@@ -67,21 +67,33 @@ def dice_similarity_coefficient_multiclass(outputs: torch.Tensor, labels: torch.
6767 return apply_multiclass_to_binary (dice_similarity_coefficient_binary , outputs , labels , num_classes , if_empty )
6868
6969
70- def dice_similarity_coefficient_with_logits (outputs : torch .Tensor , labels : torch .Tensor , * ,
71- if_empty : float = 1 ) -> torch .Tensor :
70+ def _dice_with_logits (outputs : torch .Tensor , labels : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
7271 _args_check (outputs , labels , dtype = torch .float )
7372 axes = tuple (range (2 , outputs .ndim ))
7473 tp = (outputs * labels ).sum (axes )
7574 fp = (outputs * (1 - labels )).sum (axes )
7675 fn = ((1 - outputs ) * labels ).sum (axes )
77- volume_sum = 2 * tp + fp + fn
78- if volume_sum == 0 :
76+ return tp , 2 * tp + fp + fn
77+
78+
79+ def dice_similarity_coefficient_with_logits (outputs : torch .Tensor , labels : torch .Tensor , * ,
80+ if_empty : float = 1 ) -> torch .Tensor :
81+ tp , volume_sum = _dice_with_logits (outputs , labels )
82+ if (volume_sum == 0 ).any ():
7983 return torch .tensor (if_empty , dtype = torch .float )
80- return 2 * tp / volume_sum
84+ dice = 2 * tp / volume_sum
85+ return dice .mean ()
86+
87+
88+ def dice_similarity_coefficient_with_logits_clip (outputs : torch .Tensor , labels : torch .Tensor , * ,
89+ clip_min : float = 1e-8 ) -> torch .Tensor :
90+ tp , volume_sum = _dice_with_logits (outputs , labels )
91+ dice = 2 * tp / torch .clip (volume_sum , clip_min )
92+ return dice .mean ()
8193
8294
83- def soft_dice_coefficient (outputs : torch .Tensor , labels : torch .Tensor , * , smooth : float = 1 , batch_dice : bool = True ,
84- min_percentage_per_class : float | None = None ) -> torch .Tensor :
95+ def soft_dice_coefficient (outputs : torch .Tensor , labels : torch .Tensor , * , smooth : float = 1 , clip_min : float = 1e-8 ,
96+ batch_dice : bool = True ) -> torch .Tensor :
8597 _args_check (outputs , labels )
8698 axes = tuple (range (2 , outputs .ndim ))
8799 if batch_dice :
@@ -93,16 +105,7 @@ def soft_dice_coefficient(outputs: torch.Tensor, labels: torch.Tensor, *, smooth
93105 intersection = intersection .sum (0 )
94106 output_sum = output_sum .sum (0 )
95107 label_sum = label_sum .sum (0 )
96- dice = (2 * intersection + smooth ) / (torch .clip (label_sum + output_sum + smooth , 1e-8 ))
97- if min_percentage_per_class :
98- total = label_sum .sum ()
99- if total == 0 :
100- return torch .tensor (1 , device = outputs .device , dtype = outputs .dtype )
101- min_voxels = total * min_percentage_per_class
102- valid = label_sum >= min_voxels
103- if valid .any ():
104- return dice [valid ].mean ()
105- return torch .tensor (1 , device = outputs .device , dtype = outputs .dtype )
108+ dice = (2 * intersection + smooth ) / torch .clip (label_sum + output_sum + smooth , clip_min )
106109 return dice .mean ()
107110
108111
0 commit comments