Skip to content

Commit

Permalink
[FRONTEND] Fix cast when both src_ty and dst_ty are of block_type (
Browse files Browse the repository at this point in the history
…#1301)

Commonly used in atomic_rmw ops
  • Loading branch information
Jokeren committed Mar 8, 2023
1 parent f5c9f9b commit 78b311f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
16 changes: 16 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,22 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)


def test_tensor_atomic_rmw_block(device="cuda"):
shape = (8, 8)

@triton.jit
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
offs = off0[:, None] * SHAPE1 + off1[None, :]
val = offs.to(tl.float32)
x = X + offs
tl.atomic_min(x, val)
x = torch.ones((8, 8), device=device, dtype=torch.float32)
kernel[(2,)](x, shape[0], shape[1])
assert torch.min(x).item() == 0.0


def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
Expand Down
4 changes: 2 additions & 2 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def bitcast(input: tl.tensor,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
Expand All @@ -665,7 +665,7 @@ def cast(input: tl.tensor,
if isinstance(dst_ty, tl.constexpr):
dst_ty = dst_ty.value
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
if src_ty == dst_ty:
return input

Expand Down

0 comments on commit 78b311f

Please sign in to comment.