diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9f1429d477..03702e444c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -85,6 +85,10 @@ cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") exposure, has_skimage = optional_import("skimage.exposure") +_cucim_skimage, _has_cucim_skimage = optional_import("cucim.skimage") +_cucim_morphology_edt, _has_cucim_morphology = optional_import( + "cucim.core.operations.morphology", name="distance_transform_edt" +) __all__ = [ "allow_missing_keys_mode", @@ -1147,11 +1151,10 @@ def get_largest_connected_component_mask( """ # use skimage/cucim.skimage and np/cp depending on whether packages are # available and input is non-cpu torch.tensor - skimage, has_cucim = optional_import("cucim.skimage") - use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu") + use_cp = has_cp and _has_cucim_skimage and isinstance(img, torch.Tensor) and img.device != torch.device("cpu") if use_cp: img_ = convert_to_cupy(img.short()) # type: ignore - label = skimage.measure.label + label = _cucim_skimage.measure.label lib = cp else: if not has_measure: @@ -1204,13 +1207,13 @@ def keep_merge_components_with_points( margins: include points outside of the region but within the margin. """ - cucim_skimage, has_cucim = optional_import("cucim.skimage") - - use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") + use_cp = ( + has_cp and _has_cucim_skimage and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") + ) if use_cp: img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore - label = cucim_skimage.measure.label + label = _cucim_skimage.measure.label lib = cp else: if not has_measure: @@ -2463,10 +2466,7 @@ def distance_transform_edt( Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64. """ - distance_transform_edt, has_cucim = optional_import( - "cucim.core.operations.morphology", name="distance_transform_edt" - ) - use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" + use_cp = has_cp and _has_cucim_morphology and isinstance(img, torch.Tensor) and img.device.type == "cuda" if not return_distances and not return_indices: raise RuntimeError("Neither return_distances nor return_indices True") @@ -2499,7 +2499,7 @@ def distance_transform_edt( indices_ = convert_to_cupy(indices) img_ = convert_to_cupy(img) for channel_idx in range(img_.shape[0]): - distance_transform_edt( + _cucim_morphology_edt( img_[channel_idx], sampling=sampling, return_distances=return_distances, diff --git a/monai/utils/module.py b/monai/utils/module.py index a64f73cd6b..c8851714ce 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -17,6 +17,7 @@ import pdb import re import sys +import traceback as traceback_mod import warnings from collections.abc import Callable, Collection, Hashable, Iterable, Mapping from functools import partial, wraps @@ -368,8 +369,9 @@ def optional_import( OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version'). """ - tb = None + had_exception = False exception_str = "" + tb_str = "" if name: actual_cmd = f"from {module} import {name}" else: @@ -384,8 +386,12 @@ def optional_import( if name: # user specified to load class/function/... from the module the_module = getattr(the_module, name) except Exception as import_exception: # any exceptions during import - tb = import_exception.__traceback__ + tb_str = "".join( + traceback_mod.format_exception(type(import_exception), import_exception, import_exception.__traceback__) + ) + import_exception.__traceback__ = None exception_str = f"{import_exception}" + had_exception = True else: # found the module if version_args and version_checker(pkg, f"{version}", version_args): return the_module, True @@ -394,7 +400,7 @@ def optional_import( # preparing lazy error message msg = descriptor.format(actual_cmd) - if version and tb is None: # a pure version issue + if version and not had_exception: # a pure version issue msg += f" (requires '{module} {version}' by '{version_checker.__name__}')" if exception_str: msg += f" ({exception_str})" @@ -407,10 +413,9 @@ def __init__(self, *_args, **_kwargs): + "\n\nFor details about installing the optional dependencies, please visit:" + "\n https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies" ) - if tb is None: - self._exception = OptionalImportError(_default_msg) - else: - self._exception = OptionalImportError(_default_msg).with_traceback(tb) + if tb_str: + _default_msg += f"\n\nOriginal traceback:\n{tb_str}" + self._exception = OptionalImportError(_default_msg) def __getattr__(self, name): """ diff --git a/tests/utils/test_optional_import.py b/tests/utils/test_optional_import.py index 2f640f88d0..d8aa55b907 100644 --- a/tests/utils/test_optional_import.py +++ b/tests/utils/test_optional_import.py @@ -11,7 +11,9 @@ from __future__ import annotations +import gc import unittest +import weakref from parameterized import parameterized @@ -75,6 +77,34 @@ def versioning(module, ver, a): nn, flag = optional_import("torch", "1.1", version_checker=versioning, name="nn", version_args=test_args) self.assertTrue(flag) + def test_no_traceback_leak(self): + """Verify optional_import does not retain references to stack frames (issue #7480).""" + + class _Marker: + pass + + def _do_import(): + marker = _Marker() + ref = weakref.ref(marker) + # Call optional_import for a module that does not exist. + # If the traceback is leaked, `marker` stays alive via frame references. + mod, flag = optional_import("nonexistent_module_for_leak_test") + self.assertFalse(flag) + return ref + + ref = _do_import() + gc.collect() + self.assertIsNone(ref(), "optional_import is leaking frame references via traceback") + + def test_failed_import_shows_traceback_string(self): + """Verify the error message includes the original traceback as a string.""" + mod, flag = optional_import("nonexistent_module_for_tb_test") + self.assertFalse(flag) + with self.assertRaises(OptionalImportError) as ctx: + mod.something + self.assertIn("Original traceback", str(ctx.exception)) + self.assertIn("ModuleNotFoundError", str(ctx.exception)) + if __name__ == "__main__": unittest.main()