-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Adjust execution order of activation and masking in MaskedDiceLoss #8704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughMaskedDiceLoss now exposes an explicit constructor with activation options ( Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/losses/dice.py`:
- Around line 256-283: The forward block mutates instance flags (self.sigmoid,
self.softmax, self.other_act) before calling spatial_weighted, which is not
thread-safe; instead avoid changing instance state by computing local variables
for the activated input and any activation flags (e.g., use
local_sigmoid/local_softmax/local_other_act or simply pass the already-activated
input) and call spatial_weighted with those locals (or with the activated input
and any needed override parameters) rather than flipping
self.sigmoid/self.softmax/self.other_act; remove the try/finally restore logic
and ensure spatial_weighted signature or call site (spatial_weighted) accepts
the local overrides or only the prepared input to eliminate instance mutation
and race conditions.
🧹 Nitpick comments (3)
monai/losses/dice.py (2)
263-263: Addstacklevel=2to warning.Per static analysis (B028): Set
stacklevel=2so the warning points to the caller rather than this line.Fix
- warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
257-268: Code duplication withDiceLoss.forwardactivation handling.Lines 257-268 mirror DiceLoss.forward (lines 154-165). Consider extracting a
_apply_activationshelper to reduce duplication.tests/losses/test_masked_dice_loss.py (1)
115-115: Class name mismatch.Class
TestDiceLosstestsMaskedDiceLoss. Consider renaming toTestMaskedDiceLoss.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/dice.pytests/losses/test_masked_dice_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_masked_dice_loss.pymonai/losses/dice.py
🪛 Ruff (0.14.11)
monai/losses/dice.py
263-263: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: packaging
🔇 Additional comments (1)
tests/losses/test_masked_dice_loss.py (1)
30-30: Updated expected values reflect activation-before-masking fix.The test values (0.333333, 0.301128, 0.579184) correctly correspond to the implementation change where sigmoid is applied before masking. The code confirms that activation layers execute at lines 257–268, then masked loss computation follows at line 279.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Signed-off-by: ytl0623 <david89062388@gmail.com>
092990d to
adff162
Compare
|
Hi @ytl0623 thanks for this fix, but I don't feel this is the right solution for it. The issue here is that the activation needs to happen before masking however the approach here creates a new def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
squared_pred: bool = False,
jaccard: bool = False,
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
) -> None:
"""
Args follow :py:class:`monai.losses.DiceLoss`.
"""
self.pre_sigmoid=sigmoid
self.pre_softmax=softmax
self.pre_other_act=other_act
super().__init__(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=False,
softmax=False,
other_act=None,
squared_pred=squared_pred,
jaccard=jaccard,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
batch=batch,
weight=class_weight,
soft_label=soft_label
)
self.spatial_weighted = MaskedLoss(loss=super().forward)Your |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/losses/dice.py`:
- Around line 262-269: The constructor currently assigns pre_other_act without
validating other_act, so a non-callable other_act causes a runtime error later;
add a check in the DiceLoss __init__ (before assigning self.pre_other_act) that
if other_act is not None and not callable(other_act) then raise TypeError, and
ensure this validation sits alongside the existing incompatible-flag checks for
sigmoid/softmax (referencing other_act, pre_other_act, pre_sigmoid,
pre_softmax).
🧹 Nitpick comments (1)
monai/losses/dice.py (1)
303-303: Addstacklevel=2towarnings.warn.Without it, the warning points to this internal line rather than the caller's code.
Proposed fix
- warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
e96354a to
9bde913
Compare
Signed-off-by: ytl0623 <david89062388@gmail.com>
1e5f824 to
061f0d2
Compare
Fixes #8655
Description
MaskedDiceLoss.forwardto ensure masked regions result in 0.0 probability instead of 0.5.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.