Skip to content

Commit

Permalink
Merge pull request triton-lang#19 from ROCmSoftwarePlatform/unskip_te…
Browse files Browse the repository at this point in the history
…st_reduce

reduce the skips for test_reduce functions
  • Loading branch information
rsanthanam-amd authored Nov 1, 2022
2 parents f3bcbcf + dfad6bd commit cc6b518
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,8 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
def test_reduce1d(op, dtype_str, shape, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
pytest.skip(f"test_reduce1d currently has segfaults on ROCM")
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
pytest.skip(f"test_reduce1d[{dtype_str}] skipped on ROCM")

# triton kernel
@triton.jit
Expand Down Expand Up @@ -953,7 +954,8 @@ def kernel(X, Z, BLOCK: tl.constexpr):
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
pytest.skip(f"test_reduce2d currently has segfaults on ROCM")
if dtype_str in ["int8", "int16", "uint8", "uint16"]:
pytest.skip(f"test_reduce2d[{dtype_str}] skipped on ROCM")
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
Expand Down

0 comments on commit cc6b518

Please sign in to comment.