From cc21cefda22640ba234530b87616cc6eaa29962f Mon Sep 17 00:00:00 2001 From: Kumar Abhishek <7644965+kakumarabhishek@users.noreply.github.com> Date: Thu, 19 Mar 2026 03:14:23 -0700 Subject: [PATCH] Add MCC loss Add implementation of Matthews Correlation Coefficient (MCC)-based loss Add tests for MCC loss Add entry for MCC loss in documentation Signed-off-by: Kumar Abhishek <7644965+kakumarabhishek@users.noreply.github.com> --- docs/source/losses.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/mcc_loss.py | 188 ++++++++++++++++++++++++++++++++++ tests/losses/test_mcc_loss.py | 154 ++++++++++++++++++++++++++++ 4 files changed, 348 insertions(+) create mode 100644 monai/losses/mcc_loss.py create mode 100644 tests/losses/test_mcc_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 528ccd1173..baeebbbe9c 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -98,6 +98,11 @@ Segmentation Losses .. autoclass:: NACLLoss :members: +`MCCLoss` +~~~~~~~~~ +.. autoclass:: MCCLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 41935be204..f3e40b15d5 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -35,6 +35,7 @@ from .focal_loss import FocalLoss from .giou_loss import BoxGIoULoss, giou from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss +from .mcc_loss import MCCLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss from .nacl_loss import NACLLoss diff --git a/monai/losses/mcc_loss.py b/monai/losses/mcc_loss.py new file mode 100644 index 0000000000..ac2877e5f7 --- /dev/null +++ b/monai/losses/mcc_loss.py @@ -0,0 +1,188 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from collections.abc import Callable + +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks import one_hot +from monai.utils import LossReduction + + +class MCCLoss(_Loss): + """ + Compute the Matthews Correlation Coefficient (MCC) loss between two tensors. + + Unlike Dice and Tversky losses which only use TP, FP, and FN, the MCC loss considers all four entries + of the confusion matrix (TP, TN, FP, FN), making it effective for class-imbalanced segmentation tasks + where background dominates the image. The loss is computed as ``1 - MCC`` where + ``MCC = (TP * TN - FP * FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))``. + + The soft confusion matrix entries are computed as: + + - ``TP = sum(input * target)`` + - ``TN = sum((1 - input) * (1 - target))`` + - ``FP = sum(input * (1 - target))`` + - ``FN = sum((1 - input) * target)`` + + The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + + The original paper: + + Abhishek, K. and Hamarneh, G. (2021) Matthews Correlation Coefficient Loss for Deep Convolutional + Networks: Application to Skin Lesion Segmentation. IEEE ISBI, pp. 225-229. + (https://doi.org/10.1109/ISBI48211.2021.9433782) + + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + reduction: LossReduction | str = LossReduction.MEAN, + smooth_nr: float = 0.0, + smooth_dr: float = 1e-5, + batch: bool = False, + ) -> None: + """ + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get + overwhelmed by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the confusion matrix entries over the batch dimension before computing MCC. + Defaults to False, MCC is computed independently for each item in the batch + before any `reduction`. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + + """ + super().__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + self.batch = batch + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + + Raises: + AssertionError: When input and target (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + Example: + >>> from monai.losses.mcc_loss import MCCLoss + >>> import torch + >>> B, C, H, W = 7, 1, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target = torch.randint(low=0, high=2, size=(B, C, H, W)).float() + >>> self = MCCLoss(reduction='none') + >>> loss = self(input, target) + """ + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + if self.batch: + reduce_axis = [0] + reduce_axis + + # Soft confusion matrix entries (Eq. 5 in the paper). + tp = torch.sum(input * target, dim=reduce_axis) + tn = torch.sum((1.0 - input) * (1.0 - target), dim=reduce_axis) + fp = torch.sum(input * (1.0 - target), dim=reduce_axis) + fn = torch.sum((1.0 - input) * target, dim=reduce_axis) + + # MCC (Eq. 3) and loss (Eq. 4). + numerator = tp * tn - fp * fn + self.smooth_nr + denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + self.smooth_dr) + + mcc = numerator / denominator + score: torch.Tensor = 1.0 - mcc + + # When fp = fn = 0, prediction is perfect but the denominator product + # tends to 0 when tp = 0 or tn = 0, giving mcc ~ 0 instead of 1. + perfect = (fp == 0) & (fn == 0) + score = torch.where(perfect, torch.zeros_like(score), score) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(score) + if self.reduction == LossReduction.NONE.value: + return score + if self.reduction == LossReduction.MEAN.value: + return torch.mean(score) + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/losses/test_mcc_loss.py b/tests/losses/test_mcc_loss.py new file mode 100644 index 0000000000..2451cf76d4 --- /dev/null +++ b/tests/losses/test_mcc_loss.py @@ -0,0 +1,154 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import MCCLoss +from tests.test_utils import test_script_save + +TEST_CASES = [ + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2), sigmoid + {"include_background": True, "sigmoid": True}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.733197, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2), sigmoid + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 1.0, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2), perfect prediction + {"include_background": True}, + {"input": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])}, + 0.0, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2), worst case (inverted) + {"include_background": True}, + {"input": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])}, + 2.0, + ], + [ # shape: (2, 2, 3), (2, 1, 3), multi-class, exclude background, one-hot + {"include_background": False, "to_onehot_y": True}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.0, + ], + [ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot + {"include_background": True, "to_onehot_y": True, "sigmoid": True}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 0.836617, + ], + [ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, batch=True + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "batch": True}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 0.845961, + ], + [ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, reduction=sum + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "sum"}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 3.346468, + ], + [ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot + {"include_background": True, "to_onehot_y": True, "softmax": True}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 0.730736, + ], + [ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot, reduction=none + {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "none"}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + [[0.461472, 0.461472], [1.0, 1.0]], + ], + [ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-ones perfect prediction + {"include_background": True}, + {"input": torch.ones(1, 1, 3, 3), "target": torch.ones(1, 1, 3, 3)}, + 0.0, + ], + [ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-zeros perfect prediction + {"include_background": True}, + {"input": torch.zeros(1, 1, 3, 3), "target": torch.zeros(1, 1, 3, 3)}, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2), other_act=torch.tanh + {"include_background": True, "other_act": torch.tanh}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 1.0, + ], +] + + +class TestMCCLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = MCCLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_ill_shape(self): + loss = MCCLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((2, 2, 3)), torch.ones((4, 5, 6))) + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertRaisesRegex(ValueError, ""): + MCCLoss(reduction="unknown")(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + MCCLoss(reduction=None)(chn_input, chn_target) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + MCCLoss(sigmoid=True, softmax=True) + with self.assertRaisesRegex(TypeError, ""): + MCCLoss(other_act="tanh") + + @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)]) + def test_input_warnings(self, include_background, softmax, to_onehot_y): + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertWarns(Warning): + loss = MCCLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y) + loss.forward(chn_input, chn_target) + + def test_script(self): + loss = MCCLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + + +if __name__ == "__main__": + unittest.main()