2222
2323class AsymmetricFocalTverskyLoss (_Loss ):
2424 """
25- AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class .
25+ AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes .
2626
27- Actually, it's only supported for binary image segmentation now .
27+ It supports both binary and multi-class segmentation .
2828
2929 Reimplementation of the Asymmetric Focal Tversky Loss described in:
30-
3130 - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
32- Michael Yeung, Computerized Medical Imaging and Graphics
31+ Michael Yeung, Computerized Medical Imaging and Graphics
3332 """
3433
3534 def __init__ (
@@ -39,119 +38,200 @@ def __init__(
3938 gamma : float = 0.75 ,
4039 epsilon : float = 1e-7 ,
4140 reduction : LossReduction | str = LossReduction .MEAN ,
41+ use_softmax : bool = False ,
4242 ) -> None :
4343 """
4444 Args:
4545 to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46- delta : weight of the background. Defaults to 0.7.
47- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
46+ delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
47+ gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
48+ epsilon: a small value to prevent division by zero. Defaults to 1e-7.
49+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
50+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
51+ use_softmax: whether to use softmax to transform original logits into probabilities.
52+ If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label).
53+ Defaults to False.
4954 """
5055 super ().__init__ (reduction = LossReduction (reduction ).value )
5156 self .to_onehot_y = to_onehot_y
5257 self .delta = delta
5358 self .gamma = gamma
5459 self .epsilon = epsilon
60+ self .use_softmax = use_softmax
5561
5662 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
63+ """
64+ Args:
65+ y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
66+ y_true: ground truth labels. Shape should match y_pred.
67+ """
68+
69+ # Auto-handle single channel input (binary segmentation case)
70+ if y_pred .shape [1 ] == 1 and not self .use_softmax :
71+ y_pred = torch .sigmoid (y_pred )
72+ y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
73+ is_already_prob = True
74+ if y_true .shape [1 ] == 1 :
75+ y_true = one_hot (y_true , num_classes = 2 )
76+ else :
77+ is_already_prob = False
78+
5779 n_pred_ch = y_pred .shape [1 ]
5880
5981 if self .to_onehot_y :
6082 if n_pred_ch == 1 :
6183 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
6284 else :
63- y_true = one_hot (y_true , num_classes = n_pred_ch )
85+ if y_true .shape [1 ] != n_pred_ch :
86+ y_true = one_hot (y_true , num_classes = n_pred_ch )
6487
6588 if y_true .shape != y_pred .shape :
6689 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
6790
68- # clip the prediction to avoid NaN
91+ # Convert logits to probabilities if not already done
92+ if not is_already_prob :
93+ if self .use_softmax :
94+ y_pred = torch .softmax (y_pred , dim = 1 )
95+ else :
96+ y_pred = torch .sigmoid (y_pred )
97+
98+ # Clip the prediction to avoid NaN
6999 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
100+
70101 axis = list (range (2 , len (y_pred .shape )))
71102
72103 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
73104 tp = torch .sum (y_true * y_pred , dim = axis )
74105 fn = torch .sum (y_true * (1 - y_pred ), dim = axis )
75106 fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
107+
76108 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
77109
78- # Calculate losses separately for each class, enhancing both classes
110+ # Calculate losses separately for each class
111+ # Background: Standard Dice Loss
79112 back_dice = 1 - dice_class [:, 0 ]
80- fore_dice = (1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
81113
82- # Average class scores
83- loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
84- return loss
114+ # Foreground: Focal Tversky Loss
115+ fore_dice = torch .pow (1 - dice_class [:, 1 :], 1 / self .gamma )
116+
117+ # Concatenate background and foreground losses
118+ # back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1)
119+ all_losses = torch .cat ([back_dice .unsqueeze (1 ), fore_dice ], dim = 1 )
120+
121+ # Apply reduction
122+ if self .reduction == LossReduction .MEAN .value :
123+ return torch .mean (all_losses )
124+ if self .reduction == LossReduction .SUM .value :
125+ return torch .sum (all_losses )
126+ if self .reduction == LossReduction .NONE .value :
127+ return all_losses
128+
129+ return torch .mean (all_losses )
85130
86131
87132class AsymmetricFocalLoss (_Loss ):
88133 """
89- AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class .
134+ AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently .
90135
91- Actually, it's only supported for binary image segmentation now .
136+ It supports both binary and multi-class segmentation .
92137
93138 Reimplementation of the Asymmetric Focal Loss described in:
94-
95139 - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96- Michael Yeung, Computerized Medical Imaging and Graphics
140+ Michael Yeung, Computerized Medical Imaging and Graphics
97141 """
98142
99143 def __init__ (
100144 self ,
101145 to_onehot_y : bool = False ,
102146 delta : float = 0.7 ,
103- gamma : float = 2 ,
147+ gamma : float = 2.0 ,
104148 epsilon : float = 1e-7 ,
105149 reduction : LossReduction | str = LossReduction .MEAN ,
150+ use_softmax : bool = False ,
106151 ):
107152 """
108153 Args:
109- to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110- delta : weight of the background. Defaults to 0.7.
111- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
154+ to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
155+ delta: weight for the foreground classes. Defaults to 0.7.
156+ gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
157+ epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
158+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
159+ use_softmax: whether to use softmax to transform logits. Defaults to False.
113160 """
114161 super ().__init__ (reduction = LossReduction (reduction ).value )
115162 self .to_onehot_y = to_onehot_y
116163 self .delta = delta
117164 self .gamma = gamma
118165 self .epsilon = epsilon
166+ self .use_softmax = use_softmax
119167
120168 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
169+ """
170+ Args:
171+ y_pred: prediction logits or probabilities.
172+ y_true: ground truth labels.
173+ """
174+
175+ if y_pred .shape [1 ] == 1 and not self .use_softmax :
176+ y_pred = torch .sigmoid (y_pred )
177+ y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
178+ is_already_prob = True
179+ if y_true .shape [1 ] == 1 :
180+ y_true = one_hot (y_true , num_classes = 2 )
181+ else :
182+ is_already_prob = False
183+
121184 n_pred_ch = y_pred .shape [1 ]
122185
123186 if self .to_onehot_y :
124187 if n_pred_ch == 1 :
125188 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
126189 else :
127- y_true = one_hot (y_true , num_classes = n_pred_ch )
190+ if y_true .shape [1 ] != n_pred_ch :
191+ y_true = one_hot (y_true , num_classes = n_pred_ch )
128192
129193 if y_true .shape != y_pred .shape :
130194 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
131195
196+ if not is_already_prob :
197+ if self .use_softmax :
198+ y_pred = torch .softmax (y_pred , dim = 1 )
199+ else :
200+ y_pred = torch .sigmoid (y_pred )
201+
132202 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
203+
133204 cross_entropy = - y_true * torch .log (y_pred )
134205
206+ # Background (Channel 0): Focal Loss
135207 back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
136208 back_ce = (1 - self .delta ) * back_ce
137-
138- fore_ce = cross_entropy [:, 1 ]
209+ # Foreground (Channels 1+): Standard Cross Entropy
210+ fore_ce = cross_entropy [:, 1 : ]
139211 fore_ce = self .delta * fore_ce
140212
141- loss = torch . mean ( torch . sum ( torch . stack ([ back_ce , fore_ce ], dim = 1 ), dim = 1 ))
142- return loss
213+ # Concatenate losses
214+ all_ce = torch . cat ([ back_ce . unsqueeze ( 1 ), fore_ce ], dim = 1 )
143215
216+ # Sum over classes (dim=1) to get total loss per pixel
217+ total_loss = torch .sum (all_ce , dim = 1 )
144218
145- class AsymmetricUnifiedFocalLoss (_Loss ):
146- """
147- AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
219+ # Apply reduction
220+ if self .reduction == LossReduction .MEAN .value :
221+ return torch .mean (total_loss )
222+ if self .reduction == LossReduction .SUM .value :
223+ return torch .sum (total_loss )
224+ if self .reduction == LossReduction .NONE .value :
225+ return total_loss
226+ return torch .mean (total_loss )
148227
149- Actually, it's only supported for binary image segmentation now
150228
151- Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
229+ class AsymmetricUnifiedFocalLoss (_Loss ):
230+ """
231+ AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss.
152232
153- - "Unified Focal Loss: Generalising Dice and Cross Entropy -based Losses to Handle Class Imbalanced Medical Image Segmentation" ,
154- Michael Yeung, Computerized Medical Imaging and Graphics
233+ This unified loss allows for simultaneously optimizing distribution -based (CE) and region-based (Dice) metrics ,
234+ while handling class imbalance through asymmetric weighting.
155235 """
156236
157237 def __init__ (
@@ -162,79 +242,57 @@ def __init__(
162242 gamma : float = 0.5 ,
163243 delta : float = 0.7 ,
164244 reduction : LossReduction | str = LossReduction .MEAN ,
245+ use_softmax : bool = False ,
165246 ):
166247 """
167248 Args:
168- to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169- num_classes : number of classes, it only supports 2 now. Defaults to 2.
170- delta : weight of the background. Defaults to 0.7.
171- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173- weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
174-
175- Example:
176- >>> import torch
177- >>> from monai.losses import AsymmetricUnifiedFocalLoss
178- >>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
179- >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
180- >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
181- >>> fl(pred, grnd)
249+ to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
250+ num_classes: number of classes. Defaults to 2.
251+ weight: weight factor to balance between Focal Loss and Tversky Loss.
252+ Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5.
253+ gamma: focal exponent. Defaults to 0.5.
254+ delta: background/foreground balancing weight. Defaults to 0.7.
255+ reduction: specifies the reduction to apply to the output. Defaults to "mean".
256+ use_softmax: whether to use softmax for probability conversion. Defaults to False.
182257 """
183258 super ().__init__ (reduction = LossReduction (reduction ).value )
184259 self .to_onehot_y = to_onehot_y
185260 self .num_classes = num_classes
186261 self .gamma = gamma
187262 self .delta = delta
188- self .weight : float = weight
189- self .asy_focal_loss = AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
190- self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
263+ self .weight = weight
264+ self .use_softmax = use_softmax
265+
266+ self .asy_focal_loss = AsymmetricFocalLoss (
267+ gamma = self .gamma ,
268+ delta = self .delta ,
269+ use_softmax = self .use_softmax ,
270+ to_onehot_y = to_onehot_y ,
271+ reduction = reduction ,
272+ )
273+ self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
274+ gamma = self .gamma ,
275+ delta = self .delta ,
276+ use_softmax = self .use_softmax ,
277+ to_onehot_y = to_onehot_y ,
278+ reduction = reduction ,
279+ )
191280
192- # TODO: Implement this function to support multiple classes segmentation
193281 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
194282 """
195283 Args:
196- y_pred : the shape should be BNH[WD], where N is the number of classes.
197- It only supports binary segmentation.
198- The input should be the original logits since it will be transformed by
199- a sigmoid in the forward function.
200- y_true : the shape should be BNH[WD], where N is the number of classes.
201- It only supports binary segmentation.
202-
203- Raises:
204- ValueError: When input and target are different shape
205- ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206- ValueError: When num_classes
207- ValueError: When the number of classes entered does not match the expected number
284+ y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
285+ Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
286+ y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
208287 """
209288 if y_pred .shape != y_true .shape :
210- raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
211-
212- if len (y_pred .shape ) != 4 and len (y_pred .shape ) != 5 :
213- raise ValueError (f"input shape must be 4 or 5, but got { y_pred .shape } " )
214-
215- if y_pred .shape [1 ] == 1 :
216- y_pred = one_hot (y_pred , num_classes = self .num_classes )
217- y_true = one_hot (y_true , num_classes = self .num_classes )
218-
219- if torch .max (y_true ) != self .num_classes - 1 :
220- raise ValueError (f"Please make sure the number of classes is { self .num_classes - 1 } " )
221-
222- n_pred_ch = y_pred .shape [1 ]
223- if self .to_onehot_y :
224- if n_pred_ch == 1 :
225- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
226- else :
227- y_true = one_hot (y_true , num_classes = n_pred_ch )
289+ is_binary_logits = (y_pred .shape [1 ] == 1 and not self .use_softmax )
290+ if not self .to_onehot_y and not is_binary_logits :
291+ raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
228292
229293 asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
230294 asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
231295
232- loss : torch . Tensor = self .weight * asy_focal_loss + (1 - self .weight ) * asy_focal_tversky_loss
296+ loss = self .weight * asy_focal_loss + (1 - self .weight ) * asy_focal_tversky_loss
233297
234- if self .reduction == LossReduction .SUM .value :
235- return torch .sum (loss ) # sum over the batch and channel dims
236- if self .reduction == LossReduction .NONE .value :
237- return loss # returns [N, num_classes] losses
238- if self .reduction == LossReduction .MEAN .value :
239- return torch .mean (loss )
240- raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
298+ return loss
0 commit comments