Skip to content

Commit

Permalink
int8 tensorize working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 20321fa commit f70ccd0
Showing 1 changed file with 50 additions and 53 deletions.
103 changes: 50 additions & 53 deletions tests/python/unittest/test_mma_16x8x32_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
with T.block("A_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_shared[v0, v1])
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
v0, v1
]
T.writes(A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4])
A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[v0, v1]


@T.prim_func
Expand Down Expand Up @@ -65,21 +63,19 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:

@T.prim_func
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
B_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared")
B_shared = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="shared")
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(B_shared[0:16, 0:32])
T.reads(B_shared[0:32, 0:16])
T.writes(B_warp[0:32, 0:16])

for ax0, ax1 in T.grid(16, 32):
for ax0, ax1 in T.grid(32, 16):
with T.block("B_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B_shared[v0, v1])
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
v0, v1
]
T.writes(B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4])
B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4] = B_shared[v0, v1]


@T.prim_func
Expand All @@ -88,7 +84,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
s0 = T.var("int32")
B_shared = T.match_buffer(
a,
(16, 32),
(32, 16),
"int8",
align=128,
offset_factor=16,
Expand All @@ -97,7 +93,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
)
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(B_shared[0:16, 0:32])
T.reads(B_shared[0:32, 0:16])
T.writes(B_warp[0:32, 0:16])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)
Expand All @@ -110,7 +106,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
B_warp.data,
16 * tx,
B_shared.data,
32 * (tx % 16) + 16 * (tx // 16),
16 * tx,
dtype="int8",
)
)
Expand All @@ -125,22 +121,12 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(32, 8, 16):
for i, j, k in T.grid(16, 16, 32):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
)
T.reads(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4])
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
] + T.cast(
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "int32"
) * T.cast(
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "int32"
)
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + T.cast(A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], "int32") * T.cast(B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4], "int32")


@T.prim_func
Expand Down Expand Up @@ -271,7 +257,9 @@ def mma_fill_impl(a: T.handle) -> None:
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)


M = N = K = 16
M = 16
N = 16
K = 32

def matmul_int8(n, m, k):
a = te.placeholder((n, k), name="A", dtype="int8")
Expand Down Expand Up @@ -300,13 +288,12 @@ def f_compute(i, j):

def fetch_to_shared(block, idx):
block_read = sch.cache_read(block, idx, "shared")
if use_gpu:
sch.compute_at(block_read, i1, True)
warp_size = 32
loops = sch.get_loops(block_read)
fused = sch.fuse(*loops[-2:])
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
sch.bind(f_1, "threadIdx.x")
sch.compute_at(block_read, i1, True)
warp_size = 32
loops = sch.get_loops(block_read)
fused = sch.fuse(*loops[-2:])
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
sch.bind(f_1, "threadIdx.x")

return block_read

Expand All @@ -320,18 +307,28 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2


def shared_16x32_to_ldmatrix_32x16_layout(i, j):
thread_id = 4 * (i % 8) + (j % 16) // 4
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4


def shared_32x16_to_ldmatrix_32x16_layout(i, j):
thread_id = (i % 4) + 4 * (j % 8)
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4


block = sch.get_block("C")

A_warp = sch.cache_read(block, 0, "warp")

# sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout)

B_warp = sch.cache_read(block, 1, "warp")

# sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
sch.transform_layout(B_warp, 0, "write", index_map=shared_32x16_to_ldmatrix_32x16_layout)

# sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
# sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")
sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")

C_warp = sch.cache_write(block, 0, "warp")
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
Expand All @@ -344,7 +341,7 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)

# sch.tensorize(outer, "mma_store")
sch.tensorize(outer, "mma_store")

block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])

Expand All @@ -356,25 +353,25 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(fused_1, "mma_fill")

# sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")
sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")

print(sch.mod.script())

# lowered = tvm.lower(sch.mod["main"])

# target = "cuda"
target = "cuda"

# f = tvm.build(sch.mod["main"], target=target, name="dense")
# dev = tvm.device(target, 0)
f = tvm.build(sch.mod["main"], target=target, name="dense")
dev = tvm.device(target, 0)

# a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
# b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
# c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32"))
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32"))

# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)

# # print(f.imported_modules[0].get_source())
# f(a, b, c)
# np.testing.assert_equal(c.numpy(), c_np)
# print(f.imported_modules[0].get_source())
f(a, b, c)
np.testing.assert_equal(c.numpy(), c_np)

0 comments on commit f70ccd0

Please sign in to comment.