diff --git a/tests/ignite/__init__.py b/tests/ignite/__init__.py index d553c222e58b..8f84e2e74b99 100644 --- a/tests/ignite/__init__.py +++ b/tests/ignite/__init__.py @@ -3,3 +3,14 @@ def cpu_and_maybe_cuda(): return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ()) + + +def is_mps_available_and_functional(): + if not torch.backends.mps.is_available(): + return False + try: + # Try to allocate a small tensor on the MPS device + torch.tensor([1.0], device="mps") + return True + except RuntimeError: + return False diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index d61fa8b75695..6aef0de221e4 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -2,6 +2,7 @@ import ignite.distributed as idist from ignite.distributed.comp_models.base import _torch_version_le_112 +from tests.ignite import is_mps_available_and_functional from tests.ignite.distributed.utils import ( _sanity_check, _test_distrib__get_max_length, @@ -12,7 +13,6 @@ _test_distrib_new_group, _test_sync, ) -from ....utils_for_tests import is_mps_available_and_functional def test_no_distrib(capsys): diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 2ecf7438b458..d9b0c161f75d 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -25,7 +25,7 @@ ) from ignite.metrics import MeanSquaredError -from ...utils_for_tests import is_mps_available_and_functional # type: ignore +from tests.ignite import is_mps_available_and_functional class DummyModel(torch.nn.Module): diff --git a/tests/utils_for_tests.py b/tests/utils_for_tests.py deleted file mode 100644 index 3076d080bfb9..000000000000 --- a/tests/utils_for_tests.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch - - -def is_mps_available_and_functional(): - if not torch.backends.mps.is_available(): - return False - try: - # Try to allocate a small tensor on the MPS device - torch.tensor([1.0], device="mps") - return True - except RuntimeError: - return False