From 741dbbef096de5546d2a78006fda5a5475048d87 Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 29 May 2024 11:30:54 +0100 Subject: [PATCH] fix torch 1.12 check --- ignite/distributed/comp_models/base.py | 4 ++-- tests/ignite/distributed/comp_models/test_base.py | 4 ++-- tests/ignite/distributed/test_auto.py | 4 ++-- tests/ignite/distributed/test_launcher.py | 4 ++-- tests/ignite/distributed/utils/test_serial.py | 6 +++--- tests/ignite/engine/test_create_supervised.py | 8 ++++---- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 6e86193381c..128472a6467 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -5,7 +5,7 @@ import torch from packaging.version import Version -_torch_version_le_112 = Version(torch.__version__) > Version("1.12.0") +_torch_version_ge_112 = Version(torch.__version__) > Version("1.12.0") class ComputationModel(metaclass=ABCMeta): @@ -329,7 +329,7 @@ def get_node_rank(self) -> int: def device(self) -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") - if _torch_version_le_112 and torch.backends.mps.is_available(): + if _torch_version_ge_112 and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") diff --git a/tests/ignite/distributed/comp_models/test_base.py b/tests/ignite/distributed/comp_models/test_base.py index c8041c6dc33..de75eec41f4 100644 --- a/tests/ignite/distributed/comp_models/test_base.py +++ b/tests/ignite/distributed/comp_models/test_base.py @@ -1,7 +1,7 @@ import pytest import torch -from ignite.distributed.comp_models.base import _SerialModel, _torch_version_le_112, ComputationModel +from ignite.distributed.comp_models.base import _SerialModel, _torch_version_ge_112, ComputationModel def test_serial_model(): @@ -16,7 +16,7 @@ def test_serial_model(): assert model.get_node_rank() == 0 if torch.cuda.is_available(): assert model.device().type == "cuda" - elif _torch_version_le_112 and torch.backends.mps.is_available(): + elif _torch_version_ge_112 and torch.backends.mps.is_available(): assert model.device().type == "mps" else: assert model.device().type == "cpu" diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index 5ad12d95e50..44c19eaa459 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -12,7 +12,7 @@ import ignite.distributed as idist from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler -from ignite.distributed.comp_models.base import _torch_version_le_112 +from ignite.distributed.comp_models.base import _torch_version_ge_112 from tests.ignite import is_mps_available_and_functional @@ -182,7 +182,7 @@ def _test_auto_model_optimizer(ws, device): @pytest.mark.skipif( - _torch_version_le_112 or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), + (not _torch_version_ge_112) or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), reason="Skip if MPS not functional", ) def test_auto_methods_no_dist(): diff --git a/tests/ignite/distributed/test_launcher.py b/tests/ignite/distributed/test_launcher.py index d6682c4a56b..ca453f6a470 100644 --- a/tests/ignite/distributed/test_launcher.py +++ b/tests/ignite/distributed/test_launcher.py @@ -8,7 +8,7 @@ from packaging.version import Version import ignite.distributed as idist -from ignite.distributed.comp_models.base import _torch_version_le_112 +from ignite.distributed.comp_models.base import _torch_version_ge_112 from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support from tests.ignite import is_mps_available_and_functional @@ -57,7 +57,7 @@ def execute(cmd, env=None): @pytest.mark.skipif( - _torch_version_le_112 or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), + (not _torch_version_ge_112) or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), reason="Skip if MPS not functional", ) def test_check_idist_parallel_no_dist(exec_filepath): diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index fdbf26e8360..414b3ba3132 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -1,7 +1,7 @@ import torch import ignite.distributed as idist -from ignite.distributed.comp_models.base import _torch_version_le_112 +from ignite.distributed.comp_models.base import _torch_version_ge_112 from tests.ignite.distributed.utils import ( _sanity_check, _test_distrib__get_max_length, @@ -18,7 +18,7 @@ def test_no_distrib(capsys): assert idist.backend() is None if torch.cuda.is_available(): assert idist.device().type == "cuda" - elif _torch_version_le_112 and torch.backends.mps.is_available(): + elif _torch_version_ge_112 and torch.backends.mps.is_available(): assert idist.device().type == "mps" else: assert idist.device().type == "cpu" @@ -41,7 +41,7 @@ def test_no_distrib(capsys): assert "ignite.distributed.utils INFO: backend: None" in out[-1] if torch.cuda.is_available(): assert "ignite.distributed.utils INFO: device: cuda" in out[-1] - elif _torch_version_le_112 and torch.backends.mps.is_available(): + elif _torch_version_ge_112 and torch.backends.mps.is_available(): assert "ignite.distributed.utils INFO: device: mps" in out[-1] else: assert "ignite.distributed.utils INFO: device: cpu" in out[-1] diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index d9b0c161f75..b465dfbeef9 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -12,7 +12,7 @@ from torch.optim import SGD import ignite.distributed as idist -from ignite.distributed.comp_models.base import _torch_version_le_112 +from ignite.distributed.comp_models.base import _torch_version_ge_112 from ignite.engine import ( _check_arg, create_supervised_evaluator, @@ -487,7 +487,7 @@ def test_create_supervised_trainer_on_cuda(): _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) -@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_ge_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_trainer_on_mps(): model_device = trainer_device = "mps" _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) @@ -668,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu(): _test_mocked_supervised_evaluator(evaluator_device="cuda") -@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_ge_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps(): model_device = evaluator_device = "mps" _test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) _test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) -@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_ge_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps_with_model_on_cpu(): _test_create_supervised_evaluator(evaluator_device="mps") _test_mocked_supervised_evaluator(evaluator_device="mps")