Skip to content

Commit

Permalink
skip tests when mps not functional
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed May 21, 2024
1 parent 37d9a67 commit c98e9b1
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 3 deletions.
11 changes: 11 additions & 0 deletions tests/ignite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,14 @@

def cpu_and_maybe_cuda():
return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ())


def is_mps_available_and_functional():
if not torch.backends.mps.is_available():
return False
try:
# Try to allocate a small tensor on the MPS device
torch.tensor([1.0], device="mps")
return True
except RuntimeError:
return False
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from tests.ignite import is_mps_available_and_functional


class DummyDS(Dataset):
Expand Down Expand Up @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device):
assert optimizer.backward_passes_per_step == backward_passes_per_step


@pytest.mark.skipif(
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ignite.distributed as idist
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support
from tests.ignite import is_mps_available_and_functional


def test_parallel_wrong_inputs():
Expand Down Expand Up @@ -54,6 +55,9 @@ def execute(cmd, env=None):
return str(process.stdout.read()) + str(process.stderr.read())


@pytest.mark.skipif(
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
)
def test_check_idist_parallel_no_dist(exec_filepath):
cmd = [sys.executable, "-u", exec_filepath]
out = execute(cmd)
Expand Down
8 changes: 5 additions & 3 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from ignite.metrics import MeanSquaredError

from tests.ignite import is_mps_available_and_functional


class DummyModel(torch.nn.Module):
def __init__(self, output_as_list=False):
Expand Down Expand Up @@ -485,7 +487,7 @@ def test_create_supervised_trainer_on_cuda():
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
def test_create_supervised_trainer_on_mps():
model_device = trainer_device = "mps"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
Expand Down Expand Up @@ -666,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
_test_mocked_supervised_evaluator(evaluator_device="cuda")


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps():
model_device = evaluator_device = "mps"
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)


@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
_test_create_supervised_evaluator(evaluator_device="mps")
_test_mocked_supervised_evaluator(evaluator_device="mps")
Expand Down

0 comments on commit c98e9b1

Please sign in to comment.