From 78b311f6e2efb6342fd5d991e74f8e220b036f61 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 8 Mar 2023 09:25:00 -0800 Subject: [PATCH] [FRONTEND] Fix cast when both `src_ty` and `dst_ty` are of block_type (#1301) Commonly used in atomic_rmw ops --- python/test/unit/language/test_core.py | 16 ++++++++++++++++ python/triton/language/semantic.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4f81f481eec5..cbcc15b7143b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -683,6 +683,22 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) +def test_tensor_atomic_rmw_block(device="cuda"): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2,)](x, shape[0], shape[1]) + assert torch.min(x).item() == 0.0 + + def test_atomic_cas(): # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b908b21df478..f26e633bbb99 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -641,7 +641,7 @@ def bitcast(input: tl.tensor, builder: ir.builder) -> tl.tensor: src_ty = input.type if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: return input src_sca_ty = src_ty.scalar @@ -665,7 +665,7 @@ def cast(input: tl.tensor, if isinstance(dst_ty, tl.constexpr): dst_ty = dst_ty.value if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: return input