diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index 761e328944c4..a53999c1cb69 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -12,6 +12,7 @@ import ignite.distributed as idist from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler +from tests.ignite import is_mps_available_and_functional class DummyDS(Dataset): @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device): assert optimizer.backward_passes_per_step == backward_passes_per_step +@pytest.mark.skipif( + torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional" +) def test_auto_methods_no_dist(): _test_auto_dataloader(1, 1, batch_size=1) _test_auto_dataloader(1, 1, batch_size=10, num_workers=2) diff --git a/tests/ignite/distributed/test_launcher.py b/tests/ignite/distributed/test_launcher.py index b12e2acf1c26..cf45fc76add7 100644 --- a/tests/ignite/distributed/test_launcher.py +++ b/tests/ignite/distributed/test_launcher.py @@ -9,6 +9,7 @@ import ignite.distributed as idist from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support +from tests.ignite.distributed.utils import is_mps_available_and_functional def test_parallel_wrong_inputs(): @@ -54,6 +55,9 @@ def execute(cmd, env=None): return str(process.stdout.read()) + str(process.stderr.read()) +@pytest.mark.skipif( + torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional" +) def test_check_idist_parallel_no_dist(exec_filepath): cmd = [sys.executable, "-u", exec_filepath] out = execute(cmd) diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index 6aef0de221e4..fdbf26e83608 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -2,7 +2,6 @@ import ignite.distributed as idist from ignite.distributed.comp_models.base import _torch_version_le_112 -from tests.ignite import is_mps_available_and_functional from tests.ignite.distributed.utils import ( _sanity_check, _test_distrib__get_max_length, @@ -19,7 +18,7 @@ def test_no_distrib(capsys): assert idist.backend() is None if torch.cuda.is_available(): assert idist.device().type == "cuda" - elif _torch_version_le_112 and is_mps_available_and_functional(): + elif _torch_version_le_112 and torch.backends.mps.is_available(): assert idist.device().type == "mps" else: assert idist.device().type == "cpu" @@ -42,7 +41,7 @@ def test_no_distrib(capsys): assert "ignite.distributed.utils INFO: backend: None" in out[-1] if torch.cuda.is_available(): assert "ignite.distributed.utils INFO: device: cuda" in out[-1] - elif _torch_version_le_112 and is_mps_available_and_functional(): + elif _torch_version_le_112 and torch.backends.mps.is_available(): assert "ignite.distributed.utils INFO: device: mps" in out[-1] else: assert "ignite.distributed.utils INFO: device: cpu" in out[-1]