Skip to content

Commit

Permalink
Merge pull request #18 from ROCmSoftwarePlatform/fix_test_dot
Browse files Browse the repository at this point in the history
Fix 6/7 test dot
  • Loading branch information
rsanthanam-amd authored Nov 1, 2022
2 parents f40e891 + 5f07ef2 commit 562e31d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
7 changes: 4 additions & 3 deletions lib/codegen/selection/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2973,11 +2973,11 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
bool is_outer = NK == 1;

#ifdef USE_ROCM
bool is_mma = layouts_->get(dot)->to_mma();
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
#else
bool is_mma = false;
#endif
bool is_mma = layouts_->get(dot)->to_mma();
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
Expand All @@ -2986,6 +2986,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
A->get_type()->get_scalar_ty()->is_fp32_ty())
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
throw std::runtime_error("dot has invalid operand type");
#endif
}

void generator::visit_trans_inst(ir::trans_inst* trans) {
Expand Down
36 changes: 21 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,12 +1070,15 @@ def kernel(X, stride_xm, stride_xn,
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if torch.version.hip is not None:
pass
else:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")

M, N, K = 128, 128, 64
num_warps = 8
Expand Down Expand Up @@ -1170,15 +1173,18 @@ def kernel(X, stride_xm, stride_xk,
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
if torch.version.hip is not None:
pass
else:
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx


def test_dot_without_load():
Expand Down
7 changes: 7 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from triton._C.libtriton.triton import ir


import torch

# Create custom exception that prints message "hello"
class IncompatibleTypeErrorimpl(Exception):
def __init__(self, type_a, type_b):
Expand Down Expand Up @@ -969,6 +971,11 @@ def dot(a: tl.tensor,
trans_b: bool,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:

if torch.version.hip is not None:
a = cast(a, tl.float32, builder)
b = cast(b, tl.float32, builder)

in_a = 1 if not trans_a else 0
in_b = 1 if trans_b else 0
assert a.type.is_block() and b.type.is_block()
Expand Down

0 comments on commit 562e31d

Please sign in to comment.