From 64f980ad69fc18e6ccbd6d1ed3ca3034716bee64 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 1 Nov 2022 15:00:12 +0000 Subject: [PATCH] reduce the skips for test_reduce functions --- python/test/unit/language/test_core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 00241403c407..cad10c871a3a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -892,7 +892,8 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): def test_reduce1d(op, dtype_str, shape, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested if torch.version.hip is not None: - pytest.skip(f"test_reduce1d currently has segfaults on ROCM") + if dtype_str in ["int8", "int16", "uint8", "uint16"]: + pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM") # triton kernel @triton.jit @@ -953,7 +954,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested if torch.version.hip is not None: - pytest.skip(f"test_reduce2d currently has segfaults on ROCM") + if dtype_str in ["int8", "int16", "uint8", "uint16"]: + pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM") # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):