Skip to content

Commit

Permalink
added version check and skipped mps tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Nov 21, 2023
1 parent ff9a2cd commit d1dfbec
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 5 deletions.
5 changes: 4 additions & 1 deletion ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any, Callable, cast, List, Optional, Union

import torch
from packaging.version import Version

_torch_version_le_112 = Version(torch.__version__) > Version("1.12.0")


class ComputationModel(metaclass=ABCMeta):
Expand Down Expand Up @@ -326,7 +329,7 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
if _torch_version_le_112 and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")

Expand Down
5 changes: 4 additions & 1 deletion tests/ignite/distributed/comp_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
import torch

from ignite.distributed.comp_models.base import _SerialModel, ComputationModel
from ignite.distributed.comp_models.base import _torch_version_le_112, _SerialModel, ComputationModel


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_serial_model():
_SerialModel.create_from_backend()
model = _SerialModel.create_from_context()
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ignite.distributed as idist
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
from ignite.distributed.comp_models.base import _torch_version_le_112


class DummyDS(Dataset):
Expand Down Expand Up @@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device):
assert optimizer.backward_passes_per_step == backward_passes_per_step


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from packaging.version import Version

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support


Expand Down Expand Up @@ -257,6 +258,9 @@ def test_idist_parallel_n_procs_native(init_method, backend, get_fixed_dirname,


@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_idist_parallel_no_dist():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with idist.Parallel(backend=None) as parallel:
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from tests.ignite.distributed.utils import (
_sanity_check,
_test_distrib__get_max_length,
Expand All @@ -13,6 +14,9 @@
)


@pytest.mark.skipif(
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
)
def test_no_distrib(capsys):
assert idist.backend() is None
if torch.cuda.is_available():
Expand Down
9 changes: 6 additions & 3 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.optim import SGD

import ignite.distributed as idist
from ignite.distributed.comp_models.base import _torch_version_le_112
from ignite.engine import (
_check_arg,
create_supervised_evaluator,
Expand Down Expand Up @@ -307,7 +308,9 @@ def _test_create_supervised_evaluator(
else:
if Version(torch.__version__) >= Version("1.7.0"):
# This is broken in 1.6.0 but will be probably fixed with 1.7.0
with pytest.raises(RuntimeError, match=r"Expected all tensors to be on the same device"):
err_msg_1 = "Expected all tensors to be on the same device"
err_msg_2 = "Placeholder storage has not been allocated on MPS device"
with pytest.raises(RuntimeError, match=f"({err_msg_1}|{err_msg_2})"):
evaluator.run(data)


Expand Down Expand Up @@ -659,14 +662,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
_test_mocked_supervised_evaluator(evaluator_device="cuda")


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS Backend")
@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps():
model_device = evaluator_device = "mps"
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS Backend")
@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
_test_create_supervised_evaluator(evaluator_device="mps")
_test_mocked_supervised_evaluator(evaluator_device="mps")
Expand Down

0 comments on commit d1dfbec

Please sign in to comment.