diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py index d94aefe3959f..c3d3e848b30e 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py @@ -8,6 +8,17 @@ import numpy as np +def shared_16x16_to_ldmatrix_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + + +@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") +def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j): + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j) + return tvm.runtime.convert([thread_id, local_id]) + + @T.prim_func def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") @@ -21,10 +32,10 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: with T.block("A_shared_warp"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A_shared[v0, v1]) - T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) - A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[ - v0, v1 - ] + + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) + T.writes(A_warp[thread_id, local_id]) + A_warp[thread_id, local_id] = A_shared[v0, v1] @T.prim_func @@ -74,10 +85,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: with T.block("B_shared_warp"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B_shared[v0, v1]) - T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) - B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[ - v0, v1 - ] + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) + T.writes(B_warp[thread_id, local_id]) + B_warp[thread_id, local_id] = B_shared[v0, v1] @T.prim_func @@ -126,21 +136,20 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j, k in T.grid(16, 16, 16): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) + + thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) + thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) + thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(j, k) + T.reads( - C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2], - A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2], - B[j % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 2], - ) - T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2]) - C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[ - i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2 - ] + T.cast( - A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2], - "float32", - ) * T.cast( - B[j % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 8 % 2], - "float32", + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], ) + T.writes(C[thread_id_C, local_id_C]) + C[thread_id_C, local_id_C] += T.cast( + A[thread_id_A, local_id_A], "float32" + ) * T.cast(B[thread_id_B, local_id_B], "float32") @T.prim_func @@ -202,14 +211,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: with T.block("root"): T.reads(C_warp[0:32, 0:8]) T.writes(C[0:16, 0:16]) - for ax1_0, i0, i1 in T.grid(2, 32, 4): + for i0, i1 in T.grid(16, 16): with T.block("C_warp"): - v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4) - v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2) - - T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) + v0, v1 = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) + T.reads(C_warp[thread_id, local_id]) T.writes(C[v0, v1]) - C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] + C[v0, v1] = C_warp[thread_id, local_id] @T.prim_func @@ -242,21 +250,13 @@ def mma_fill_desc(a: T.handle) -> None: with T.block("root"): T.reads() T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(32, 8): + for i0, i1 in T.grid(16, 16): with T.block("C_warp"): - i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4) - j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4) + i_init, j_init = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) T.reads() - T.writes( - C_warp[ - i_init % 8 * 4 + j_init % 8 // 2, - j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2, - ] - ) - C_warp[ - i_init % 8 * 4 + j_init % 8 // 2, - j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2, - ] = T.float32(0) + T.writes(C_warp[thread_id, local_id]) + C_warp[thread_id, local_id] = T.float32(0) @T.prim_func @@ -397,7 +397,8 @@ def fetch_to_shared(block, idx, ndim): jo, ji = sch.split(jj, factors=[None, 16]) sch.reorder(io, jo, ii, ji) - block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") def tile_wmma_fragment(block_read, height): i, j = sch.get_loops(block_read)[-2:] @@ -406,47 +407,26 @@ def tile_wmma_fragment(block_read, height): sch.reorder(i0, j0, i1, j1) return i1 - def shared_16x16_to_ldmatrix_32x8_layout(i, j): - i_0 = i // 16 - j_0 = j // 16 - - i = i % 16 - j = j % 16 - - thread_id = 4 * (i % 8) + (j % 8) // 2 - return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2 loop_a = tile_wmma_fragment(A_warp, 16) loop_b = tile_wmma_fragment(B_warp, 16) - sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) - sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) - sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout) + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + sch.transform_layout(A_warp, 0, "write", index_map) + sch.transform_layout(B_warp, 0, "write", index_map) + sch.transform_layout(C_warp, 0, "read", index_map) sch.tensorize(loop_a, "mma.ldmatrix_a") sch.tensorize(loop_b, "mma.ldmatrix_b") - - mma_loop = sch.get_loops(block_inner)[-3] - sch.tensorize(mma_loop, "mma_sync") - - block_init_c = sch.get_block("C_init") - init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:] - f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) - f_2, f_3 = sch.split(init_loop2, factors=[None, 4]) - sch.reorder(f_1, f_2, f_0, f_3) - fused_1 = sch.fuse(f_1, f_2) - fused_2 = sch.fuse(f_0, f_3) - sch.tensorize(fused_1, "mma_fill") - - warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] - f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) - outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2]) - sch.reorder(outer, f_1, f_2, f_0, f_3) - fused_1 = sch.fuse(f_1, f_2) - fused_2 = sch.fuse(f_0, f_3) - sch.tensorize(outer, "mma_store") - # print(sch.mod.script()) - # return + sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") + sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") ir_module = tvm.IRModule({"main": workload}) @@ -454,42 +434,42 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j): schedule(sch) print(sch.mod.script()) -if tune: - with tempfile.TemporaryDirectory() as work_dir: - sch = ms.tune_tir( - mod=workload, - target=tvm.target.Target("nvidia/geforce-rtx-3070"), - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=32, - max_trials_per_task=128, - max_trials_global=128, - ), - work_dir=work_dir, - space=ms.space_generator.ScheduleFn(schedule), - ) - if sch is None: - print("No valid schedule found!") - else: - print(sch.mod.script()) - print(sch.trace) - - -dev = tvm.device("cuda", 0) -a_np = np.random.uniform(size=(N, K)).astype("float16") -b_np = np.random.uniform(size=(K, M)).astype("float16") -c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()) -a = tvm.nd.array(a_np, dev) -b = tvm.nd.array(b_np, dev) -c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) -f = tvm.build(sch.mod["main"], target="cuda", name="dense") - -print(f.imported_modules[0].get_source()) -f(a, b, c) -tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -print("ok") - -evaluator = f.time_evaluator(f.entry_name, dev, number=1000) -gflops = (N * M * K) * 2 / 1e9 -time_ms = evaluator(a, b, c).mean * 1e3 -print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3))) +# if tune: +# with tempfile.TemporaryDirectory() as work_dir: +# sch = ms.tune_tir( +# mod=workload, +# target=tvm.target.Target("nvidia/geforce-rtx-3070"), +# config=ms.TuneConfig( +# strategy="evolutionary", +# num_trials_per_iter=32, +# max_trials_per_task=128, +# max_trials_global=128, +# ), +# work_dir=work_dir, +# space=ms.space_generator.ScheduleFn(schedule), +# ) +# if sch is None: +# print("No valid schedule found!") +# else: +# print(sch.mod.script()) +# print(sch.trace) + + +# dev = tvm.device("cuda", 0) +# a_np = np.random.uniform(size=(N, K)).astype("float16") +# b_np = np.random.uniform(size=(K, M)).astype("float16") +# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()) +# a = tvm.nd.array(a_np, dev) +# b = tvm.nd.array(b_np, dev) +# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) +# f = tvm.build(sch.mod["main"], target="cuda", name="dense") + +# print(f.imported_modules[0].get_source()) +# f(a, b, c) +# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +# print("ok") + +# evaluator = f.time_evaluator(f.entry_name, dev, number=1000) +# gflops = (N * M * K) * 2 / 1e9 +# time_ms = evaluator(a, b, c).mean * 1e3 +# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))