diff --git a/cyclops/evaluate/metrics/experimental/metric_dict.py b/cyclops/evaluate/metrics/experimental/metric_dict.py index 6f2845349..341e28e99 100644 --- a/cyclops/evaluate/metrics/experimental/metric_dict.py +++ b/cyclops/evaluate/metrics/experimental/metric_dict.py @@ -38,6 +38,8 @@ attribute="Metric", error="ignore", ) + if TorchMetric is None: + TorchMetric = type(None) LOGGER = logging.getLogger(__name__) diff --git a/cyclops/models/torch_utils.py b/cyclops/models/torch_utils.py index a5a82b730..235436272 100644 --- a/cyclops/models/torch_utils.py +++ b/cyclops/models/torch_utils.py @@ -21,6 +21,8 @@ attribute="PackedSequence", error="warn", ) + if PackedSequence is None: + PackedSequence = type(None) def _get_class_members( diff --git a/cyclops/utils/optional.py b/cyclops/utils/optional.py index 8b32e8063..e7795f344 100644 --- a/cyclops/utils/optional.py +++ b/cyclops/utils/optional.py @@ -3,14 +3,15 @@ import importlib import importlib.util import warnings -from typing import Any, Literal, Optional +from types import ModuleType +from typing import Literal, Optional, Union def import_optional_module( name: str, attribute: Optional[str] = None, error: Literal["raise", "warn", "ignore"] = "raise", -) -> Optional[Any]: +) -> Union[ModuleType, None]: """Import an optional module. Parameters @@ -27,9 +28,27 @@ def import_optional_module( Returns ------- - Optional[Any] - The imported module or attribute from the module, or `None` if the - module could not be imported. + ModuleType or None + None if the module could not be imported, + or the module or attribute if it was imported successfully. + + Raises + ------ + ImportError + If the module could not be imported and `error` is set to "raise". + + Warns + ----- + UserWarning + If the module could not be imported and `error` is set to "warn". + + Notes + ----- + This function is useful for handling optional dependencies. It will + attempt to import the specified module and return it if it is found. + If the module is not found, it will raise an ImportError, raise a + warning, or return ``None`` based on the value of + the `error` parameter. """ if error not in ("raise", "warn", "ignore"):