Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 12a376a commit 576f841
Showing 1 changed file with 94 additions and 114 deletions.
208 changes: 94 additions & 114 deletions tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
import numpy as np


def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)


@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
return tvm.runtime.convert([thread_id, local_id])


@T.prim_func
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
Expand All @@ -21,10 +32,10 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
with T.block("A_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_shared[v0, v1])
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
v0, v1
]

thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.writes(A_warp[thread_id, local_id])
A_warp[thread_id, local_id] = A_shared[v0, v1]


@T.prim_func
Expand Down Expand Up @@ -74,10 +85,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
with T.block("B_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B_shared[v0, v1])
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
v0, v1
]
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.writes(B_warp[thread_id, local_id])
B_warp[thread_id, local_id] = B_shared[v0, v1]


@T.prim_func
Expand Down Expand Up @@ -126,21 +136,20 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])

thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(j, k)

T.reads(
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2],
B[j % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[
i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
] + T.cast(
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2],
"float32",
) * T.cast(
B[j % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + j % 16 // 8 * 2 + k % 8 % 2],
"float32",
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])
C[thread_id_C, local_id_C] += T.cast(
A[thread_id_A, local_id_A], "float32"
) * T.cast(B[thread_id_B, local_id_B], "float32")


@T.prim_func
Expand Down Expand Up @@ -202,14 +211,13 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
for ax1_0, i0, i1 in T.grid(2, 32, 4):
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)

T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
v0, v1 = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]
C[v0, v1] = C_warp[thread_id, local_id]


@T.prim_func
Expand Down Expand Up @@ -242,21 +250,13 @@ def mma_fill_desc(a: T.handle) -> None:
with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(32, 8):
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
i_init, j_init = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
T.reads()
T.writes(
C_warp[
i_init % 8 * 4 + j_init % 8 // 2,
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
]
)
C_warp[
i_init % 8 * 4 + j_init % 8 // 2,
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
] = T.float32(0)
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = T.float32(0)


@T.prim_func
Expand Down Expand Up @@ -397,7 +397,8 @@ def fetch_to_shared(block, idx, ndim):
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height):
i, j = sch.get_loops(block_read)[-2:]
Expand All @@ -406,90 +407,69 @@ def tile_wmma_fragment(block_read, height):
sch.reorder(i0, j0, i1, j1)
return i1

def shared_16x16_to_ldmatrix_32x8_layout(i, j):
i_0 = i // 16
j_0 = j // 16

i = i % 16
j = j % 16

thread_id = 4 * (i % 8) + (j % 8) // 2
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2

loop_a = tile_wmma_fragment(A_warp, 16)
loop_b = tile_wmma_fragment(B_warp, 16)

sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
def index_map(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
)

sch.transform_layout(A_warp, 0, "write", index_map)
sch.transform_layout(B_warp, 0, "write", index_map)
sch.transform_layout(C_warp, 0, "read", index_map)

sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")

mma_loop = sch.get_loops(block_inner)[-3]
sch.tensorize(mma_loop, "mma_sync")

block_init_c = sch.get_block("C_init")
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
sch.reorder(f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(fused_1, "mma_fill")

warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
sch.reorder(outer, f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(outer, "mma_store")
# print(sch.mod.script())
# return
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")


ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)
schedule(sch)
print(sch.mod.script())

if tune:
with tempfile.TemporaryDirectory() as work_dir:
sch = ms.tune_tir(
mod=workload,
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=32,
max_trials_per_task=128,
max_trials_global=128,
),
work_dir=work_dir,
space=ms.space_generator.ScheduleFn(schedule),
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)


dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose())
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
f = tvm.build(sch.mod["main"], target="cuda", name="dense")

print(f.imported_modules[0].get_source())
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
print("ok")

evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
# if tune:
# with tempfile.TemporaryDirectory() as work_dir:
# sch = ms.tune_tir(
# mod=workload,
# target=tvm.target.Target("nvidia/geforce-rtx-3070"),
# config=ms.TuneConfig(
# strategy="evolutionary",
# num_trials_per_iter=32,
# max_trials_per_task=128,
# max_trials_global=128,
# ),
# work_dir=work_dir,
# space=ms.space_generator.ScheduleFn(schedule),
# )
# if sch is None:
# print("No valid schedule found!")
# else:
# print(sch.mod.script())
# print(sch.trace)


# dev = tvm.device("cuda", 0)
# a_np = np.random.uniform(size=(N, K)).astype("float16")
# b_np = np.random.uniform(size=(K, M)).astype("float16")
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose())
# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")

# print(f.imported_modules[0].get_source())
# f(a, b, c)
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
# print("ok")

# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
# gflops = (N * M * K) * 2 / 1e9
# time_ms = evaluator(a, b, c).mean * 1e3
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 comments on commit 576f841

Please sign in to comment.