From bf23fc50f0ffa99e875d9247ca66acec0c36677f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 17 May 2022 12:23:30 +0900 Subject: [PATCH] mma intrin generation with meta programming --- python/tvm/tir/tensor_intrin/cuda.py | 163 +++++++++++++++++- .../unittest/test_mma_16x8x16_4k_tune.py | 82 +-------- .../test_mma_16x8x16_4k_tune_trans.py | 84 +-------- .../unittest/test_mma_16x8x16_fp16_4k_tune.py | 82 +-------- .../test_mma_16x8x16_fp16_4k_tune_trans.py | 82 +-------- .../unittest/test_mma_16x8x32_4k_tune.py | 83 +-------- .../test_mma_16x8x32_4k_tune_trans.py | 83 +-------- .../unittest/test_mma_16x8x8_4k_tune.py | 18 +- 8 files changed, 188 insertions(+), 489 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index ec0511c8ed5f..d1e67cae1192 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for tensorization on NVIDIA GPU.""" +from .. import Cast from ..._ffi import register_func from ...runtime import convert from .. import TensorIntrin @@ -46,6 +47,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j): lift = convert M_DIM = 16 +N_DIM = 16 WARP_SIZE = 32 HALF_WARP = WARP_SIZE // 2 HALF_WARP_expr = lift(HALF_WARP) @@ -81,7 +83,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): assert dtype == "int8" if ldmatrix_col_major: - print("foo") index_map = shared_32x16_to_ldmatrix_32x16_layout shared_offset = ( lambda _, stride: stride @@ -172,6 +173,148 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: return ldmatrix_desc, ldmatrix_impl +def get_mma_intrin(k_dim, out_dtype, transposed): + local_size = (M_DIM * k_dim) // WARP_SIZE + local_size_out = (M_DIM * N_DIM) // 32 + + index_map_C = shared_16x16_to_ldmatrix_32x8_layout + + if k_dim == 16: + index_map_A = shared_16x16_to_ldmatrix_32x8_layout + index_map_B = shared_16x16_to_ldmatrix_32x8_layout + mma_prefix = "m16n8k16" + elif k_dim == 32 and transposed: + index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout + mma_prefix = "m16n8k32" + elif k_dim == 32 and not transposed: + index_map_A = shared_16x32_to_ldmatrix_32x16_layout + index_map_B = shared_32x16_to_ldmatrix_32x16_layout + mma_prefix = "m16n8k32" + else: + assert False + + out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype] + + if out_dtype in ["float16", "float32"]: + in_dtype = "float16" + in_dtype_abbrv = "fp16" + else: + in_dtype = "int8" + in_dtype_abbrv = "int8" + + def maybe_cast(v): + if out_dtype in ["float32", "int32"]: + return Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if transposed: + return j, i + return i, j + + @T.prim_func + def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + B = T.match_buffer( + b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + + for i, j, k in T.grid(M_DIM, N_DIM, k_dim): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + b_row_ind, b_col_ind = maybe_swap(k, j) + + thread_id_C, local_id_C = index_map_C(i, j) + thread_id_A, local_id_A = index_map_A(i, k) + thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind) + + T.reads( + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], + ) + T.writes(C[thread_id_C, local_id_C]) + + C[thread_id_C, local_id_C] += maybe_cast( + A[thread_id_A, local_id_A] + ) * maybe_cast(B[thread_id_B, local_id_B]) + + @T.prim_func + def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + B = T.match_buffer( + b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size), + C.data, + C.elem_offset + tx * lift(local_size_out), + False, + dtype=out_dtype, + ) + ) + + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, + C.data, + C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, + False, + dtype=out_dtype, + ) + ) + + return mma_sync_desc, mma_sync_impl + + LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a" TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False)) @@ -191,3 +334,21 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans" TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True)) + +MMA_f16f16f32_INTRIN = "mma_f16f16f32" +TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False)) + +MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans" +TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, "float32", True)) + +MMA_f16f16f16_INTRIN = "mma_f16f16f16" +TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False)) + +MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans" +TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, "float16", True)) + +MMA_i8i8i32_INTRIN = "mma_i8i8i32" +TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False)) + +MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans" +TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True)) diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index 4378c46d14b3..043ab4a345e5 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -7,90 +7,13 @@ from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_INTRIN, + MMA_f16f16f32_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) import tvm.testing import numpy as np -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - for i, j, k in T.grid(16, 16, 16): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) - thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j) - - T.reads( - C[thread_id_C, local_id_C], - A[thread_id_A, local_id_A], - B[thread_id_B, local_id_B], - ) - T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += T.cast( - A[thread_id_A, local_id_A], "float32" - ) * T.cast(B[thread_id_B, local_id_B], "float32") - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8, - C.data, - C.elem_offset + tx * 8, - False, - dtype="float32", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8 + 4, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="float32", - ) - ) - - @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp") @@ -160,7 +83,6 @@ def mma_fill_impl(a: T.handle) -> None: T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) @@ -291,7 +213,7 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN) - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_INTRIN) sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py index 097ac1d55a71..cc6032846825 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py @@ -6,6 +6,7 @@ from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_TRANS_INTRIN, + MMA_f16f16f32_TRANS_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) from tvm import meta_schedule as ms @@ -13,85 +14,6 @@ import numpy as np -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - for i, j, k in T.grid(16, 16, 16): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) - thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(j, k) - - T.reads( - C[thread_id_C, local_id_C], - A[thread_id_A, local_id_A], - B[thread_id_B, local_id_B], - ) - T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += T.cast( - A[thread_id_A, local_id_A], "float32" - ) * T.cast(B[thread_id_B, local_id_B], "float32") - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8, - C.data, - C.elem_offset + tx * 8, - False, - dtype="float32", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8 + 4, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="float32", - ) - ) - - @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp") @@ -161,7 +83,6 @@ def mma_fill_impl(a: T.handle) -> None: T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) @@ -292,7 +213,6 @@ def tile_wmma_fragment(block_read, height): sch.reorder(i0, j0, i1, j1) return i1 - loop_a = tile_wmma_fragment(A_warp, 16) loop_b = tile_wmma_fragment(B_warp, 16) @@ -309,8 +229,8 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_TRANS_INTRIN) + sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_TRANS_INTRIN) - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") diff --git a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py index bd56eacda249..f0f59c0f5209 100644 --- a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py @@ -6,6 +6,7 @@ from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_INTRIN, + MMA_f16f16f16_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) from tvm import meta_schedule as ms @@ -13,84 +14,6 @@ import numpy as np -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - for i, j, k in T.grid(16, 16, 16): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) - thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j) - - T.reads( - C[thread_id_C, local_id_C], - A[thread_id_A, local_id_A], - B[thread_id_B, local_id_B], - ) - T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += ( - A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B] - ) - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp16", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8, - C.data, - C.elem_offset + tx * 8, - False, - dtype="float16", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp16", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8 + 4, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="float16", - ) - ) - - @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp") @@ -160,7 +83,6 @@ def mma_fill_impl(a: T.handle) -> None: T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float16")) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) @@ -306,7 +228,7 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN) - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f16_INTRIN) sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") diff --git a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py index 688bf1bf18cd..d716016a6130 100644 --- a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py @@ -6,6 +6,7 @@ from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_TRANS_INTRIN, + MMA_f16f16f16_TRANS_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) from tvm import meta_schedule as ms @@ -13,84 +14,6 @@ import numpy as np -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - for i, j, k in T.grid(16, 16, 16): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) - thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(j, k) - - T.reads( - C[thread_id_C, local_id_C], - A[thread_id_A, local_id_A], - B[thread_id_B, local_id_B], - ) - T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += ( - A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B] - ) - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp16", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8, - C.data, - C.elem_offset + tx * 8, - False, - dtype="float16", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp16", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8 + 4, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="float16", - ) - ) - - @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp") @@ -160,7 +83,6 @@ def mma_fill_impl(a: T.handle) -> None: T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float16")) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) @@ -308,7 +230,7 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_TRANS_INTRIN) - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f16_TRANS_INTRIN) sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") diff --git a/tests/python/unittest/test_mma_16x8x32_4k_tune.py b/tests/python/unittest/test_mma_16x8x32_4k_tune.py index 9a21cc6f3402..b504114872cf 100644 --- a/tests/python/unittest/test_mma_16x8x32_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x32_4k_tune.py @@ -5,6 +5,7 @@ from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x32_A_INTRIN, LDMATRIX_32x16_B_INTRIN, + MMA_i8i8i32_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, shared_32x16_to_ldmatrix_32x16_layout, shared_16x32_to_ldmatrix_32x16_layout, @@ -15,85 +16,6 @@ import numpy as np -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) - T.writes(C[0:32, 0:8]) - for i, j, k in T.grid(16, 16, 32): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x32_to_ldmatrix_32x16_layout(i, k) - thread_id_B, local_id_B = shared_32x16_to_ldmatrix_32x16_layout(k, j) - - T.reads( - C[thread_id_C, local_id_C], - A[thread_id_A, local_id_A], - B[thread_id_B, local_id_B], - ) - T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += T.cast(A[thread_id_A, local_id_A], "int32") * T.cast( - B[thread_id_B, local_id_B], "int32" - ) - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k32", - "row", - "col", - "int8", - "int8", - "int32", - A.data, - A.elem_offset + tx * 16, - B.data, - B.elem_offset + tx * 16, - C.data, - C.elem_offset + tx * 8, - False, - dtype="int32", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k32", - "row", - "col", - "int8", - "int8", - "int32", - A.data, - A.elem_offset + tx * 16, - B.data, - B.elem_offset + tx * 16 + 8, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="int32", - ) - ) - - @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") @@ -163,7 +85,6 @@ def mma_fill_impl(a: T.handle) -> None: T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32")) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) @@ -327,7 +248,7 @@ def index_map_C(i, j): sch.tensorize(loop_a, LDMATRIX_16x32_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_32x16_B_INTRIN) - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_inner)[-3], MMA_i8i8i32_INTRIN) # "mma_sync") sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") diff --git a/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py index 31122dffec13..d2b4c6b9cb26 100644 --- a/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py @@ -6,6 +6,7 @@ from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x32_A_INTRIN, LDMATRIX_16x32_B_TRANS_INTRIN, + MMA_i8i8i32_TRANS_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, shared_16x32_to_ldmatrix_32x16_layout, ) @@ -15,83 +16,6 @@ import numpy as np -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) - T.writes(C[0:32, 0:8]) - for i, j, k in T.grid(16, 16, 32): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) - thread_id_A, local_id_A = shared_16x32_to_ldmatrix_32x16_layout(i, k) - thread_id_B, local_id_B = shared_16x32_to_ldmatrix_32x16_layout(j, k) - - T.reads( - C[thread_id_C, local_id_C], - A[thread_id_A, local_id_A], - B[thread_id_B, local_id_B], - ) - T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += T.cast(A[thread_id_A, local_id_A], "int32") * T.cast( - B[thread_id_B, local_id_B], "int32" - ) - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k32", - "row", - "col", - "int8", - "int8", - "int32", - A.data, - A.elem_offset + tx * 16, - B.data, - B.elem_offset + tx * 16, - C.data, - C.elem_offset + tx * 8, - False, - dtype="int32", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k32", - "row", - "col", - "int8", - "int8", - "int32", - A.data, - A.elem_offset + tx * 16, - B.data, - B.elem_offset + tx * 16 + 8, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="int32", - ) - ) - - @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") @@ -161,7 +85,6 @@ def mma_fill_impl(a: T.handle) -> None: T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32")) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) @@ -296,7 +219,6 @@ def tile_wmma_fragment(block_read, height, width): loop_a = tile_wmma_fragment(A_warp, 16, 32) loop_b = tile_wmma_fragment(B_warp, 16, 32) - def index_map_A_B(i, j): return ( i // 16, @@ -304,7 +226,6 @@ def index_map_A_B(i, j): *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), ) - def index_map_C(i, j): return ( i // 16, @@ -318,7 +239,7 @@ def index_map_C(i, j): sch.tensorize(loop_a, LDMATRIX_16x32_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x32_B_TRANS_INTRIN) - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_inner)[-3], MMA_i8i8i32_TRANS_INTRIN) sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") diff --git a/tests/python/unittest/test_mma_16x8x8_4k_tune.py b/tests/python/unittest/test_mma_16x8x8_4k_tune.py index ada530208089..703a219810a4 100644 --- a/tests/python/unittest/test_mma_16x8x8_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x8_4k_tune.py @@ -188,7 +188,9 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: s0 = T.var("int32") C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp", offset_factor=1) - C = T.match_buffer(c, [16, 8], dtype="float32", scope="global",offset_factor=1, strides=[s1, s0]) + C = T.match_buffer( + c, [16, 8], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0] + ) with T.block("root"): T.reads(C_warp[0:32, 0:4]) @@ -196,7 +198,11 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: tx = T.env_thread("threadIdx.x") T.launch_thread(tx, 32) - T.evaluate(T.mma_store(16, 8, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32")) + T.evaluate( + T.mma_store( + 16, 8, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32" + ) + ) @T.prim_func @@ -211,8 +217,12 @@ def mma_fill_desc(a: T.handle) -> None: i_init = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4) j_init = T.axis.spatial(8, (i0 % 4) * 2 + i1 % 2) T.reads() - T.writes(C_warp[i_init % 8 * 4 + j_init % 8 // 2, i_init % 16 // 8 * 2 + j_init % 2]) - C_warp[i_init % 8 * 4 + j_init % 8 // 2, i_init % 16 // 8 * 2 + j_init % 2] = T.float32(0) + T.writes( + C_warp[i_init % 8 * 4 + j_init % 8 // 2, i_init % 16 // 8 * 2 + j_init % 2] + ) + C_warp[ + i_init % 8 * 4 + j_init % 8 // 2, i_init % 16 // 8 * 2 + j_init % 2 + ] = T.float32(0) @T.prim_func