Skip to content

Commit edb01ce

Browse files
committed
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss
Signed-off-by: ytl0623 <[email protected]>
1 parent 15fd428 commit edb01ce

File tree

2 files changed

+225
-124
lines changed

2 files changed

+225
-124
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 149 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222

2323
class AsymmetricFocalTverskyLoss(_Loss):
2424
"""
25-
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
25+
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes.
2626
27-
Actually, it's only supported for binary image segmentation now.
27+
It supports both binary and multi-class segmentation.
2828
2929
Reimplementation of the Asymmetric Focal Tversky Loss described in:
30-
3130
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
32-
Michael Yeung, Computerized Medical Imaging and Graphics
31+
Michael Yeung, Computerized Medical Imaging and Graphics
3332
"""
3433

3534
def __init__(
@@ -39,119 +38,200 @@ def __init__(
3938
gamma: float = 0.75,
4039
epsilon: float = 1e-7,
4140
reduction: LossReduction | str = LossReduction.MEAN,
41+
use_softmax: bool = False,
4242
) -> None:
4343
"""
4444
Args:
4545
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46-
delta : weight of the background. Defaults to 0.7.
47-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
46+
delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
47+
gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
48+
epsilon: a small value to prevent division by zero. Defaults to 1e-7.
49+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
50+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
51+
use_softmax: whether to use softmax to transform original logits into probabilities.
52+
If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label).
53+
Defaults to False.
4954
"""
5055
super().__init__(reduction=LossReduction(reduction).value)
5156
self.to_onehot_y = to_onehot_y
5257
self.delta = delta
5358
self.gamma = gamma
5459
self.epsilon = epsilon
60+
self.use_softmax = use_softmax
5561

5662
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
63+
"""
64+
Args:
65+
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
66+
y_true: ground truth labels. Shape should match y_pred.
67+
"""
68+
69+
# Auto-handle single channel input (binary segmentation case)
70+
if y_pred.shape[1] == 1 and not self.use_softmax:
71+
y_pred = torch.sigmoid(y_pred)
72+
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
73+
is_already_prob = True
74+
if y_true.shape[1] == 1:
75+
y_true = one_hot(y_true, num_classes=2)
76+
else:
77+
is_already_prob = False
78+
5779
n_pred_ch = y_pred.shape[1]
5880

5981
if self.to_onehot_y:
6082
if n_pred_ch == 1:
6183
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
6284
else:
63-
y_true = one_hot(y_true, num_classes=n_pred_ch)
85+
if y_true.shape[1] != n_pred_ch:
86+
y_true = one_hot(y_true, num_classes=n_pred_ch)
6487

6588
if y_true.shape != y_pred.shape:
6689
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
6790

68-
# clip the prediction to avoid NaN
91+
# Convert logits to probabilities if not already done
92+
if not is_already_prob:
93+
if self.use_softmax:
94+
y_pred = torch.softmax(y_pred, dim=1)
95+
else:
96+
y_pred = torch.sigmoid(y_pred)
97+
98+
# Clip the prediction to avoid NaN
6999
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
100+
70101
axis = list(range(2, len(y_pred.shape)))
71102

72103
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
73104
tp = torch.sum(y_true * y_pred, dim=axis)
74105
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
75106
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
107+
76108
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
77109

78-
# Calculate losses separately for each class, enhancing both classes
110+
# Calculate losses separately for each class
111+
# Background: Standard Dice Loss
79112
back_dice = 1 - dice_class[:, 0]
80-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
81113

82-
# Average class scores
83-
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
84-
return loss
114+
# Foreground: Focal Tversky Loss
115+
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)
116+
117+
# Concatenate background and foreground losses
118+
# back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1)
119+
all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1)
120+
121+
# Apply reduction
122+
if self.reduction == LossReduction.MEAN.value:
123+
return torch.mean(all_losses)
124+
if self.reduction == LossReduction.SUM.value:
125+
return torch.sum(all_losses)
126+
if self.reduction == LossReduction.NONE.value:
127+
return all_losses
128+
129+
return torch.mean(all_losses)
85130

86131

87132
class AsymmetricFocalLoss(_Loss):
88133
"""
89-
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
134+
AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently.
90135
91-
Actually, it's only supported for binary image segmentation now.
136+
It supports both binary and multi-class segmentation.
92137
93138
Reimplementation of the Asymmetric Focal Loss described in:
94-
95139
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96-
Michael Yeung, Computerized Medical Imaging and Graphics
140+
Michael Yeung, Computerized Medical Imaging and Graphics
97141
"""
98142

99143
def __init__(
100144
self,
101145
to_onehot_y: bool = False,
102146
delta: float = 0.7,
103-
gamma: float = 2,
147+
gamma: float = 2.0,
104148
epsilon: float = 1e-7,
105149
reduction: LossReduction | str = LossReduction.MEAN,
150+
use_softmax: bool = False,
106151
):
107152
"""
108153
Args:
109-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110-
delta : weight of the background. Defaults to 0.7.
111-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
154+
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
155+
delta: weight for the foreground classes. Defaults to 0.7.
156+
gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
157+
epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
158+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
159+
use_softmax: whether to use softmax to transform logits. Defaults to False.
113160
"""
114161
super().__init__(reduction=LossReduction(reduction).value)
115162
self.to_onehot_y = to_onehot_y
116163
self.delta = delta
117164
self.gamma = gamma
118165
self.epsilon = epsilon
166+
self.use_softmax = use_softmax
119167

120168
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169+
"""
170+
Args:
171+
y_pred: prediction logits or probabilities.
172+
y_true: ground truth labels.
173+
"""
174+
175+
if y_pred.shape[1] == 1 and not self.use_softmax:
176+
y_pred = torch.sigmoid(y_pred)
177+
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
178+
is_already_prob = True
179+
if y_true.shape[1] == 1:
180+
y_true = one_hot(y_true, num_classes=2)
181+
else:
182+
is_already_prob = False
183+
121184
n_pred_ch = y_pred.shape[1]
122185

123186
if self.to_onehot_y:
124187
if n_pred_ch == 1:
125188
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
126189
else:
127-
y_true = one_hot(y_true, num_classes=n_pred_ch)
190+
if y_true.shape[1] != n_pred_ch:
191+
y_true = one_hot(y_true, num_classes=n_pred_ch)
128192

129193
if y_true.shape != y_pred.shape:
130194
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
131195

196+
if not is_already_prob:
197+
if self.use_softmax:
198+
y_pred = torch.softmax(y_pred, dim=1)
199+
else:
200+
y_pred = torch.sigmoid(y_pred)
201+
132202
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
203+
133204
cross_entropy = -y_true * torch.log(y_pred)
134205

206+
# Background (Channel 0): Focal Loss
135207
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
136208
back_ce = (1 - self.delta) * back_ce
137-
138-
fore_ce = cross_entropy[:, 1]
209+
# Foreground (Channels 1+): Standard Cross Entropy
210+
fore_ce = cross_entropy[:, 1:]
139211
fore_ce = self.delta * fore_ce
140212

141-
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
142-
return loss
213+
# Concatenate losses
214+
all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1)
143215

216+
# Sum over classes (dim=1) to get total loss per pixel
217+
total_loss = torch.sum(all_ce, dim=1)
144218

145-
class AsymmetricUnifiedFocalLoss(_Loss):
146-
"""
147-
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
219+
# Apply reduction
220+
if self.reduction == LossReduction.MEAN.value:
221+
return torch.mean(total_loss)
222+
if self.reduction == LossReduction.SUM.value:
223+
return torch.sum(total_loss)
224+
if self.reduction == LossReduction.NONE.value:
225+
return total_loss
226+
return torch.mean(total_loss)
148227

149-
Actually, it's only supported for binary image segmentation now
150228

151-
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
229+
class AsymmetricUnifiedFocalLoss(_Loss):
230+
"""
231+
AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss.
152232
153-
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
154-
Michael Yeung, Computerized Medical Imaging and Graphics
233+
This unified loss allows for simultaneously optimizing distribution-based (CE) and region-based (Dice) metrics,
234+
while handling class imbalance through asymmetric weighting.
155235
"""
156236

157237
def __init__(
@@ -162,79 +242,57 @@ def __init__(
162242
gamma: float = 0.5,
163243
delta: float = 0.7,
164244
reduction: LossReduction | str = LossReduction.MEAN,
245+
use_softmax: bool = False,
165246
):
166247
"""
167248
Args:
168-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169-
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170-
delta : weight of the background. Defaults to 0.7.
171-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
174-
175-
Example:
176-
>>> import torch
177-
>>> from monai.losses import AsymmetricUnifiedFocalLoss
178-
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
179-
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
180-
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
181-
>>> fl(pred, grnd)
249+
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
250+
num_classes: number of classes. Defaults to 2.
251+
weight: weight factor to balance between Focal Loss and Tversky Loss.
252+
Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5.
253+
gamma: focal exponent. Defaults to 0.5.
254+
delta: background/foreground balancing weight. Defaults to 0.7.
255+
reduction: specifies the reduction to apply to the output. Defaults to "mean".
256+
use_softmax: whether to use softmax for probability conversion. Defaults to False.
182257
"""
183258
super().__init__(reduction=LossReduction(reduction).value)
184259
self.to_onehot_y = to_onehot_y
185260
self.num_classes = num_classes
186261
self.gamma = gamma
187262
self.delta = delta
188-
self.weight: float = weight
189-
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
190-
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
263+
self.weight = weight
264+
self.use_softmax = use_softmax
265+
266+
self.asy_focal_loss = AsymmetricFocalLoss(
267+
gamma=self.gamma,
268+
delta=self.delta,
269+
use_softmax=self.use_softmax,
270+
to_onehot_y=to_onehot_y,
271+
reduction=reduction,
272+
)
273+
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
274+
gamma=self.gamma,
275+
delta=self.delta,
276+
use_softmax=self.use_softmax,
277+
to_onehot_y=to_onehot_y,
278+
reduction=reduction,
279+
)
191280

192-
# TODO: Implement this function to support multiple classes segmentation
193281
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
194282
"""
195283
Args:
196-
y_pred : the shape should be BNH[WD], where N is the number of classes.
197-
It only supports binary segmentation.
198-
The input should be the original logits since it will be transformed by
199-
a sigmoid in the forward function.
200-
y_true : the shape should be BNH[WD], where N is the number of classes.
201-
It only supports binary segmentation.
202-
203-
Raises:
204-
ValueError: When input and target are different shape
205-
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206-
ValueError: When num_classes
207-
ValueError: When the number of classes entered does not match the expected number
284+
y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
285+
Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
286+
y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
208287
"""
209288
if y_pred.shape != y_true.shape:
210-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
211-
212-
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
213-
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
214-
215-
if y_pred.shape[1] == 1:
216-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
217-
y_true = one_hot(y_true, num_classes=self.num_classes)
218-
219-
if torch.max(y_true) != self.num_classes - 1:
220-
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
221-
222-
n_pred_ch = y_pred.shape[1]
223-
if self.to_onehot_y:
224-
if n_pred_ch == 1:
225-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
226-
else:
227-
y_true = one_hot(y_true, num_classes=n_pred_ch)
289+
is_binary_logits = (y_pred.shape[1] == 1 and not self.use_softmax)
290+
if not self.to_onehot_y and not is_binary_logits:
291+
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
228292

229293
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
230294
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
231295

232-
loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
296+
loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
233297

234-
if self.reduction == LossReduction.SUM.value:
235-
return torch.sum(loss) # sum over the batch and channel dims
236-
if self.reduction == LossReduction.NONE.value:
237-
return loss # returns [N, num_classes] losses
238-
if self.reduction == LossReduction.MEAN.value:
239-
return torch.mean(loss)
240-
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
298+
return loss

0 commit comments

Comments
 (0)