diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index feafde234d..df81c24ff5 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -21,6 +21,11 @@ DeepEval, ) +from ...utils import ( + CI, + TEST_DEVICE, +) + infer_path = Path(__file__).parent.parent.parent / "infer" @@ -66,6 +71,7 @@ def tearDown(self): elif Path(ii).is_dir(): shutil.rmtree(ii) + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):