From 2054d08f944b7c5e2db860a51e2318645f453c8a Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 29 May 2024 11:13:44 +0100 Subject: [PATCH] fix mps errors for older pytorch versions --- tests/ignite/distributed/test_auto.py | 4 +++- tests/ignite/distributed/test_launcher.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index a53999c1cb6..5ad12d95e50 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -12,6 +12,7 @@ import ignite.distributed as idist from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler +from ignite.distributed.comp_models.base import _torch_version_le_112 from tests.ignite import is_mps_available_and_functional @@ -181,7 +182,8 @@ def _test_auto_model_optimizer(ws, device): @pytest.mark.skipif( - torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional" + _torch_version_le_112 or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), + reason="Skip if MPS not functional", ) def test_auto_methods_no_dist(): _test_auto_dataloader(1, 1, batch_size=1) diff --git a/tests/ignite/distributed/test_launcher.py b/tests/ignite/distributed/test_launcher.py index 8cc1001aa74..d6682c4a56b 100644 --- a/tests/ignite/distributed/test_launcher.py +++ b/tests/ignite/distributed/test_launcher.py @@ -8,6 +8,7 @@ from packaging.version import Version import ignite.distributed as idist +from ignite.distributed.comp_models.base import _torch_version_le_112 from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support from tests.ignite import is_mps_available_and_functional @@ -56,7 +57,8 @@ def execute(cmd, env=None): @pytest.mark.skipif( - torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional" + _torch_version_le_112 or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), + reason="Skip if MPS not functional", ) def test_check_idist_parallel_no_dist(exec_filepath): cmd = [sys.executable, "-u", exec_filepath]