Skip to content

Commit acefee4

Browse files
authored
dice_similarity_coefficient_with_logits(): RuntimeError: Boolean value of Tensor with more than one value is ambiguous (#210)
* Fix ambiguity. (#209) * Add `dice_similarity_coefficient_with_logits_clip` and update usages. * Removed `min_percentage_per_class` from loss and metric functions. (#209)
1 parent 958cf76 commit acefee4

File tree

4 files changed

+29
-30
lines changed

4 files changed

+29
-30
lines changed

mipcandy/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \
99
WithNetwork
1010
from mipcandy.metrics import do_reduction, dice_similarity_coefficient_binary, \
11-
dice_similarity_coefficient_multiclass, dice_similarity_coefficient_with_logits, soft_dice_coefficient, \
12-
accuracy_binary, accuracy_multiclass, precision_binary, precision_multiclass, recall_binary, recall_multiclass, \
13-
iou_binary, iou_multiclass
11+
dice_similarity_coefficient_multiclass, dice_similarity_coefficient_with_logits, \
12+
dice_similarity_coefficient_with_logits_clip, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \
13+
precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass
1414
from mipcandy.presets import *
1515
from mipcandy.profiler import ProfilerFrame, Profiler
1616
from mipcandy.run import config

mipcandy/common/optim/loss.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mipcandy.data import convert_ids_to_logits, convert_logits_to_ids
77
from mipcandy.metrics import do_reduction, soft_dice_coefficient, dice_similarity_coefficient_binary, \
8-
dice_similarity_coefficient_with_logits
8+
dice_similarity_coefficient_with_logits_clip
99

1010

1111
class FocalBCEWithLogits(nn.Module):
@@ -56,21 +56,19 @@ def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Te
5656
class_dice = dice_similarity_coefficient_binary(outputs == i, labels == i).item()
5757
dice += class_dice
5858
metrics[f"dice {i}"] = class_dice
59-
metrics["dice"] = dice_similarity_coefficient_with_logits(
59+
metrics["dice"] = dice_similarity_coefficient_with_logits_clip(
6060
self.logitfy_no_grad(outputs), self.logitfy_no_grad(labels)
6161
).item()
6262
return c, metrics
6363

6464

6565
class DiceCELossWithLogits(_SegmentationLoss):
6666
def __init__(self, num_classes: int, *, lambda_ce: float = 1, lambda_soft_dice: float = 1,
67-
smooth: float = 1e-5, include_background: bool = True,
68-
min_percentage_per_class: float | None = None) -> None:
67+
smooth: float = 1e-5, include_background: bool = True) -> None:
6968
super().__init__(num_classes, include_background)
7069
self.lambda_ce: float = lambda_ce
7170
self.lambda_soft_dice: float = lambda_soft_dice
7271
self.smooth: float = smooth
73-
self.min_percentage_per_class: float | None = min_percentage_per_class
7472

7573
def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
7674
ce = nn.functional.cross_entropy(outputs, labels[:, 0].long())
@@ -79,8 +77,7 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T
7977
if not self.include_background:
8078
outputs = outputs[:, 1:]
8179
labels = labels[:, 1:]
82-
soft_dice = soft_dice_coefficient(outputs, labels, smooth=self.smooth,
83-
min_percentage_per_class=self.min_percentage_per_class)
80+
soft_dice = soft_dice_coefficient(outputs, labels, smooth=self.smooth)
8481
metrics = {"soft dice": soft_dice.item(), "ce loss": ce.item()}
8582
c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - soft_dice)
8683
return c, metrics

mipcandy/metrics.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,33 @@ def dice_similarity_coefficient_multiclass(outputs: torch.Tensor, labels: torch.
6767
return apply_multiclass_to_binary(dice_similarity_coefficient_binary, outputs, labels, num_classes, if_empty)
6868

6969

70-
def dice_similarity_coefficient_with_logits(outputs: torch.Tensor, labels: torch.Tensor, *,
71-
if_empty: float = 1) -> torch.Tensor:
70+
def _dice_with_logits(outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
7271
_args_check(outputs, labels, dtype=torch.float)
7372
axes = tuple(range(2, outputs.ndim))
7473
tp = (outputs * labels).sum(axes)
7574
fp = (outputs * (1 - labels)).sum(axes)
7675
fn = ((1 - outputs) * labels).sum(axes)
77-
volume_sum = 2 * tp + fp + fn
78-
if volume_sum == 0:
76+
return tp, 2 * tp + fp + fn
77+
78+
79+
def dice_similarity_coefficient_with_logits(outputs: torch.Tensor, labels: torch.Tensor, *,
80+
if_empty: float = 1) -> torch.Tensor:
81+
tp, volume_sum = _dice_with_logits(outputs, labels)
82+
if (volume_sum == 0).any():
7983
return torch.tensor(if_empty, dtype=torch.float)
80-
return 2 * tp / volume_sum
84+
dice = 2 * tp / volume_sum
85+
return dice.mean()
86+
87+
88+
def dice_similarity_coefficient_with_logits_clip(outputs: torch.Tensor, labels: torch.Tensor, *,
89+
clip_min: float = 1e-8) -> torch.Tensor:
90+
tp, volume_sum = _dice_with_logits(outputs, labels)
91+
dice = 2 * tp / torch.clip(volume_sum, clip_min)
92+
return dice.mean()
8193

8294

83-
def soft_dice_coefficient(outputs: torch.Tensor, labels: torch.Tensor, *, smooth: float = 1, batch_dice: bool = True,
84-
min_percentage_per_class: float | None = None) -> torch.Tensor:
95+
def soft_dice_coefficient(outputs: torch.Tensor, labels: torch.Tensor, *, smooth: float = 1, clip_min: float = 1e-8,
96+
batch_dice: bool = True) -> torch.Tensor:
8597
_args_check(outputs, labels)
8698
axes = tuple(range(2, outputs.ndim))
8799
if batch_dice:
@@ -93,16 +105,7 @@ def soft_dice_coefficient(outputs: torch.Tensor, labels: torch.Tensor, *, smooth
93105
intersection = intersection.sum(0)
94106
output_sum = output_sum.sum(0)
95107
label_sum = label_sum.sum(0)
96-
dice = (2 * intersection + smooth) / (torch.clip(label_sum + output_sum + smooth, 1e-8))
97-
if min_percentage_per_class:
98-
total = label_sum.sum()
99-
if total == 0:
100-
return torch.tensor(1, device=outputs.device, dtype=outputs.dtype)
101-
min_voxels = total * min_percentage_per_class
102-
valid = label_sum >= min_voxels
103-
if valid.any():
104-
return dice[valid].mean()
105-
return torch.tensor(1, device=outputs.device, dtype=outputs.dtype)
108+
dice = (2 * intersection + smooth) / torch.clip(label_sum + output_sum + smooth, clip_min)
106109
return dice.mean()
107110

108111

mipcandy/presets/segmentation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,9 @@ def build_criterion(self) -> nn.Module:
9494
if self.num_classes < 2:
9595
if not self.include_background:
9696
raise ValueError("Binary segmentation models must include background class")
97-
loss = DiceBCELossWithLogits(min_percentage_per_class=1e-5)
97+
loss = DiceBCELossWithLogits()
9898
else:
99-
loss = DiceCELossWithLogits(self.num_classes, include_background=self.include_background,
100-
min_percentage_per_class=1e-5)
99+
loss = DiceCELossWithLogits(self.num_classes, include_background=self.include_background)
101100
if self.deep_supervision:
102101
if not self.deep_supervision_weights and self.deep_supervision_scales:
103102
weights = np.array([1 / (2 ** i) for i in range(len(self.deep_supervision_scales))])

0 commit comments

Comments
 (0)