Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744
Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744hongjie-qiu wants to merge 2 commits intoProject-MONAI:devfrom
Conversation
📝 WalkthroughWalkthroughThe PR corrects batch handling in GeneralizedWassersteinDiceLoss: for Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
monai/losses/dice.py (2)
509-514: Document thereduction="none"output shape in the docstring.The
forwarddocstring has noReturns:section. Since this PR changes thereduction="none"output shape from a broken reshape attempt to a well-defined(B,), callers need to know what to expect.📝 Proposed docstring update
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. + Returns: + Scalar when ``reduction`` is ``"mean"`` or ``"sum"``. + Tensor of shape ``(B,)`` when ``reduction`` is ``"none"``, one loss value + per sample (GWDL aggregates over classes and spatial dims internally). """As per coding guidelines, Google-style docstrings should describe each return value in a
Returns:section.Also applies to: 550-553
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/losses/dice.py` around lines 509 - 514, Update the forward docstring for the Dice loss (method forward) to add a Google-style "Returns:" section that clearly states the tensor shape and meaning for each reduction mode; specifically document that when reduction="none" the output is a 1-D tensor of shape (B,) containing per-batch loss values (and describe shapes for "mean"/"sum" if applicable), and apply the same docstring clarification to the other forward variant referenced around lines 550-553 so callers know to expect a (B,) output instead of the previous reshape behavior.
596-630: Both helpers share identical alpha-mapping code — extract a private helper.The five-line alpha-extension/gather/squeeze block is duplicated verbatim in
_compute_generalized_true_positiveand_compute_denominator. A private_map_alpha_to_voxelswould remove the duplication and make future fixes a one-place change (as this PR illustrates — the squeeze had to be added in both places).Additionally, both methods are missing
Returns:sections in their docstrings. As per coding guidelines, Google-style docstrings must describe return values.♻️ Proposed refactor
+ def _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor: + """Map per-class alpha weights to a per-voxel tensor via flat_target. + + Args: + alpha: per-class weights of shape (B, C). + flat_target: flattened target labels of shape (B, S). + + Returns: + Per-voxel alpha values of shape (B, S). + """ + alpha_extended = torch.unsqueeze(alpha, dim=2) + alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) + flat_target_extended = torch.unsqueeze(flat_target, dim=1) + alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) + return torch.squeeze(alpha_extended, dim=1) def _compute_generalized_true_positive( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor ) -> torch.Tensor: """ Args: alpha: generalised number of true positives of target class. flat_target: the target tensor. wasserstein_distance_map: the map obtained from the above function. + + Returns: + Per-sample generalised true positives of shape (B,). """ - # Extend alpha to a map and select value at each voxel according to flat_target - alpha_extended = torch.unsqueeze(alpha, dim=2) - alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) - flat_target_extended = torch.unsqueeze(flat_target, dim=1) - alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1) + alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target) + return torch.sum(alpha_per_voxel * (1.0 - wasserstein_distance_map), dim=1) def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor ) -> torch.Tensor: """ Args: alpha: generalised number of true positives of target class. flat_target: the target tensor. wasserstein_distance_map: the map obtained from the above function. + + Returns: + Per-sample denominator of shape (B,). """ - # Extend alpha to a map and select value at each voxel according to flat_target - alpha_extended = torch.unsqueeze(alpha, dim=2) - alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) - flat_target_extended = torch.unsqueeze(flat_target, dim=1) - alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1) + alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target) + return torch.sum(alpha_per_voxel * (2.0 - wasserstein_distance_map), dim=1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/losses/dice.py` around lines 596 - 630, Both methods _compute_generalized_true_positive and _compute_denominator duplicate the alpha-extension/gather/squeeze block; extract that logic into a private helper (e.g., _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor) that returns the per-voxel alpha_extended tensor, update both methods to call this helper and use its result in their sums, and remove the duplicated code; also add a Returns: section to the docstrings of _compute_generalized_true_positive and _compute_denominator describing the returned tensor shape and meaning.tests/losses/test_generalized_wasserstein_dice_loss.py (1)
293-295:float()on areduction="none"output is fragile for future readers.
loss_fn(pred_a, target_a)returns shape(1,)here;float()only works because batch size is 1. Prefer.item()or index explicitly to make intent clear.🔧 Suggested clarification
- loss_a = float(loss_fn(pred_a, target_a)) - loss_b = float(loss_fn(pred_b, target_b)) + loss_a = loss_fn(pred_a, target_a)[0].item() + loss_b = loss_fn(pred_b, target_b)[0].item()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 293 - 295, The tests call float(loss_fn(pred_a, target_a)) and float(loss_fn(pred_b, target_b)) where loss_fn is configured with reduction="none" and returns a 1-element tensor; using float() is fragile and hides the intent. Replace float(...) with .item() (e.g., loss_fn(pred_a, target_a).item()) or explicitly index [0] to extract the scalar so the code clearly indicates you're converting a single-element tensor to a Python float; update both occurrences referencing loss_fn, pred_a, target_a, pred_b, and target_b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 254-270: The mean-reduction subtest uses trivially-zero losses
(pred_single from pred_very_good), so it won't catch a broken mean reduction;
update the test to use a non-trivial prediction for at least one sample: when
constructing pred_single/pred_batch for the GeneralizedWassersteinDiceLoss
checks, replace one of the batch entries (or pred_single) with a clearly poor
prediction (e.g., pred_very_poor) so loss_single and loss_batch produce a
non-zero, different-per-sample value; keep using loss_fn =
GeneralizedWassersteinDiceLoss(..., weighting_mode=w_mode, reduction="mean") and
then assert loss_batch equals loss_single to verify mean reduction behavior.
---
Nitpick comments:
In `@monai/losses/dice.py`:
- Around line 509-514: Update the forward docstring for the Dice loss (method
forward) to add a Google-style "Returns:" section that clearly states the tensor
shape and meaning for each reduction mode; specifically document that when
reduction="none" the output is a 1-D tensor of shape (B,) containing per-batch
loss values (and describe shapes for "mean"/"sum" if applicable), and apply the
same docstring clarification to the other forward variant referenced around
lines 550-553 so callers know to expect a (B,) output instead of the previous
reshape behavior.
- Around line 596-630: Both methods _compute_generalized_true_positive and
_compute_denominator duplicate the alpha-extension/gather/squeeze block; extract
that logic into a private helper (e.g., _map_alpha_to_voxels(self, alpha:
torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor) that returns the
per-voxel alpha_extended tensor, update both methods to call this helper and use
its result in their sums, and remove the duplicated code; also add a Returns:
section to the docstrings of _compute_generalized_true_positive and
_compute_denominator describing the returned tensor shape and meaning.
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 293-295: The tests call float(loss_fn(pred_a, target_a)) and
float(loss_fn(pred_b, target_b)) where loss_fn is configured with
reduction="none" and returns a 1-element tensor; using float() is fragile and
hides the intent. Replace float(...) with .item() (e.g., loss_fn(pred_a,
target_a).item()) or explicitly index [0] to extract the scalar so the code
clearly indicates you're converting a single-element tensor to a Python float;
update both occurrences referencing loss_fn, pred_a, target_a, pred_b, and
target_b.
…oject-MONAI#4650) After `torch.gather`, `alpha_extended` retains shape (B, 1, S) while `wasserstein_distance_map` has shape (B, S). When batch size > 1 the element-wise multiply broadcasts to (B, B, S), mixing values across samples. Fixed by squeezing dim=1 after gather in both `_compute_generalized_true_positive` and `_compute_denominator`, and reducing with `dim=1` instead of `dim=[1, 2]`. Also fixed the `reduction="none"` code path which incorrectly tried to reshape the per-sample loss tensor (B,) to (B, C, 1, ...) — GWDL aggregates over classes internally so the class dimension doesn't apply. Added regression tests that verify batch consistency: - identical samples in a batch produce the same loss as a single sample - batched per-sample losses match individually computed losses Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
063df92 to
4887d9d
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 268-297: The test test_batch_size_different_samples is currently
trivial because both pred_a and pred_b are perfect one-hot predictions; change
pred_b to produce a poor prediction (e.g., build pred_b from 1 - target_b before
one_hot) so that loss_b is ≈1.0 and the batch variant checks are meaningful;
keep pred_a as the perfect 1000*F.one_hot(target_a, ...) and ensure pred_batch =
torch.cat([pred_a, pred_b], dim=0) remains consistent so loss_a, loss_b, and
loss_batch indices compare correctly inside the weighting_mode loop that calls
GeneralizedWassersteinDiceLoss and asserts loss_batch[0]==loss_a and
loss_batch[1]==loss_b.
---
Duplicate comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 252-266: The mean-reduction sub-test is weak because
pred_single/pred_batch are perfect (loss ≈ 0), so it doesn't catch
mean-reduction bugs; update the test to use non-perfect predictions so
per-sample losses are non-zero and comparable: for
GeneralizedWassersteinDiceLoss with weighting_mode set from ["default","GDL"]
and reduction="mean", construct pred_single/pred_batch (or modify pred_single to
contain a small deliberate error and replicate into pred_batch as the batch of
samples) and/or use multiple distinct samples in pred_batch, compute loss_single
= float(loss_fn(pred_single, target_single)) and loss_batch =
float(loss_fn(pred_batch, target_batch)), then assert loss_batch ≈ loss_single
to verify the mean aggregation is correct; reference
GeneralizedWassersteinDiceLoss, pred_single, pred_batch, target_single,
target_batch, weighting_mode, and reduction="mean" when making the change.
|
Hi @hongjie-qiu thanks for looking into this. Please look at the comments from Coderabbit and we can then review again. |
Address review feedback: use poor predictions in mean-reduction and different-samples tests so the expected loss values are non-trivial (~1.0 instead of ~0.0), ensuring the assertions are meaningful. Signed-off-by: Jeffrey Qiu <hongjie.qiu@gmail.com> Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/losses/test_generalized_wasserstein_dice_loss.py (1)
235-250:reduction="none"sub-test with identical perfect predictions can't catch the batch broadcasting bug.For B identical samples, the buggy code scales both
true_posanddenomby B (the extraneous broadcasting factor), so the ratio(2·B·tp) / (2·B·tp + all_error)is unchanged from the single-sample result regardless of prediction quality. With perfect predictions the loss is ≈ 0 for both buggy and correct code; with poor predictions it's ≈ 1 for both. The assertion0 ≈ 0(or1 ≈ 1) trivially passes on unpatched code.
test_batch_size_different_samplesis the proper regression anchor — it catches the bug because cross-batch sample contamination changes the numerator for sample 1 meaningfully (≈0.67 buggy vs ≈1.0 correct). This sub-test intest_batch_size_greater_than_oneends up only verifying that the batch infrastructure doesn't add noise on top of perfect predictions, which is a much weaker guarantee.Consider replacing
pred_single(perfect) with a poor prediction, and pairing it with different targets per batch element (e.g., offset the second sample), so the "none" path has a non-trivially-different expected value and actually exercises cross-sample isolation:♻️ Suggested improvement
- target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) - target_single = target_single.unsqueeze(0) # shape (1, H, W) - pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float() - - # Create a batch of size 2 by repeating the same sample - target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W) - pred_batch = pred_single.repeat(2, 1, 1, 1) # shape (2, C, H, W) + target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + target_single = target_single.unsqueeze(0) # shape (1, H, W) + pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float() + + # Create a batch with one good and one poor prediction so the "none" + # path has a non-trivial per-sample expected value. + target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W) + pred_batch = torch.cat([pred_poor, pred_poor], dim=0) # shape (2, C, H, W) for w_mode in ["default", "GDL"]: loss_fn = GeneralizedWassersteinDiceLoss( dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none" ) - loss_single = loss_fn(pred_single, target_single) - loss_batch = loss_fn(pred_batch, target_batch) + loss_single = loss_fn(pred_poor, target_single) + loss_batch = loss_fn(pred_batch, target_batch) for i in range(2): + self.assertGreater(float(loss_single[0]), 0.5, + msg=f"Expected non-trivial loss for weighting_mode={w_mode}") self.assertAlmostEqual( float(loss_batch[i]), float(loss_single[0]), places=5, msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 235 - 250, The current reduction="none" sub-test uses identical perfect predictions so it cannot detect the batching/broadcasting bug; update the test in tests/losses/test_generalized_wasserstein_dice_loss.py to use a non-perfect prediction for pred_single (e.g., a poor/confusing prediction) and construct pred_batch/target_batch so the two batch elements are different (for example offset or swap labels in the second sample) rather than duplicating the perfect sample; then call GeneralizedWassersteinDiceLoss(..., reduction="none") and assert each loss_batch[i] equals the corresponding loss_single[i] (compare loss_batch[i] to loss_single[i] or compute single-sample losses for each batch element) to ensure cross-sample isolation and surface the broadcasting bug (referenced symbols: pred_single, pred_batch, target_single, target_batch, GeneralizedWassersteinDiceLoss, loss_fn, loss_single, loss_batch).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 235-250: The current reduction="none" sub-test uses identical
perfect predictions so it cannot detect the batching/broadcasting bug; update
the test in tests/losses/test_generalized_wasserstein_dice_loss.py to use a
non-perfect prediction for pred_single (e.g., a poor/confusing prediction) and
construct pred_batch/target_batch so the two batch elements are different (for
example offset or swap labels in the second sample) rather than duplicating the
perfect sample; then call GeneralizedWassersteinDiceLoss(..., reduction="none")
and assert each loss_batch[i] equals the corresponding loss_single[i] (compare
loss_batch[i] to loss_single[i] or compute single-sample losses for each batch
element) to ensure cross-sample isolation and surface the broadcasting bug
(referenced symbols: pred_single, pred_batch, target_single, target_batch,
GeneralizedWassersteinDiceLoss, loss_fn, loss_single, loss_batch).
Fixes #4650
Description
When
batch_size > 1,GeneralizedWassersteinDiceLossproduces incorrect loss values because of a tensor broadcasting issue in_compute_generalized_true_positiveand_compute_denominator.After
torch.gather,alpha_extendedhas shape(B, 1, S)whilewasserstein_distance_maphas shape(B, S). The element-wise multiply silently broadcasts to(B, B, S), which mixes values across batch samples. This means the loss has always been wrong for any training run withbatch_size > 1.The fix follows the reference implementation by the original paper's author — squeeze
dim=1after the gather so both tensors are(B, S), and reduce withdim=1instead ofdim=[1, 2].I also noticed that
reduction="none"was broken (never had test coverage) — it tried to reshape the per-sample loss(B,)into(B, C, 1, ...), but GWDL aggregates over classes internally so the class dimension doesn't exist in the output. Fixed that as well.Changes
monai/losses/dice.py: squeeze + dim fix in_compute_generalized_true_positiveand_compute_denominator; fixedreduction="none"pathtests/losses/test_generalized_wasserstein_dice_loss.py: two new regression tests for batch consistencyTests
All existing tests pass. The new regression tests fail on unpatched code and pass with the fix.