diff --git a/earth2studio/models/px/stormcast.py b/earth2studio/models/px/stormcast.py index 430b009d..a85ce100 100644 --- a/earth2studio/models/px/stormcast.py +++ b/earth2studio/models/px/stormcast.py @@ -26,6 +26,7 @@ from modulus.models import Module from modulus.utils.generative import deterministic_sampler from omegaconf import OmegaConf +from packaging.version import Version from earth2studio.data import DataSource, fetch_data from earth2studio.models.auto import AutoModelMixin, Package @@ -220,8 +221,8 @@ def load_model(cls, package: Package) -> DiagnosticModel: """Load StormCast model.""" # Require appropriate modulus version - installed_version = modulus.__version__ - if installed_version < "0.10.0a0": + installed_version = Version(modulus.__version__) + if installed_version < Version("0.10.0a0"): raise RuntimeError( f"modulus version 0.10.0a0 or later is required " f"to load the StormCast package from NGC, "