Skip to content

Commit

Permalink
skip tests when mps not functional
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed May 21, 2024
1 parent 37d9a67 commit 655c12a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 3 additions & 2 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
_test_distrib_new_group,
_test_sync,
)
from ....utils_for_tests import is_mps_available_and_functional


def test_no_distrib(capsys):
assert idist.backend() is None
if torch.cuda.is_available():
assert idist.device().type == "cuda"
elif _torch_version_le_112 and torch.backends.mps.is_available():
elif _torch_version_le_112 and is_mps_available_and_functional():
assert idist.device().type == "mps"
else:
assert idist.device().type == "cpu"
Expand All @@ -41,7 +42,7 @@ def test_no_distrib(capsys):
assert "ignite.distributed.utils INFO: backend: None" in out[-1]
if torch.cuda.is_available():
assert "ignite.distributed.utils INFO: device: cuda" in out[-1]
elif _torch_version_le_112 and torch.backends.mps.is_available():
elif _torch_version_le_112 and is_mps_available_and_functional():
assert "ignite.distributed.utils INFO: device: mps" in out[-1]
else:
assert "ignite.distributed.utils INFO: device: cpu" in out[-1]
Expand Down
8 changes: 5 additions & 3 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from ignite.metrics import MeanSquaredError

from ...utils_for_tests import is_mps_available_and_functional # type: ignore


class DummyModel(torch.nn.Module):
def __init__(self, output_as_list=False):
Expand Down Expand Up @@ -485,7 +487,7 @@ def test_create_supervised_trainer_on_cuda():
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
def test_create_supervised_trainer_on_mps():
model_device = trainer_device = "mps"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
Expand Down Expand Up @@ -666,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
_test_mocked_supervised_evaluator(evaluator_device="cuda")


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps():
model_device = evaluator_device = "mps"
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
_test_create_supervised_evaluator(evaluator_device="mps")
_test_mocked_supervised_evaluator(evaluator_device="mps")
Expand Down
12 changes: 12 additions & 0 deletions tests/utils_for_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch


def is_mps_available_and_functional():
if not torch.backends.mps.is_available():
return False
try:
# Try to allocate a small tensor on the MPS device
torch.tensor([1.0], device="mps")
return True
except RuntimeError:
return False

0 comments on commit 655c12a

Please sign in to comment.