Skip to content

Comments

Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744

Open
hongjie-qiu wants to merge 2 commits intoProject-MONAI:devfrom
hongjie-qiu:4650-fix-gwdl-batch-size
Open

Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744
hongjie-qiu wants to merge 2 commits intoProject-MONAI:devfrom
hongjie-qiu:4650-fix-gwdl-batch-size

Conversation

@hongjie-qiu
Copy link

@hongjie-qiu hongjie-qiu commented Feb 19, 2026

Fixes #4650

Description

When batch_size > 1, GeneralizedWassersteinDiceLoss produces incorrect loss values because of a tensor broadcasting issue in _compute_generalized_true_positive and _compute_denominator.

After torch.gather, alpha_extended has shape (B, 1, S) while wasserstein_distance_map has shape (B, S). The element-wise multiply silently broadcasts to (B, B, S), which mixes values across batch samples. This means the loss has always been wrong for any training run with batch_size > 1.

The fix follows the reference implementation by the original paper's author — squeeze dim=1 after the gather so both tensors are (B, S), and reduce with dim=1 instead of dim=[1, 2].

I also noticed that reduction="none" was broken (never had test coverage) — it tried to reshape the per-sample loss (B,) into (B, C, 1, ...), but GWDL aggregates over classes internally so the class dimension doesn't exist in the output. Fixed that as well.

Changes

  • monai/losses/dice.py: squeeze + dim fix in _compute_generalized_true_positive and _compute_denominator; fixed reduction="none" path
  • tests/losses/test_generalized_wasserstein_dice_loss.py: two new regression tests for batch consistency

Tests

All existing tests pass. The new regression tests fail on unpatched code and pass with the fix.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 19, 2026

📝 Walkthrough

Walkthrough

The PR corrects batch handling in GeneralizedWassersteinDiceLoss: for reduction="none" the forward now preserves a per-sample shape (B,) instead of broadcasting to per-voxel. Internal helpers _compute_generalized_true_positive and _compute_denominator squeeze the gathered alpha_extended class dimension and switch their sum reductions from dims [1, 2] to dim=1, keeping results per batch. Two regression tests were added to ensure per-sample losses match single-sample computations for identical and differing batch entries across weighting modes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and specifically describes the main bug fix: a batch size broadcasting issue in GeneralizedWassersteinDiceLoss.
Description check ✅ Passed Description covers all required sections: issue reference, clear problem explanation, implementation details, file changes, and test coverage notes.
Linked Issues check ✅ Passed Code changes directly address issue #4650: fixes broadcasting bug via squeeze+dim reduction in loss computation and adds regression tests validating batch consistency.
Out of Scope Changes check ✅ Passed All changes are tightly scoped to fixing the identified broadcasting bug and adding corresponding test coverage; no unrelated modifications detected.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
monai/losses/dice.py (2)

509-514: Document the reduction="none" output shape in the docstring.

The forward docstring has no Returns: section. Since this PR changes the reduction="none" output shape from a broken reshape attempt to a well-defined (B,), callers need to know what to expect.

📝 Proposed docstring update
 def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
     """
     Args:
         input: the shape should be BNH[WD].
         target: the shape should be BNH[WD].

+    Returns:
+        Scalar when ``reduction`` is ``"mean"`` or ``"sum"``.
+        Tensor of shape ``(B,)`` when ``reduction`` is ``"none"``, one loss value
+        per sample (GWDL aggregates over classes and spatial dims internally).
     """

As per coding guidelines, Google-style docstrings should describe each return value in a Returns: section.

Also applies to: 550-553

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/dice.py` around lines 509 - 514, Update the forward docstring
for the Dice loss (method forward) to add a Google-style "Returns:" section that
clearly states the tensor shape and meaning for each reduction mode;
specifically document that when reduction="none" the output is a 1-D tensor of
shape (B,) containing per-batch loss values (and describe shapes for
"mean"/"sum" if applicable), and apply the same docstring clarification to the
other forward variant referenced around lines 550-553 so callers know to expect
a (B,) output instead of the previous reshape behavior.

596-630: Both helpers share identical alpha-mapping code — extract a private helper.

The five-line alpha-extension/gather/squeeze block is duplicated verbatim in _compute_generalized_true_positive and _compute_denominator. A private _map_alpha_to_voxels would remove the duplication and make future fixes a one-place change (as this PR illustrates — the squeeze had to be added in both places).

Additionally, both methods are missing Returns: sections in their docstrings. As per coding guidelines, Google-style docstrings must describe return values.

♻️ Proposed refactor
+    def _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor:
+        """Map per-class alpha weights to a per-voxel tensor via flat_target.
+
+        Args:
+            alpha: per-class weights of shape (B, C).
+            flat_target: flattened target labels of shape (B, S).
+
+        Returns:
+            Per-voxel alpha values of shape (B, S).
+        """
+        alpha_extended = torch.unsqueeze(alpha, dim=2)
+        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
+        flat_target_extended = torch.unsqueeze(flat_target, dim=1)
+        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
+        return torch.squeeze(alpha_extended, dim=1)

     def _compute_generalized_true_positive(
         self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
     ) -> torch.Tensor:
         """
         Args:
             alpha: generalised number of true positives of target class.
             flat_target: the target tensor.
             wasserstein_distance_map: the map obtained from the above function.
+
+        Returns:
+            Per-sample generalised true positives of shape (B,).
         """
-        # Extend alpha to a map and select value at each voxel according to flat_target
-        alpha_extended = torch.unsqueeze(alpha, dim=2)
-        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
-        flat_target_extended = torch.unsqueeze(flat_target, dim=1)
-        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
-        alpha_extended = torch.squeeze(alpha_extended, dim=1)
-        return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1)
+        alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target)
+        return torch.sum(alpha_per_voxel * (1.0 - wasserstein_distance_map), dim=1)

     def _compute_denominator(
         self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
     ) -> torch.Tensor:
         """
         Args:
             alpha: generalised number of true positives of target class.
             flat_target: the target tensor.
             wasserstein_distance_map: the map obtained from the above function.
+
+        Returns:
+            Per-sample denominator of shape (B,).
         """
-        # Extend alpha to a map and select value at each voxel according to flat_target
-        alpha_extended = torch.unsqueeze(alpha, dim=2)
-        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
-        flat_target_extended = torch.unsqueeze(flat_target, dim=1)
-        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
-        alpha_extended = torch.squeeze(alpha_extended, dim=1)
-        return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1)
+        alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target)
+        return torch.sum(alpha_per_voxel * (2.0 - wasserstein_distance_map), dim=1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/dice.py` around lines 596 - 630, Both methods
_compute_generalized_true_positive and _compute_denominator duplicate the
alpha-extension/gather/squeeze block; extract that logic into a private helper
(e.g., _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target:
torch.Tensor) -> torch.Tensor) that returns the per-voxel alpha_extended tensor,
update both methods to call this helper and use its result in their sums, and
remove the duplicated code; also add a Returns: section to the docstrings of
_compute_generalized_true_positive and _compute_denominator describing the
returned tensor shape and meaning.
tests/losses/test_generalized_wasserstein_dice_loss.py (1)

293-295: float() on a reduction="none" output is fragile for future readers.

loss_fn(pred_a, target_a) returns shape (1,) here; float() only works because batch size is 1. Prefer .item() or index explicitly to make intent clear.

🔧 Suggested clarification
-            loss_a = float(loss_fn(pred_a, target_a))
-            loss_b = float(loss_fn(pred_b, target_b))
+            loss_a = loss_fn(pred_a, target_a)[0].item()
+            loss_b = loss_fn(pred_b, target_b)[0].item()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 293 -
295, The tests call float(loss_fn(pred_a, target_a)) and float(loss_fn(pred_b,
target_b)) where loss_fn is configured with reduction="none" and returns a
1-element tensor; using float() is fragile and hides the intent. Replace
float(...) with .item() (e.g., loss_fn(pred_a, target_a).item()) or explicitly
index [0] to extract the scalar so the code clearly indicates you're converting
a single-element tensor to a Python float; update both occurrences referencing
loss_fn, pred_a, target_a, pred_b, and target_b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 254-270: The mean-reduction subtest uses trivially-zero losses
(pred_single from pred_very_good), so it won't catch a broken mean reduction;
update the test to use a non-trivial prediction for at least one sample: when
constructing pred_single/pred_batch for the GeneralizedWassersteinDiceLoss
checks, replace one of the batch entries (or pred_single) with a clearly poor
prediction (e.g., pred_very_poor) so loss_single and loss_batch produce a
non-zero, different-per-sample value; keep using loss_fn =
GeneralizedWassersteinDiceLoss(..., weighting_mode=w_mode, reduction="mean") and
then assert loss_batch equals loss_single to verify mean reduction behavior.

---

Nitpick comments:
In `@monai/losses/dice.py`:
- Around line 509-514: Update the forward docstring for the Dice loss (method
forward) to add a Google-style "Returns:" section that clearly states the tensor
shape and meaning for each reduction mode; specifically document that when
reduction="none" the output is a 1-D tensor of shape (B,) containing per-batch
loss values (and describe shapes for "mean"/"sum" if applicable), and apply the
same docstring clarification to the other forward variant referenced around
lines 550-553 so callers know to expect a (B,) output instead of the previous
reshape behavior.
- Around line 596-630: Both methods _compute_generalized_true_positive and
_compute_denominator duplicate the alpha-extension/gather/squeeze block; extract
that logic into a private helper (e.g., _map_alpha_to_voxels(self, alpha:
torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor) that returns the
per-voxel alpha_extended tensor, update both methods to call this helper and use
its result in their sums, and remove the duplicated code; also add a Returns:
section to the docstrings of _compute_generalized_true_positive and
_compute_denominator describing the returned tensor shape and meaning.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 293-295: The tests call float(loss_fn(pred_a, target_a)) and
float(loss_fn(pred_b, target_b)) where loss_fn is configured with
reduction="none" and returns a 1-element tensor; using float() is fragile and
hides the intent. Replace float(...) with .item() (e.g., loss_fn(pred_a,
target_a).item()) or explicitly index [0] to extract the scalar so the code
clearly indicates you're converting a single-element tensor to a Python float;
update both occurrences referencing loss_fn, pred_a, target_a, pred_b, and
target_b.

…oject-MONAI#4650)

After `torch.gather`, `alpha_extended` retains shape (B, 1, S) while
`wasserstein_distance_map` has shape (B, S). When batch size > 1 the
element-wise multiply broadcasts to (B, B, S), mixing values across
samples. Fixed by squeezing dim=1 after gather in both
`_compute_generalized_true_positive` and `_compute_denominator`, and
reducing with `dim=1` instead of `dim=[1, 2]`.

Also fixed the `reduction="none"` code path which incorrectly tried to
reshape the per-sample loss tensor (B,) to (B, C, 1, ...) — GWDL
aggregates over classes internally so the class dimension doesn't apply.

Added regression tests that verify batch consistency:
- identical samples in a batch produce the same loss as a single sample
- batched per-sample losses match individually computed losses

Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
@hongjie-qiu hongjie-qiu force-pushed the 4650-fix-gwdl-batch-size branch from 063df92 to 4887d9d Compare February 19, 2026 17:57
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 268-297: The test test_batch_size_different_samples is currently
trivial because both pred_a and pred_b are perfect one-hot predictions; change
pred_b to produce a poor prediction (e.g., build pred_b from 1 - target_b before
one_hot) so that loss_b is ≈1.0 and the batch variant checks are meaningful;
keep pred_a as the perfect 1000*F.one_hot(target_a, ...) and ensure pred_batch =
torch.cat([pred_a, pred_b], dim=0) remains consistent so loss_a, loss_b, and
loss_batch indices compare correctly inside the weighting_mode loop that calls
GeneralizedWassersteinDiceLoss and asserts loss_batch[0]==loss_a and
loss_batch[1]==loss_b.

---

Duplicate comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 252-266: The mean-reduction sub-test is weak because
pred_single/pred_batch are perfect (loss ≈ 0), so it doesn't catch
mean-reduction bugs; update the test to use non-perfect predictions so
per-sample losses are non-zero and comparable: for
GeneralizedWassersteinDiceLoss with weighting_mode set from ["default","GDL"]
and reduction="mean", construct pred_single/pred_batch (or modify pred_single to
contain a small deliberate error and replicate into pred_batch as the batch of
samples) and/or use multiple distinct samples in pred_batch, compute loss_single
= float(loss_fn(pred_single, target_single)) and loss_batch =
float(loss_fn(pred_batch, target_batch)), then assert loss_batch ≈ loss_single
to verify the mean aggregation is correct; reference
GeneralizedWassersteinDiceLoss, pred_single, pred_batch, target_single,
target_batch, weighting_mode, and reduction="mean" when making the change.

@ericspod
Copy link
Member

Hi @hongjie-qiu thanks for looking into this. Please look at the comments from Coderabbit and we can then review again.

Address review feedback: use poor predictions in mean-reduction and
different-samples tests so the expected loss values are non-trivial
(~1.0 instead of ~0.0), ensuring the assertions are meaningful.

Signed-off-by: Jeffrey Qiu <hongjie.qiu@gmail.com>
Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/losses/test_generalized_wasserstein_dice_loss.py (1)

235-250: reduction="none" sub-test with identical perfect predictions can't catch the batch broadcasting bug.

For B identical samples, the buggy code scales both true_pos and denom by B (the extraneous broadcasting factor), so the ratio (2·B·tp) / (2·B·tp + all_error) is unchanged from the single-sample result regardless of prediction quality. With perfect predictions the loss is ≈ 0 for both buggy and correct code; with poor predictions it's ≈ 1 for both. The assertion 0 ≈ 0 (or 1 ≈ 1) trivially passes on unpatched code.

test_batch_size_different_samples is the proper regression anchor — it catches the bug because cross-batch sample contamination changes the numerator for sample 1 meaningfully (≈0.67 buggy vs ≈1.0 correct). This sub-test in test_batch_size_greater_than_one ends up only verifying that the batch infrastructure doesn't add noise on top of perfect predictions, which is a much weaker guarantee.

Consider replacing pred_single (perfect) with a poor prediction, and pairing it with different targets per batch element (e.g., offset the second sample), so the "none" path has a non-trivially-different expected value and actually exercises cross-sample isolation:

♻️ Suggested improvement
-        target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
-        target_single = target_single.unsqueeze(0)  # shape (1, H, W)
-        pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float()
-
-        # Create a batch of size 2 by repeating the same sample
-        target_batch = target_single.repeat(2, 1, 1)  # shape (2, H, W)
-        pred_batch = pred_single.repeat(2, 1, 1, 1)  # shape (2, C, H, W)
+        target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
+        target_single = target_single.unsqueeze(0)  # shape (1, H, W)
+        pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float()
+
+        # Create a batch with one good and one poor prediction so the "none"
+        # path has a non-trivial per-sample expected value.
+        target_batch = target_single.repeat(2, 1, 1)  # shape (2, H, W)
+        pred_batch = torch.cat([pred_poor, pred_poor], dim=0)  # shape (2, C, H, W)

         for w_mode in ["default", "GDL"]:
             loss_fn = GeneralizedWassersteinDiceLoss(
                 dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none"
             )

-            loss_single = loss_fn(pred_single, target_single)
-            loss_batch = loss_fn(pred_batch, target_batch)
+            loss_single = loss_fn(pred_poor, target_single)
+            loss_batch = loss_fn(pred_batch, target_batch)

             for i in range(2):
+                self.assertGreater(float(loss_single[0]), 0.5,
+                    msg=f"Expected non-trivial loss for weighting_mode={w_mode}")
                 self.assertAlmostEqual(
                     float(loss_batch[i]),
                     float(loss_single[0]),
                     places=5,
                     msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}",
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 235 -
250, The current reduction="none" sub-test uses identical perfect predictions so
it cannot detect the batching/broadcasting bug; update the test in
tests/losses/test_generalized_wasserstein_dice_loss.py to use a non-perfect
prediction for pred_single (e.g., a poor/confusing prediction) and construct
pred_batch/target_batch so the two batch elements are different (for example
offset or swap labels in the second sample) rather than duplicating the perfect
sample; then call GeneralizedWassersteinDiceLoss(..., reduction="none") and
assert each loss_batch[i] equals the corresponding loss_single[i] (compare
loss_batch[i] to loss_single[i] or compute single-sample losses for each batch
element) to ensure cross-sample isolation and surface the broadcasting bug
(referenced symbols: pred_single, pred_batch, target_single, target_batch,
GeneralizedWassersteinDiceLoss, loss_fn, loss_single, loss_batch).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 235-250: The current reduction="none" sub-test uses identical
perfect predictions so it cannot detect the batching/broadcasting bug; update
the test in tests/losses/test_generalized_wasserstein_dice_loss.py to use a
non-perfect prediction for pred_single (e.g., a poor/confusing prediction) and
construct pred_batch/target_batch so the two batch elements are different (for
example offset or swap labels in the second sample) rather than duplicating the
perfect sample; then call GeneralizedWassersteinDiceLoss(..., reduction="none")
and assert each loss_batch[i] equals the corresponding loss_single[i] (compare
loss_batch[i] to loss_single[i] or compute single-sample losses for each batch
element) to ensure cross-sample isolation and surface the broadcasting bug
(referenced symbols: pred_single, pred_batch, target_single, target_batch,
GeneralizedWassersteinDiceLoss, loss_fn, loss_single, loss_batch).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in the generalized Wasserstein Dice loss

2 participants