From 54f1cb731d4b42a6cbc08baf144e74646400eef5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 13 May 2022 18:23:27 +0900 Subject: [PATCH] wip --- .../test_mma_16x8x16_4k_tune_simple.py | 430 ++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py b/tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py new file mode 100644 index 000000000000..6a2cc795b473 --- /dev/null +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py @@ -0,0 +1,430 @@ +import tempfile +import tvm +from tvm.script import tir as T +import tvm.meta_schedule.testing.te_workload as te_workload +from tvm import te, tir +from tvm import meta_schedule as ms +import tvm.testing +import numpy as np + + +@T.prim_func +def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: + A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + A_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(A_shared[0:16, 0:16]) + T.writes(A_warp[0:16, 0:16]) + + for ax0, ax1 in T.grid(16, 16): + with T.block("A_shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A_shared[v0, v1]) + T.writes(A_warp[v0, v1]) + A_warp[v0, v1] = A_shared[v0, v1] + + +@T.prim_func +def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A_shared = T.match_buffer( + a, + (16, 16), + "float16", + align=128, + offset_factor=16, + scope="shared", + strides=[s1, s0], + ) + A_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + with T.block("root"): + T.reads(A_shared[0:16, 0:16]) + T.writes(A_warp[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_ldmatrix( + 0, + 4, + ".b16", + A_warp.data, + A_warp.elem_offset + 8 * tx, + A_shared.access_ptr("r"), + s1 * (tx % 16) + 8 * (tx // 16), + dtype="float16", + ) + ) + + +@T.prim_func +def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: + B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + B_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(B_shared[0:16, 0:16]) + T.writes(B_warp[0:16, 0:16]) + + for ax0, ax1 in T.grid(16, 16): + with T.block("B_shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B_shared[v0, v1]) + T.writes(B_warp[v0, v1]) + B_warp[v0, v1] = B_shared[v0, v1] + + +@T.prim_func +def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + B_shared = T.match_buffer( + a, + (16, 16), + "float16", + align=128, + offset_factor=16, + scope="shared", + strides=[s1, s0], + ) + B_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + with T.block("root"): + T.reads(B_shared[0:16, 0:16]) + T.writes(B_warp[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_ldmatrix( + 1, + 4, + ".b16", + B_warp.data, + B_warp.elem_offset + 8 * tx, + B_shared.access_ptr("r"), + s1 * (tx % 16) + 8 * (tx // 16), + dtype="float16", + ) + ) + + +@T.prim_func +def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + # T.reads(C[i % 16, j % 16], A[i % 16, k % 16], B[k % 16, j % 16]) + # T.writes(C[i % 16, j % 16]) + # C[i % 16, j % 16] = C[i % 16, j % 16] + T.cast(A[i % 16, k % 16], "float32") * T.cast(B[k % 16, j % 16], "float32") + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + C[i, j] = C[i, j] + T.cast(A[i, k], "float32") * T.cast(B[k, j], "float32") + + +@T.prim_func +def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp32", + A.data, + A.elem_offset + tx * 8, + B.data, + B.elem_offset + tx * 8, + C.data, + C.elem_offset + tx * 8, + False, + dtype="float32", + ) + ) + + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp32", + A.data, + A.elem_offset + tx * 8, + B.data, + B.elem_offset + tx * 8 + 4, + C.data, + C.elem_offset + tx * 8 + 4, + False, + dtype="float32", + ) + ) + + +@T.prim_func +def mma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp") + C = T.match_buffer(c, [16, 16], dtype="float32", scope="global") + + with T.block("root"): + T.reads(C_warp[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i0, i1 in T.grid(16, 16): + with T.block("C_warp"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(C_warp[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[v0, v1] + + +@T.prim_func +def mma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + + C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp", offset_factor=1) + C = T.match_buffer( + c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0] + ) + + with T.block("root"): + T.reads(C_warp[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.mma_store( + 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32" + ) + ) + + +@T.prim_func +def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:16, 0:16]) + for i0, i1 in T.grid(16, 16): + with T.block("C_warp"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + # T.writes(C_warp[i % 16, j % 16]) + # C_warp[i % 16, j % 16] = T.float32(0) + T.writes(C_warp[i, j]) + C_warp[i, j] = T.float32(0) + + +@T.prim_func +def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp", offset_factor=1) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) + + +tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) +tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) +tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) +tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) +tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) + +N = 4096 +M = 4096 +K = 4096 + +workload = te.create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K)) + +tune = False + + +def schedule(sch: tir.Schedule): + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i, i_tc = sch.split(i, factors=[None, 16]) + j, j_tc = sch.split(j, factors=[None, 16]) + k, k_tc = sch.split(k, factors=[None, 16]) + + sch.reorder( + i, + j, + k, + i_tc, + j_tc, + k_tc, + ) + block_inner = sch.blockize(i_tc) + + block_outer, block_inner = block_inner, block + + if tune: + i_factors = sch.sample_perfect_tile(i, n=5) + j_factors = sch.sample_perfect_tile(j, n=5) + k_factors = sch.sample_perfect_tile(k, n=3) + num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2]) + else: + i_factors = [4, 8, 2, 4, 1] + j_factors = [1, 64, 2, 1, 2] + k_factors = [128, 2, 1] + + num_ty = i_factors[2] * j_factors[2] + + i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) + j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) + k0, k1, k2 = sch.split(k, k_factors) + + sch.reorder( + i0, + j0, # S => blockIdx.x + i1, + j1, # S => blockIdx.y + j2, + i2, # S => threadIdx.y + # cache_write here + k0, # R + # vectorized cooperative fetching here + k1, # R + i3, + j3, # S + k2, # R + i4, + j4, + # S + ) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k0) + vector_size = 8 + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + f_0, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=8) + + return block_read + + A_sh = fetch_to_shared(block_outer, 0, 2) + B_sh = fetch_to_shared(block_outer, 1, 2) + + loop = sch.get_loops(block_outer)[-1] + + A_warp = sch.cache_read(block_outer, 0, "warp") + B_warp = sch.cache_read(block_outer, 1, "warp") + + sch.compute_at(A_warp, k1) + sch.compute_at(B_warp, k1) + + C_warp = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(C_warp, thread_idy) + + ii, jj = sch.get_loops(C_warp)[-2:] + io, ii = sch.split(ii, factors=[None, 16]) + jo, ji = sch.split(jj, factors=[None, 16]) + sch.reorder(io, jo, ii, ji) + + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") + + def tile_wmma_fragment(block_read, height): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + return i1 + + loop_a = tile_wmma_fragment(A_warp, 16) + loop_b = tile_wmma_fragment(B_warp, 16) + + # sch.transform_layout(A_warp, 0, "write", index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)) + # sch.transform_layout(B_warp, 0, "write", index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)) + # sch.transform_layout(C_warp, 0, "read", index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)) + + sch.tensorize(loop_a, "mma.ldmatrix_a") + sch.tensorize(loop_b, "mma.ldmatrix_b") + sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") + sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") + sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + + +ir_module = tvm.IRModule({"main": workload}) +sch = tvm.tir.Schedule(ir_module) +schedule(sch) +print(sch.mod.script()) +# print(tvm.tir.transform.CompactBufferAllocation()(sch.mod)) + +# if tune: +# with tempfile.TemporaryDirectory() as work_dir: +# sch = ms.tune_tir( +# mod=workload, +# target=tvm.target.Target("nvidia/geforce-rtx-3070"), +# config=ms.TuneConfig( +# strategy="evolutionary", +# num_trials_per_iter=32, +# max_trials_per_task=128, +# max_trials_global=128, +# ), +# work_dir=work_dir, +# space=ms.space_generator.ScheduleFn(schedule), +# ) +# if sch is None: +# print("No valid schedule found!") +# else: +# print(sch.mod.script()) +# print(sch.trace) + +f = tvm.build(sch.mod["main"], target="cuda", name="dense") + +# dev = tvm.device("cuda", 0) +# a_np = np.random.uniform(size=(N, K)).astype("float16") +# b_np = np.random.uniform(size=(K, M)).astype("float16") +# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) +# a = tvm.nd.array(a_np, dev) +# b = tvm.nd.array(b_np, dev) +# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) + +# print(f.imported_modules[0].get_source()) +# f(a, b, c) +# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +# print("ok") + +# evaluator = f.time_evaluator(f.entry_name, dev, number=1000) +# gflops = (N * M * K) * 2 / 1e9 +# time_ms = evaluator(a, b, c).mean * 1e3 +# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))