Skip to content

Commit

Permalink
fix mps errors for older pytorch versions
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed May 29, 2024
1 parent 2692a7f commit 2054d08
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from ignite.distributed.comp_models.base import _torch_version_le_112
from tests.ignite import is_mps_available_and_functional


Expand Down Expand Up @@ -181,7 +182,8 @@ def _test_auto_model_optimizer(ws, device):


@pytest.mark.skipif(
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
_torch_version_le_112 or (torch.backends.mps.is_available() and not is_mps_available_and_functional()),
reason="Skip if MPS not functional",
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
Expand Down
4 changes: 3 additions & 1 deletion tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from packaging.version import Version

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support
from tests.ignite import is_mps_available_and_functional

Expand Down Expand Up @@ -56,7 +57,8 @@ def execute(cmd, env=None):


@pytest.mark.skipif(
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
_torch_version_le_112 or (torch.backends.mps.is_available() and not is_mps_available_and_functional()),
reason="Skip if MPS not functional",
)
def test_check_idist_parallel_no_dist(exec_filepath):
cmd = [sys.executable, "-u", exec_filepath]
Expand Down

0 comments on commit 2054d08

Please sign in to comment.