diff --git a/monai/losses/dice.py b/monai/losses/dice.py index cd76ec1323..b6fb471cfb 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -23,6 +23,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss from monai.losses.utils import compute_tp_fp_fn +from monai.metrics.utils import create_ignore_mask from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option @@ -67,6 +68,7 @@ def __init__( batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -100,6 +102,7 @@ def __init__( The value/values should be no less than 0. Defaults to None. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. + ignore_index: class index to ignore from the loss computation. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -122,6 +125,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.ignore_index = ignore_index weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -163,12 +167,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.other_act is not None: input = self.other_act(input) + original_target = target if self.ignore_index is not None else None + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) + mask = create_ignore_mask(original_target if original_target is not None else target, self.ignore_index) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") @@ -180,6 +187,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + if mask is not None: + input = input * mask + target = target * mask + # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index caa237fca8..de3e4f0985 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch.nn.modules.loss import _Loss +from monai.metrics.utils import create_ignore_mask from monai.networks import one_hot from monai.utils import LossReduction @@ -73,6 +74,7 @@ def __init__( weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -99,6 +101,7 @@ def __init__( use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. + ignore_index: class index to ignore from the loss computation. Example: >>> import torch @@ -124,6 +127,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -161,6 +165,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + mask = create_ignore_mask(target, self.ignore_index) + if mask is not None: + input = input * mask + target = target * mask + loss: torch.Tensor | None = None input = input.float() target = target.float() diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 154f34c526..887d5f4c0d 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -18,6 +18,7 @@ from torch.nn.modules.loss import _Loss from monai.losses.utils import compute_tp_fp_fn +from monai.metrics.utils import create_ignore_mask from monai.networks import one_hot from monai.utils import LossReduction @@ -51,6 +52,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -77,6 +79,7 @@ def __init__( before any `reduction`. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. + ignore_index: index of the class to ignore during calculation. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -101,6 +104,7 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch self.soft_label = soft_label + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -129,8 +133,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: + original_target = target target = one_hot(target, num_classes=n_pred_ch) + if self.ignore_index is not None: + mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target + mask = create_ignore_mask(mask_src, self.ignore_index) + if mask is not None: + input = input * mask + target = target * mask + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 745513fec0..106123a69d 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -16,6 +16,7 @@ import torch from torch.nn.modules.loss import _Loss +from monai.metrics.utils import create_ignore_mask from monai.networks import one_hot from monai.utils import LossReduction @@ -39,6 +40,7 @@ def __init__( gamma: float = 0.75, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ) -> None: """ Args: @@ -46,41 +48,61 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + ignore_index: class index to ignore from the loss computation. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] + # Save original for masking + original_y_true = y_true if self.ignore_index is not None else None + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: + if self.ignore_index is not None: + # Replace ignore_index with valid class before one_hot + y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true) y_true = one_hot(y_true, num_classes=n_pred_ch) if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN + mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index) + if mask is not None: + mask = mask.expand_as(y_true) + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) - tp = torch.sum(y_true * y_pred, dim=axis) - fn = torch.sum(y_true * (1 - y_pred), dim=axis) - fp = torch.sum((1 - y_true) * y_pred, dim=axis) + if mask is not None: + tp = torch.sum(y_true * y_pred * mask, dim=axis) + fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis) + fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis) + else: + tp = torch.sum(y_true * y_pred, dim=axis) + fn = torch.sum(y_true * (1 - y_pred), dim=axis) + fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma) # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) + loss = torch.stack([back_dice, fore_dice], dim=-1) + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) return loss @@ -103,6 +125,7 @@ def __init__( gamma: float = 2, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ): """ Args: @@ -110,35 +133,58 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + ignore_index: class index to ignore from the loss computation. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] + # Save original for masking + original_y_true = y_true if self.ignore_index is not None else None + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: + if self.ignore_index is not None: + # Replace ignore_index with valid class before one_hot + y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true) y_true = one_hot(y_true, num_classes=n_pred_ch) if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index) + if mask is not None: + mask = mask.expand_as(y_true) + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * torch.log(y_pred) + if mask is not None: + cross_entropy = cross_entropy * mask + back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce fore_ce = cross_entropy[:, 1] fore_ce = self.delta * fore_ce - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) + loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W] + + if self.reduction == LossReduction.MEAN.value: + if mask is not None: + masked_loss = loss * mask + return masked_loss.sum() / mask.expand_as(loss).sum().clamp(min=1e-5) + return loss.mean() + if self.reduction == LossReduction.SUM.value: + return loss.sum() return loss @@ -162,6 +208,7 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ): """ Args: @@ -170,8 +217,7 @@ def __init__( weight : weight for each loss function. Defaults to 0.5. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - - + ignore_index: class index to ignore from the loss computation. Example: >>> import torch @@ -187,10 +233,14 @@ def __init__( self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.asy_focal_loss = AsymmetricFocalLoss( + to_onehot_y=False, gamma=self.gamma, delta=self.delta, ignore_index=ignore_index + ) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + to_onehot_y=False, gamma=self.gamma, delta=self.delta, ignore_index=ignore_index + ) + self.ignore_index = ignore_index - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: @@ -207,28 +257,41 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: ValueError: When num_classes ValueError: When the number of classes entered does not match the expected number """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") + # Transform binary inputs to 2-channel space if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") + # Save original for masking before one-hot conversion + original_y_true = y_true if self.ignore_index is not None else None - n_pred_ch = y_pred.shape[1] if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + if self.ignore_index is not None: + # Replace ignore_index with valid class before one_hot + y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true) + y_true = one_hot(y_true, num_classes=self.num_classes) + elif y_true.shape[1] == 1 and y_pred.shape[1] == 2: + y_true = torch.cat([1 - y_true, y_true], dim=1) + + if y_true.shape != y_pred.shape: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1: + raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}") + + mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index) + + if mask is not None: + mask_expanded = mask.expand_as(y_true) + y_pred_masked = y_pred * mask_expanded + y_true_masked = y_true * mask_expanded + else: + y_pred_masked = y_pred + y_true_masked = y_true - asy_focal_loss = self.asy_focal_loss(y_pred, y_true) - asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) + asy_focal_loss = self.asy_focal_loss(y_pred_masked, y_true_masked) + asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred_masked, y_true_masked) loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 26ec823081..9b667526cb 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -16,7 +16,7 @@ import torch -from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.metrics.utils import create_ignore_mask, do_metric_reduction, ignore_background from monai.utils import MetricReduction, ensure_tuple from .metric import CumulativeIterationMetric @@ -69,6 +69,7 @@ def __init__( compute_sample: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -76,6 +77,7 @@ def __init__( self.compute_sample = compute_sample self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ @@ -96,7 +98,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False - return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) + return get_confusion_matrix( + y_pred=y_pred, y=y, include_background=self.include_background, ignore_index=self.ignore_index + ) def aggregate( self, compute_sample: bool = False, reduction: MetricReduction | str | None = None @@ -131,7 +135,9 @@ def aggregate( return results -def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: +def get_confusion_matrix( + y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_index: int | None = None +) -> torch.Tensor: """ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for @@ -145,6 +151,9 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou The values should be binarized. include_background: whether to include metric computation on the first channel of the predicted output. Defaults to True. + ignore_index: index of the class to ignore during calculation. + If ignore_index < number of classes, that class channel is excluded + else ignored regions are inferred from spatial locations where all label channels are zero. Raises: ValueError: when `y_pred` and `y` have different shapes. @@ -158,17 +167,35 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # get confusion matrix related metric batch_size, n_class = y_pred.shape[:2] + + # Create spatial mask if ignore_index is provided + mask = create_ignore_mask(y, ignore_index) + # convert to [BNS], where S is the number of pixels for one sample. - # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) + + if mask is not None: + mask = mask.reshape(batch_size, 1, -1) + y_pred = y_pred * mask + y = y * mask + tp = (y_pred + y) == 2 tn = (y_pred + y) == 0 + if mask is not None: + # When masking, TN must only count locations where the mask is 1 + tn = tn * mask.bool() + tp = tp.sum(dim=[2]).float() tn = tn.sum(dim=[2]).float() p = y.sum(dim=[2]).float() - n = y.shape[-1] - p + + if mask is not None: + # n is total valid pixels (per sample) minus the positives for that class + n = mask.reshape(batch_size, -1).sum(dim=1, keepdim=True) - p + else: + n = y.shape[-1] - p fn = p - tp fp = n - tn diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 05eb94af48..abdb63e67e 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -13,7 +13,7 @@ import torch -from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.metrics.utils import create_ignore_mask, do_metric_reduction from monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option from .metric import CumulativeIterationMetric @@ -41,6 +41,7 @@ class GeneralizedDiceScore(CumulativeIterationMetric): Old versions computed `mean` when `mean_batch` was provided due to bug in reduction. weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + ignore_index: class index to ignore from the metric computation. Raises: ValueError: When the `reduction` is not one of MetricReduction enum. @@ -51,11 +52,13 @@ def __init__( include_background: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN, weight_type: Weight | str = Weight.SQUARE, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.ignore_index = ignore_index self.sum_over_classes = self.reduction in { MetricReduction.SUM, MetricReduction.MEAN, @@ -71,6 +74,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + Note: + The ignore_index for this computation is taken from self.ignore_index if set during initialization. Returns: torch.Tensor: Generalized Dice Score averaged across batch and class @@ -84,6 +89,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor include_background=self.include_background, weight_type=self.weight_type, sum_over_classes=self.sum_over_classes, + ignore_index=self.ignore_index, ) @deprecated_arg( @@ -118,6 +124,7 @@ def compute_generalized_dice( include_background: bool = True, weight_type: Weight | str = Weight.SQUARE, sum_over_classes: bool = False, + ignore_index: int | None = None, ) -> torch.Tensor: """ Computes the Generalized Dice Score and returns a tensor with its per image values. @@ -131,7 +138,8 @@ def compute_generalized_dice( predicted output. Defaults to True. weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. - sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation. + sum_over_classes (bool): Whether to sum the numerator and denominator across all classes before the final computation. + ignore_index: class index to ignore from the metric computation. Returns: torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. @@ -147,52 +155,77 @@ def compute_generalized_dice( if y.shape != y_pred.shape: raise ValueError(f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes.") - # Ignore background, if needed + # Apply ignore_index masking + mask = create_ignore_mask(y, ignore_index) + if mask is not None: + y_pred = y_pred * mask + y = y * mask + + n_channels = y_pred.shape[1] + channels_to_use = list(range(n_channels)) + if not include_background: - y_pred, y = ignore_background(y_pred=y_pred, y=y) + channels_to_use.pop(0) + + if ignore_index is not None and 0 <= ignore_index < n_channels: + # If background was 0 and we ignore class 2, we need the correct absolute index + if ignore_index in channels_to_use: + channels_to_use.remove(ignore_index) + + if not channels_to_use: + return torch.zeros(y_pred.shape[0], 1, device=y_pred.device) # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator reduce_axis = list(range(2, y_pred.dim())) - intersection = torch.sum(y * y_pred, dim=reduce_axis) - y_o = torch.sum(y, dim=reduce_axis) - y_pred_o = torch.sum(y_pred, dim=reduce_axis) + intersection = torch.sum(y[:, channels_to_use, ...] * y_pred[:, channels_to_use, ...], dim=reduce_axis) + y_o = torch.sum(y[:, channels_to_use, ...], dim=reduce_axis) + y_pred_o = torch.sum(y_pred[:, channels_to_use, ...], dim=reduce_axis) + denominator = y_o + y_pred_o # Set the class weights + # Set the class weights (computed from scored channels only) weight_type = look_up_option(weight_type, Weight) + y_o_float = y_o.float() + if weight_type == Weight.SIMPLE: - w = torch.reciprocal(y_o.float()) + w = torch.reciprocal(y_o_float) elif weight_type == Weight.SQUARE: - w = torch.reciprocal(y_o.float() * y_o.float()) + w = torch.reciprocal(y_o_float * y_o_float) else: - w = torch.ones_like(y_o.float()) + w = torch.ones_like(y_o_float) # Replace infinite values for non-appearing classes by the maximum weight - for b in w: - infs = torch.isinf(b) - b[infs] = 0 - b[infs] = torch.max(b) + for b_idx in range(w.shape[0]): + batch_w = w[b_idx] + infs = torch.isinf(batch_w) + if infs.any(): + batch_w[infs] = 0 + max_w = torch.max(batch_w) + batch_w[infs] = max_w if max_w > 0 else 1.0 - # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True if sum_over_classes: - numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) - denom = (denominator * w).sum(dim=1, keepdim=True) - y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + intersection = (intersection * w).sum(dim=1, keepdim=True) + denominator = (denominator * w).sum(dim=1, keepdim=True) + numer = 2.0 * intersection + denom = denominator else: numer = 2.0 * (intersection * w) denom = denominator * w - y_pred_o = y_pred_o # Compute the score - generalized_dice_score = numer / denom + generalized_dice_score = numer / (denom + 1e-6) - # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. - # Where denom == 0 but the prediction volume is not 0, score is 0 + # Handle zero division denom_zeros = denom == 0 - generalized_dice_score[denom_zeros] = torch.where( - (y_pred_o == 0)[denom_zeros], - torch.tensor(1.0, device=generalized_dice_score.device), - torch.tensor(0.0, device=generalized_dice_score.device), - ) + if denom_zeros.any(): + if sum_over_classes: + generalized_dice_score[denom_zeros] = 1.0 + else: + generalized_dice_score[denom_zeros] = torch.where( + (y_pred_o * w)[denom_zeros] == 0, + torch.ones_like(generalized_dice_score[denom_zeros]), + torch.zeros_like(generalized_dice_score[denom_zeros]), + ) return generalized_dice_score diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1b83c93e5b..de045984d5 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -17,7 +17,13 @@ import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.metrics.utils import ( + create_ignore_mask, + do_metric_reduction, + get_edge_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -62,6 +68,7 @@ def __init__( directed: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -70,6 +77,7 @@ def __init__( self.directed = directed self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -97,6 +105,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if dims < 3: raise ValueError("y_pred should have at least three dimensions.") + mask = create_ignore_mask(y, self.ignore_index) + if mask is not None: + y_pred = y_pred * mask + y = y * mask + # compute (BxC) for each channel for each batch return compute_hausdorff_distance( y_pred=y_pred, @@ -179,17 +192,20 @@ def compute_hausdorff_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + yp = y_pred[b, c] + yt = y[b, c] + _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], - distance_metric=distance_metric, - spacing=spacing_list[b], - symmetric=not directed, - class_index=c, + yp, yt, distance_metric=distance_metric, spacing=spacing_list[b], symmetric=not directed ) + + if len(distances) == 0: + hd[b, c] = torch.tensor(0.0, device=y_pred.device) + continue + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] - max_distance = torch.max(torch.stack(percentile_distances)) - hd[b, c] = max_distance + + hd[b, c] = torch.max(torch.stack(percentile_distances)) return hd diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index fedd94fb93..403b78d9c6 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -13,7 +13,7 @@ import torch -from monai.metrics.utils import do_metric_reduction +from monai.metrics.utils import create_ignore_mask, do_metric_reduction from monai.utils import MetricReduction, deprecated_arg from .metric import CumulativeIterationMetric @@ -106,6 +106,7 @@ def __init__( ignore_empty: bool = True, num_classes: int | None = None, return_with_label: bool | list[str] = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -114,6 +115,7 @@ def __init__( self.ignore_empty = ignore_empty self.num_classes = num_classes self.return_with_label = return_with_label + self.ignore_index = ignore_index self.dice_helper = DiceHelper( include_background=self.include_background, reduction=MetricReduction.NONE, @@ -121,6 +123,7 @@ def __init__( apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, + ignore_index=self.ignore_index, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] @@ -175,6 +178,7 @@ def compute_dice( include_background: bool = True, ignore_empty: bool = True, num_classes: int | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ Computes Dice score metric for a batch of predictions. This performs the same computation as @@ -192,6 +196,7 @@ def compute_dice( num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + ignore_index: index of the class to ignore during calculation. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -204,6 +209,7 @@ def compute_dice( apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, + ignore_index=ignore_index, )(y_pred=y_pred, y=y) @@ -262,6 +268,7 @@ def __init__( num_classes: int | None = None, sigmoid: bool | None = None, softmax: bool | None = None, + ignore_index: int | None = None, ) -> None: # handling deprecated arguments if sigmoid is not None: @@ -277,8 +284,9 @@ def __init__( self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes + self.ignore_index = ignore_index - def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately for each batch item and for each channel of those items. @@ -286,7 +294,12 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Args: y_pred: input predictions with shape HW[D]. y: ground truth with shape HW[D]. + mask: binary mask where 0 indicates voxels to ignore. """ + if mask is not None: + y_pred = y_pred & mask.bool() + y = y * mask + y_o = torch.sum(y) if y_o > 0: return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred)) @@ -322,6 +335,9 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 + # Create global mask for ignored voxels if ignore_index is set + mask = create_ignore_mask(y, self.ignore_index) + first_ch = 0 if self.include_background else 1 data = [] for b in range(y_pred.shape[0]): @@ -329,7 +345,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) + + # Extract the spatial mask for the current batch item + b_mask = mask[b, 0] if mask is not None else None + + c_list.append(self.compute_channel(x_pred, x, mask=b_mask)) data.append(torch.stack(c_list)) data = torch.stack(data, dim=0).contiguous() # type: ignore diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 65c53f7aa5..ed4bef45bc 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -13,7 +13,7 @@ import torch -from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.metrics.utils import create_ignore_mask, do_metric_reduction, ignore_background from monai.utils import MetricReduction from .metric import CumulativeIterationMetric @@ -54,12 +54,14 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ignore_empty: bool = True, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = reduction self.get_not_nans = get_not_nans self.ignore_empty = ignore_empty + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ @@ -78,7 +80,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") # compute IoU (BxC) for each channel for each batch return compute_iou( - y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty + y_pred=y_pred, + y=y, + include_background=self.include_background, + ignore_empty=self.ignore_empty, + ignore_index=self.ignore_index, ) def aggregate( @@ -103,7 +109,11 @@ def aggregate( def compute_iou( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + ignore_empty: bool = True, + ignore_index: int | None = None, ) -> torch.Tensor: """Computes Intersection over Union (IoU) score metric from a batch of predictions. @@ -133,6 +143,13 @@ def compute_iou( if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + mask = create_ignore_mask(y, ignore_index) + if mask is not None: + if mask.shape != y_pred.shape: + mask = mask.expand_as(y_pred) + y_pred = y_pred * mask + y = y * mask + # reducing only spatial dimensions (not batch nor channels) n_len = len(y_pred.shape) reduce_axis = list(range(2, n_len)) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index b20b47a1a5..d5ef2a5d87 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -17,7 +17,13 @@ import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.metrics.utils import ( + create_ignore_mask, + do_metric_reduction, + get_edge_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction from .metric import CumulativeIterationMetric @@ -57,6 +63,7 @@ class SurfaceDiceMetric(CumulativeIterationMetric): If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count. If set to ``False``, `aggregate` will only return the aggregated NSD. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. """ def __init__( @@ -67,6 +74,7 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, use_subvoxels: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.class_thresholds = class_thresholds @@ -75,6 +83,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans self.use_subvoxels = use_subvoxels + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] r""" @@ -94,6 +103,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. Returns: @@ -108,6 +118,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), use_subvoxels=self.use_subvoxels, + ignore_index=self.ignore_index, ) def aggregate( @@ -142,6 +153,7 @@ def compute_surface_dice( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, use_subvoxels: bool = False, + ignore_index: int | None = None, ) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as @@ -199,6 +211,7 @@ def compute_surface_dice( else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. @@ -213,6 +226,11 @@ def compute_surface_dice( Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index :math:`b` and class :math:`c`. """ + # Apply ignore_index masking using centralized helper + mask = create_ignore_mask(y, ignore_index) + if mask is not None: + y_pred = y_pred * mask + y = y * mask if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) @@ -264,12 +282,21 @@ def compute_surface_dice( distances_gt_pred <= class_thresholds[c] ) else: - areas_pred, areas_gt = areas # type: ignore + # Handle areas being returned as a single item or a tuple + if isinstance(areas, (list, tuple)): + if len(areas) == 2: + areas_pred, areas_gt = areas + elif len(areas) == 1: + areas_pred = areas_gt = areas[0] + else: + areas_pred = areas_gt = torch.tensor([], device=y_pred.device) + else: + areas_pred = areas_gt = areas areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred] boundary_complete = areas_gt.sum() + areas_pred.sum() gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0 pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0 - boundary_correct = gt_true + pred_true + boundary_correct = gt_true + pred_true # type: ignore[assignment,operator] if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation nsd[b, c] = torch.tensor(np.nan) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 3cb336d6a0..c82ab53e95 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -17,7 +17,13 @@ import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.metrics.utils import ( + create_ignore_mask, + do_metric_reduction, + get_edge_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -46,6 +52,7 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + ignore_index: class index to ignore from the metric computation. """ @@ -56,6 +63,7 @@ def __init__( distance_metric: str = "euclidean", reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -63,6 +71,7 @@ def __init__( self.symmetric = symmetric self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -89,6 +98,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") + mask = create_ignore_mask(y, self.ignore_index) + if mask is not None: + y_pred = y_pred * mask + y = y * mask + # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, @@ -172,15 +186,16 @@ def compute_average_surface_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + yp = y_pred[b, c] + yt = y[b, c] + _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], - distance_metric=distance_metric, - spacing=spacing_list[b], - symmetric=symmetric, - class_index=c, + yp, yt, distance_metric=distance_metric, spacing=spacing_list[b], symmetric=symmetric, class_index=c ) + surface_distance = torch.cat(distances) - asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() + asd[b, c] = ( + torch.tensor(float("nan"), device=asd.device) if surface_distance.numel() == 0 else surface_distance.mean() + ) return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a451b1a770..1d72c5ab49 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -41,6 +41,7 @@ __all__ = [ "ignore_background", + "create_ignore_mask", "do_metric_reduction", "get_mask_edges", "get_surface_distance", @@ -57,7 +58,7 @@ def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayT Args: y_pred: predictions. As for classification tasks, - `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, + `y_pred` should have the shape [BN] where N is larger than 1. As for segmentation tasks, the shape should be [BNHW] or [BNHWD]. y: ground truth, the first dim is batch. @@ -68,6 +69,44 @@ def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayT return y_pred, y +def create_ignore_mask(y: torch.Tensor, ignore_index: int | None) -> torch.Tensor | None: + """ + Create a spatial mask for ignore_index functionality. + + Handles three cases: + + 1. ignore_index is None: returns None (no masking) + 2. Label-encoded input (y.shape[1] == 1): direct comparison + 3. One-hot encoded input (y.shape[1] > 1): + - Valid class index (0 <= ignore_index < num_classes): mask that channel + - Sentinel value (ignore_index >= num_classes): mask all-zero pixels + + Args: + y: Target tensor with shape (B, C, H, W, [D]). + C=1 for label-encoded, C>1 for one-hot encoded. + ignore_index: Class index or sentinel value to ignore, or None. + + Returns: + Mask tensor of shape (B, 1, H, W, [D]) where 1=valid, 0=ignore. + Returns None if ignore_index is None. + """ + if ignore_index is None: + return None + + if y.shape[1] == 1: + # Label-encoded: direct comparison + return (y != ignore_index).float() + + # One-hot encoded + num_classes = y.shape[1] + if 0 <= ignore_index < num_classes: + # Valid class index: exclude that channel + return 1.0 - y[:, ignore_index : ignore_index + 1] # type: ignore[no-any-return] + else: + # Sentinel value: exclude where all channels are zero + return (y.sum(dim=1, keepdim=True) > 0).float() + + def do_metric_reduction( f: torch.Tensor, reduction: MetricReduction | str = MetricReduction.MEAN ) -> tuple[torch.Tensor | Any, torch.Tensor]: @@ -169,7 +208,7 @@ def get_mask_edges( images. Defaults to ``True``. spacing: the input spacing. If not None, the subvoxel edges and areas will be computed. otherwise `scipy`'s binary erosion is used to calculate the edges. - always_return_as_numpy: whether to a numpy array regardless of the input type. + always_return_as_numpy: whether to return a numpy array regardless of the input type. If False, return the same type as inputs. The default value is changed from `True` to `False` in v1.5.0. """ @@ -275,14 +314,20 @@ def get_surface_distance( dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32) dis = dis[seg_gt] return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0] + if distance_metric == "euclidean": dis = monai_distance_transform_edt((~seg_gt)[None, ...], sampling=spacing)[0] # type: ignore elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(convert_to_numpy(~seg_gt), metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") + dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0] - return dis[seg_pred] # type: ignore + if isinstance(seg_pred, torch.Tensor): + return dis[seg_pred.bool()] # type: ignore[union-attr,no-any-return] + else: + # NumPy array + return dis[seg_pred.astype(bool)] # type: ignore[union-attr,no-any-return] def get_edge_surface_distance( @@ -320,28 +365,40 @@ def get_edge_surface_distance( edges_spacing = None if use_subvoxels: edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape)) - (edges_pred, edges_gt, *areas) = get_mask_edges( - y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False - ) - if not edges_gt.any(): - warnings.warn( - f"the ground truth of class {class_index if class_index != -1 else 'Unknown'} is all 0," - " this may result in nan/inf distance." - ) - if not edges_pred.any(): - warnings.warn( - f"the prediction of class {class_index if class_index != -1 else 'Unknown'} is all 0," - " this may result in nan/inf distance." - ) - distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] + + edge_results = get_mask_edges(y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False) + edges_pred, edges_gt = edge_results[0], edge_results[1] + + distances_raw: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] if symmetric: - distances = ( + distances_raw = ( get_surface_distance(edges_pred, edges_gt, distance_metric, spacing), get_surface_distance(edges_gt, edges_pred, distance_metric, spacing), ) # type: ignore else: - distances = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore - return convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] + distances_raw = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore + + distances_list = list(distances_raw) + distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] = ( + tuple(distances_list) if len(distances_list) == 2 else (distances_list[0],) # type: ignore[assignment] + ) + + areas = edge_results[2:] if use_subvoxels else () + + # Ensure areas is always a tuple of 2 when use_subvoxels=True + if use_subvoxels and isinstance(areas, (list, tuple)): + if len(areas) == 1: + areas = (areas[0], areas[0]) + elif len(areas) != 2: + # Unexpected length, create empty tensors + areas = (torch.tensor([], device=y_pred.device), torch.tensor([], device=y_pred.device)) + + out = convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] + + if out is None: + out = torch.empty((0,), device=y_pred.device) + + return out # type: ignore[return-value,no-any-return] def is_binary_tensor(input: torch.Tensor, name: str) -> None: diff --git a/tests/losses/test_ignore_index_losses.py b/tests/losses/test_ignore_index_losses.py new file mode 100644 index 0000000000..b07ba5c98d --- /dev/null +++ b/tests/losses/test_ignore_index_losses.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import AsymmetricUnifiedFocalLoss, DiceLoss, FocalLoss, TverskyLoss + +# Defining test cases: (LossClass, args) +TEST_CASES = [ + (DiceLoss, {"sigmoid": True}), + (FocalLoss, {"use_softmax": False}), + (TverskyLoss, {"sigmoid": True}), + (AsymmetricUnifiedFocalLoss, {}), +] + + +class TestIgnoreIndexLosses(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_loss_ignore_consistency(self, loss_class, kwargs): + ignore_index = 255 + loss_func = loss_class(ignore_index=ignore_index, **kwargs) + + # Create two inputs that are identical EXCEPT in the area designated as 'ignored' + # Input shape: [Batch, Channel, H, W] + input_base = torch.randn(1, 1, 4, 4) + input_alt = input_base.clone() + input_alt[0, 0, 2:, :] += 5.0 # Significant difference in the bottom half + + # Target: Top half is valid (0,1), Bottom half is ignored (255) + target = torch.tensor( + [[[[1, 0, 1, 0], [0, 1, 0, 1], [255, 255, 255, 255], [255, 255, 255, 255]]]], dtype=torch.float + ) + + # Execute + loss_base = loss_func(input_base, target) + loss_alt = loss_func(input_alt, target) + + # ASSERTION: The losses must be identical because the difference + # occurred only in the ignored region. + torch.testing.assert_close(loss_base, loss_alt, atol=1e-5, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_no_ignore_behavior(self, loss_class, kwargs): + # Ensure that when ignore_index is None, the loss functions normally + loss_func = loss_class(ignore_index=None, **kwargs) + input_data = torch.randn(1, 1, 4, 4) + target = torch.randint(0, 2, (1, 1, 4, 4)).float() + + output = loss_func(input_data, target) + self.assertFalse(torch.isnan(output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_ignore_index_metrics.py b/tests/metrics/test_ignore_index_metrics.py new file mode 100644 index 0000000000..0efcda561d --- /dev/null +++ b/tests/metrics/test_ignore_index_metrics.py @@ -0,0 +1,91 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.metrics import ( + ConfusionMatrixMetric, + DiceMetric, + GeneralizedDiceScore, + HausdorffDistanceMetric, + MeanIoU, + SurfaceDiceMetric, + SurfaceDistanceMetric, +) +from monai.utils import optional_import + +scipy, has_scipy = optional_import("scipy") + +# Test cases for metrics with their specific required arguments +TEST_METRICS = [ + (DiceMetric, {"include_background": True, "reduction": "mean"}), + (MeanIoU, {"include_background": True, "reduction": "mean"}), + (GeneralizedDiceScore, {"include_background": True}), + (ConfusionMatrixMetric, {"metric_name": "accuracy"}), +] + +# Metrics that require SciPy (Hausdorff and Surface metrics) +SCIPY_METRICS = [ + (HausdorffDistanceMetric, {"include_background": True}), + (SurfaceDistanceMetric, {"include_background": True}), + (SurfaceDiceMetric, {"class_thresholds": [0.5, 0.5], "include_background": True}), +] + + +@unittest.skipUnless(has_scipy, "Scipy required for surface metrics") +class TestIgnoreIndexMetrics(unittest.TestCase): + @parameterized.expand(TEST_METRICS + SCIPY_METRICS) + def test_metric_ignore_consistency(self, metric_class, kwargs): + # Initialize metric with ignore_index + metric = metric_class(ignore_index=255, **kwargs) + + # Batch size 1, 2 Classes, 4x4 Image + # y_pred1 and y_pred2 differ ONLY in the bottom half (the ignore zone) + y_pred1 = torch.zeros((1, 2, 4, 4)) + y_pred1[:, 1, 0:2, :] = 1.0 # Top half prediction + + y_pred2 = y_pred1.clone() + y_pred2[:, 1, 2:4, :] = 1.0 # Bottom half prediction (different!) + + # Target: Top half is valid (0/1), Bottom half should be ignored + # For ignore_index=255 (sentinel), we need to mark ignored pixels differently + # Option 1: Use ignore_index as a class index (e.g., ignore_index=1) + # Option 2: Keep one-hot but set ignored region to all zeros + y = torch.zeros((1, 2, 4, 4)) + y[:, 1, 0:2, 0:2] = 1.0 # Top-left is class 1 + y[:, 0, 0:2, 2:4] = 1.0 # Top-right is class 0 + # Bottom half: leave as all zeros to indicate "no valid class" + + # Run metric for both predictions + metric.reset() + metric(y_pred=y_pred1, y=y) + res1 = metric.aggregate() + if isinstance(res1, list): + res1 = res1[0] + + metric.reset() + metric(y_pred=y_pred2, y=y) + res2 = metric.aggregate() + if isinstance(res2, list): + res2 = res2[0] + + # The result must be identical because the spatial difference + # is hidden by the ignore_index + torch.testing.assert_close(res1, res2, msg=f"Failed for {metric_class.__name__}") + + +if __name__ == "__main__": + unittest.main()