Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions monai/metrics/active_learning_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def compute_variance(
y_pred = y_pred.float()

if not include_background:
y = y_pred
# TODO If this utils is made to be optional for 'y' it would be nice
y_pred, y = ignore_background(y_pred=y_pred, y=y)
y_pred = ignore_background(y_pred=y_pred)

# Set any values below 0 to threshold
y_pred[y_pred <= 0] = threshold
Expand Down
37 changes: 30 additions & 7 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Iterable, Sequence
from functools import cache, partial
from types import ModuleType
from typing import Any
from typing import Any, overload

import numpy as np
import torch
Expand Down Expand Up @@ -51,21 +51,44 @@
]


def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]:
@overload
def ignore_background(y_pred: NdarrayTensor, y: None = ...) -> NdarrayTensor: ...


@overload
def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: ...


def ignore_background(
y_pred: NdarrayTensor, y: NdarrayTensor | None = None
) -> NdarrayTensor | tuple[NdarrayTensor, NdarrayTensor]:
"""
This function is used to remove background (the first channel) for `y_pred` and `y`.

Args:
y_pred: predictions. As for classification tasks,
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
`y_pred` should have the shape [BN] where N is larger than 1. As for segmentation tasks,
the shape should be [BNHW] or [BNHWD].
y: ground truth, the first dim is batch.
y: ground truth, the first dim is batch. (Optional)

Returns:
NdarrayTensor | tuple[NdarrayTensor, NdarrayTensor]:
- If `y` is None: returns background-removed `y_pred` only.
- If `y` is provided: returns a tuple of (background-removed `y_pred`, background-removed `y`).
"""

y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment]
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # type: ignore[assignment]
return y_pred, y
y_pred_out = y_pred
if y_pred.shape[1] > 1:
y_pred_out = y_pred[:, 1:] # type: ignore

if y is None:
return y_pred_out

y_out = y
if y.shape[1] > 1:
y_out = y[:, 1:] # type: ignore

return y_pred_out, y_out


def do_metric_reduction(
Expand Down
Loading