diff --git a/tests/ignite/__init__.py b/tests/ignite/__init__.py index d553c222e58..8f84e2e74b9 100644 --- a/tests/ignite/__init__.py +++ b/tests/ignite/__init__.py @@ -3,3 +3,14 @@ def cpu_and_maybe_cuda(): return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ()) + + +def is_mps_available_and_functional(): + if not torch.backends.mps.is_available(): + return False + try: + # Try to allocate a small tensor on the MPS device + torch.tensor([1.0], device="mps") + return True + except RuntimeError: + return False diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index 761e328944c..a53999c1cb6 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -12,6 +12,7 @@ import ignite.distributed as idist from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler +from tests.ignite import is_mps_available_and_functional class DummyDS(Dataset): @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device): assert optimizer.backward_passes_per_step == backward_passes_per_step +@pytest.mark.skipif( + torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional" +) def test_auto_methods_no_dist(): _test_auto_dataloader(1, 1, batch_size=1) _test_auto_dataloader(1, 1, batch_size=10, num_workers=2) diff --git a/tests/ignite/distributed/test_launcher.py b/tests/ignite/distributed/test_launcher.py index b12e2acf1c2..8cc1001aa74 100644 --- a/tests/ignite/distributed/test_launcher.py +++ b/tests/ignite/distributed/test_launcher.py @@ -9,6 +9,7 @@ import ignite.distributed as idist from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support +from tests.ignite import is_mps_available_and_functional def test_parallel_wrong_inputs(): @@ -54,6 +55,9 @@ def execute(cmd, env=None): return str(process.stdout.read()) + str(process.stderr.read()) +@pytest.mark.skipif( + torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional" +) def test_check_idist_parallel_no_dist(exec_filepath): cmd = [sys.executable, "-u", exec_filepath] out = execute(cmd) diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 31ca43f4bbf..d9b0c161f75 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -25,6 +25,8 @@ ) from ignite.metrics import MeanSquaredError +from tests.ignite import is_mps_available_and_functional + class DummyModel(torch.nn.Module): def __init__(self, output_as_list=False): @@ -485,7 +487,7 @@ def test_create_supervised_trainer_on_cuda(): _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_trainer_on_mps(): model_device = trainer_device = "mps" _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) @@ -666,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu(): _test_mocked_supervised_evaluator(evaluator_device="cuda") -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps(): model_device = evaluator_device = "mps" _test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) _test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps_with_model_on_cpu(): _test_create_supervised_evaluator(evaluator_device="mps") _test_mocked_supervised_evaluator(evaluator_device="mps")