diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 7186287429191..bbd26681f20b8 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -354,14 +354,14 @@ def mma_fill_impl(a: T.handle) -> None: return mma_fill_desc, mma_fill_impl -def get_mma_store_intrin(dtype, local_size): +def get_mma_store_intrin(dtype, local_size, scope="global"): # Assume M = N = 16 index_map = shared_16x16_to_ldmatrix_32x8_layout @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") - C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global") + C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) with T.block("root"): T.reads(C_warp[0:WARP_SIZE, 0:local_size]) @@ -454,11 +454,17 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32" TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8)) -MMA_store_16x16_f32_INTRIN = "mma_store_16x16_f32" -TensorIntrin.register(MMA_store_16x16_f32_INTRIN, *get_mma_store_intrin("float32", 8)) +MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_" +TensorIntrin.register( + MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, "global") +) -MMA_store_16x16_f16_INTRIN = "mma_store_16x16_f16" -TensorIntrin.register(MMA_store_16x16_f16_INTRIN, *get_mma_store_intrin("float16", 8)) +MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_" +TensorIntrin.register( + MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, "global") +) -MMA_store_16x16_i32_INTRIN = "mma_store_16x16_i32" -TensorIntrin.register(MMA_store_16x16_i32_INTRIN, *get_mma_store_intrin("int32", 8)) +MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_" +TensorIntrin.register( + MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global") +) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index d6b14bc1a8cfa..cd37866aabea6 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -33,9 +33,9 @@ MMA_fill_16x16_f32_INTRIN, MMA_fill_16x16_f16_INTRIN, MMA_fill_16x16_i32_INTRIN, - MMA_store_16x16_f32_INTRIN, - MMA_store_16x16_f16_INTRIN, - MMA_store_16x16_i32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + MMA_store_16x16_i32_global_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, shared_32x16_to_ldmatrix_32x16_layout, shared_16x32_to_ldmatrix_32x16_layout, @@ -249,7 +249,7 @@ def index_map(i, j): LDMATRIX_16x16_B_INTRIN, MMA_f16f16f32_INTRIN, MMA_fill_16x16_f32_INTRIN, - MMA_store_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, ) if measure_perf: @@ -270,7 +270,7 @@ def index_map(i, j): LDMATRIX_16x16_B_TRANS_INTRIN, MMA_f16f16f32_TRANS_INTRIN, MMA_fill_16x16_f32_INTRIN, - MMA_store_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, ) if measure_perf: @@ -305,7 +305,7 @@ def index_map(i, j): LDMATRIX_16x16_B_INTRIN, MMA_f16f16f16_INTRIN, MMA_fill_16x16_f16_INTRIN, - MMA_store_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, ) if measure_perf: @@ -326,7 +326,7 @@ def index_map(i, j): LDMATRIX_16x16_B_TRANS_INTRIN, MMA_f16f16f16_TRANS_INTRIN, MMA_fill_16x16_f16_INTRIN, - MMA_store_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, ) if measure_perf: @@ -375,7 +375,7 @@ def index_map_C(i, j): LDMATRIX_32x16_B_INTRIN, MMA_i8i8i32_INTRIN, MMA_fill_16x16_i32_INTRIN, - MMA_store_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, ) if measure_perf: @@ -396,7 +396,7 @@ def index_map_C(i, j): LDMATRIX_16x32_B_TRANS_INTRIN, MMA_i8i8i32_TRANS_INTRIN, MMA_fill_16x16_i32_INTRIN, - MMA_store_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, ) if measure_perf: