Skip to content

Commit

Permalink
distinguish between available and functional
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed May 21, 2024
1 parent 861ff97 commit 8f7275f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from tests.ignite import is_mps_available_and_functional


class DummyDS(Dataset):
Expand Down Expand Up @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device):
assert optimizer.backward_passes_per_step == backward_passes_per_step


@pytest.mark.skipif(
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ignite.distributed as idist
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support
from tests.ignite.distributed.utils import is_mps_available_and_functional


def test_parallel_wrong_inputs():
Expand Down Expand Up @@ -54,6 +55,9 @@ def execute(cmd, env=None):
return str(process.stdout.read()) + str(process.stderr.read())


@pytest.mark.skipif(
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
)
def test_check_idist_parallel_no_dist(exec_filepath):
cmd = [sys.executable, "-u", exec_filepath]
out = execute(cmd)
Expand Down
5 changes: 2 additions & 3 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from tests.ignite import is_mps_available_and_functional
from tests.ignite.distributed.utils import (
_sanity_check,
_test_distrib__get_max_length,
Expand All @@ -19,7 +18,7 @@ def test_no_distrib(capsys):
assert idist.backend() is None
if torch.cuda.is_available():
assert idist.device().type == "cuda"
elif _torch_version_le_112 and is_mps_available_and_functional():
elif _torch_version_le_112 and torch.backends.mps.is_available():
assert idist.device().type == "mps"
else:
assert idist.device().type == "cpu"
Expand All @@ -42,7 +41,7 @@ def test_no_distrib(capsys):
assert "ignite.distributed.utils INFO: backend: None" in out[-1]
if torch.cuda.is_available():
assert "ignite.distributed.utils INFO: device: cuda" in out[-1]
elif _torch_version_le_112 and is_mps_available_and_functional():
elif _torch_version_le_112 and torch.backends.mps.is_available():
assert "ignite.distributed.utils INFO: device: mps" in out[-1]
else:
assert "ignite.distributed.utils INFO: device: cpu" in out[-1]
Expand Down

0 comments on commit 8f7275f

Please sign in to comment.