From fcbd65d8cd48488a0ea998f86a93e9b02b263aca Mon Sep 17 00:00:00 2001 From: Archermmt Date: Wed, 3 Jan 2024 08:18:15 +0800 Subject: [PATCH] change requires_gpu to requires_cuda --- tests/python/contrib/test_msc/test_runner.py | 8 ++++---- tests/python/contrib/test_msc/test_tools.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index f9133e9c1534e..e3d5bcf245032 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -109,8 +109,8 @@ def test_tvm_runner_cpu(): _test_from_torch(TVMRunner, "cpu", is_training=True) -@tvm.testing.requires_gpu -def test_tvm_runner_gpu(): +@tvm.testing.requires_cuda +def test_tvm_runner_cuda(): """Test runner for tvm on cuda""" _test_from_torch(TVMRunner, "cuda", is_training=True) @@ -122,8 +122,8 @@ def test_torch_runner_cpu(): _test_from_torch(TorchRunner, "cpu") -@tvm.testing.requires_gpu -def test_torch_runner_gpu(): +@tvm.testing.requires_cuda +def test_torch_runner_cuda(): """Test runner for torch on cuda""" _test_from_torch(TorchRunner, "cuda", atol=1e-1, rtol=1e-1) diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 9216761bcb256..8fa9e5cf10ccc 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -251,7 +251,7 @@ def test_tvm_tool(tool_type): ) -@tvm.testing.requires_gpu +@tvm.testing.requires_cuda @pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER]) def test_tvm_distill(tool_type): """Test tools for tvm with distiller"""