-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with DiceLoss-compatible API #8703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with DiceLoss-compatible API #8703
Conversation
… with additional parameters - Add include_background, to_onehot_y, sigmoid, softmax, other_act, and reduction parameters - Fix argument order in forward() to match other losses (y_pred, y_true) - Add proper input validation and comprehensive docstrings - These changes make the losses consistent with DiceLoss API and fix zero loss issues Signed-off-by: Soumya Snigdha Kundu <[email protected]>
📝 WalkthroughWalkthroughReplaces the previous soft_dice path with a new SoftclDiceLoss class that supports configurable activation (sigmoid/softmax/custom), optional to_onehot target conversion, include/exclude background, iterative skeleton-based clDice computation, numeric smoothing, and LossReduction-based reductions. Adds SoftDiceclDiceLoss to combine DiceLoss and SoftclDiceLoss via a weighted alpha. Implements forward methods with shape/activation validation and warnings for incompatible settings. Removes the old soft_dice function and updates imports. Tests rewritten and parameterized to cover activations, reductions, shape/channel mismatches, CUDA, and edge cases. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/losses/cldice.py`:
- Around line 217-223: The computation of tprec/tsens and cl_dice can produce
NaN when self.smooth == 0 and the denominator sums (torch.sum(skel_pred, ...) or
torch.sum(skel_true, ...)) are zero; update the logic in the CLDice computation
(references: skel_pred, skel_true, input, tprec, tsens, cl_dice, self.smooth,
reduce_axis) to guard denominators by a small positive epsilon (or enforce
self.smooth > 0) — e.g., compute denom_pred = torch.sum(skel_pred,
dim=reduce_axis).clamp_min(eps) (and similarly for denom_true) or use
torch.where to replace zero denominators with eps before dividing, so
tprec/tsens and cl_dice never become NaN; optionally add a docstring note that
smooth must be positive.
🧹 Nitpick comments (4)
monai/losses/cldice.py (2)
187-187: Addstacklevel=2to warnings.Without stacklevel, warnings point to this line instead of the caller's location.
Proposed fix
- warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)Same applies to lines 196 and 202.
340-340: Addstacklevel=2here as well.Proposed fix
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)tests/losses/test_cldice_loss.py (2)
36-71: Consider adding a test case forinclude_background=False.Current cases cover sigmoid, softmax, and to_onehot_y, but not background exclusion.
82-87: Consider using@unittest.skipUnlessdecorator.More idiomatic than early return.
Proposed fix
+ `@unittest.skipUnless`(torch.cuda.is_available(), "CUDA not available") def test_cuda(self): - if not torch.cuda.is_available(): - return loss = SoftclDiceLoss() result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/cldice.pytests/losses/test_cldice_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_cldice_loss.pymonai/losses/cldice.py
🧬 Code graph analysis (1)
monai/losses/cldice.py (3)
monai/losses/dice.py (1)
DiceLoss(31-229)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.11)
monai/losses/cldice.py
158-158: Avoid specifying long messages outside the exception class
(TRY003)
160-160: Avoid specifying long messages outside the exception class
(TRY003)
187-187: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
196-196: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
202-202: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
208-208: Avoid specifying long messages outside the exception class
(TRY003)
233-233: Avoid specifying long messages outside the exception class
(TRY003)
326-329: Avoid specifying long messages outside the exception class
(TRY003)
332-335: Avoid specifying long messages outside the exception class
(TRY003)
340-340: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🔇 Additional comments (5)
monai/losses/cldice.py (4)
14-23: LGTM!Imports are appropriate for the enhanced functionality.
119-167: LGTM!Constructor properly validates activation options and aligns with DiceLoss API.
291-312: LGTM!Correctly centralizes one-hot conversion to avoid double application in composed losses.
314-348: LGTM!Input validation and combined loss computation are correct.
tests/losses/test_cldice_loss.py (1)
111-134: LGTM!Good coverage of combined loss functionality and error cases.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Signed-off-by: Soumya Snigdha Kundu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/losses/cldice.py`:
- Around line 161-163: The code currently validates smooth but not the
skeletonization iteration count iter_; add validation in the clDice constructor
(or the function where self.iter is set) to ensure iter_ is a positive integer:
check that iter_ is an int (or castable to int) and greater than 0, and raise a
ValueError with a clear message if not; then assign self.iter = int(iter_) so
downstream skeletonize/skel operations use a safe positive integer (refer to the
self.iter assignment and the iter_ parameter in the clDice class/constructor).
🧹 Nitpick comments (6)
monai/losses/cldice.py (6)
188-189: Addstacklevel=2to warnings.Without it, warnings point to this file instead of the caller's location.
Proposed fix
if n_pred_ch == 1: - warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)Apply similarly to lines 198 and 204.
Also applies to: 197-198, 203-204
232-233: Inconsistentreduction="none"behavior with DiceLoss.DiceLoss applies
.view(-1)forreduction="none". Here it's left as-is. This may cause issues when combining losses or stacking results.Proposed fix for consistency
elif self.reduction == LossReduction.NONE.value: - pass # keep per-batch values + cl_dice = cl_dice.view(-1)
293-315: Consider validatingalpharange.
alphaoutside[0, 1]would produce unusual weighting. While possibly intentional, a validation or warning could prevent mistakes.Proposed fix
if smooth <= 0: raise ValueError(f"smooth must be a positive value but got {smooth}.") + if not 0.0 <= alpha <= 1.0: + warnings.warn(f"alpha={alpha} is outside [0, 1], loss weighting may be unusual.", stacklevel=2)
344-344: Addstacklevel=2here as well.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
171-181: MissingReturnssection in docstring.Per coding guidelines, return values should be documented.
Proposed addition
Raises: AssertionError: When input and target (after one hot transform if set) have different shapes. + + Returns: + torch.Tensor: The computed clDice loss, reduced according to `self.reduction`. """
318-328: MissingReturnssection in docstring.Same as
SoftclDiceLoss.forward.Proposed addition
Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + + Returns: + torch.Tensor: The weighted combination of Dice and clDice losses. """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/cldice.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/cldice.py
🧬 Code graph analysis (1)
monai/losses/cldice.py (3)
monai/losses/dice.py (1)
DiceLoss(31-229)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.11)
monai/losses/cldice.py
158-158: Avoid specifying long messages outside the exception class
(TRY003)
160-160: Avoid specifying long messages outside the exception class
(TRY003)
162-162: Avoid specifying long messages outside the exception class
(TRY003)
189-189: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
198-198: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
204-204: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
210-210: Avoid specifying long messages outside the exception class
(TRY003)
235-235: Avoid specifying long messages outside the exception class
(TRY003)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
330-333: Avoid specifying long messages outside the exception class
(TRY003)
336-339: Avoid specifying long messages outside the exception class
(TRY003)
344-344: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🔇 Additional comments (3)
monai/losses/cldice.py (3)
14-23: LGTM!Imports are appropriate for the new functionality.
212-225: LGTM!clDice computation is correct. The
smoothvalidation in__init__addresses the NaN concern from previous reviews, and the1e-8epsilon provides additional safety for the harmonic mean.
318-352: LGTM!Forward method properly validates dimensions/channels, centralizes one-hot conversion, and combines losses correctly.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| if smooth <= 0: | ||
| raise ValueError(f"smooth must be a positive value but got {smooth}.") | ||
| self.iter = iter_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation for iter_ parameter.
smooth is validated but iter_ is not. A non-positive value would produce incorrect skeletonization.
Proposed fix
if smooth <= 0:
raise ValueError(f"smooth must be a positive value but got {smooth}.")
+ if iter_ < 0:
+ raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.")
self.iter = iter_🧰 Tools
🪛 Ruff (0.14.11)
162-162: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@monai/losses/cldice.py` around lines 161 - 163, The code currently validates
smooth but not the skeletonization iteration count iter_; add validation in the
clDice constructor (or the function where self.iter is set) to ensure iter_ is a
positive integer: check that iter_ is an int (or castable to int) and greater
than 0, and raise a ValueError with a clear message if not; then assign
self.iter = int(iter_) so downstream skeletonize/skel operations use a safe
positive integer (refer to the self.iter assignment and the iter_ parameter in
the clDice class/constructor).
Summary
include_background,to_onehot_y,sigmoid,softmax,other_act, andreductionparameters toSoftclDiceLossandSoftDiceclDiceLossforward()to match MONAI convention (input,targetinstead ofy_true,y_pred)DiceLossFixes #8239
Changes
These changes make the clDice losses consistent with the
DiceLossAPI, addressing the issues reported in #8239 where users encountered zero loss due to missing preprocessing options.Checklist
./runtests.sh --codeformat)y_true, y_predtoinput, target)Test plan