Skip to content

Commit

Permalink
all tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 5e086cf commit 328d0aa
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def maybe_swap(i, j):
return i, j

c = te.compute(
(n, m),
(m, n),
lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]),
name="C",
)
Expand Down Expand Up @@ -132,7 +132,8 @@ def fetch_to_shared(block, idx, ndim):
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=8)
offset = 8 if in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

return block_read

Expand Down Expand Up @@ -180,36 +181,42 @@ def tile_wmma_fragment(block_read, height, width):
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)

# print(sch.mod.script())

f = tvm.build(sch.mod["main"], target="cuda", name="dense")
dev = tvm.device("cuda", 0)

if in_dtype == "float16":
a_np = np.random.uniform(size=(M, K)).astype("float16")

if b_transposed:
b_np = np.random.uniform(size=(N, K)).astype("float16").transpose()
b_np = np.random.uniform(size=(N, K)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
out_dtype
)
else:
b_np = np.random.uniform(size=(K, N)).astype("float16")

c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype)
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype)
else:
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")

if b_transposed:
b_np = np.random.randint(-128, 128, (N, K)).astype("int8").transpose()
b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
"int32"
)
else:
b_np = np.random.randint(-128, 128, (K, N)).astype("int8")

c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32")

a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev)

f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

if out_dtype != "float16":
# The numpy reference is computed with fp32 precision (otherwise too slow).
# So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation.
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)

Expand Down Expand Up @@ -372,7 +379,7 @@ def index_map_C(i, j):
)

if measure_perf:
print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean)))

timer = run_test(
k_inner,
Expand All @@ -393,7 +400,7 @@ def index_map_C(i, j):
)

if measure_perf:
print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))
print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))


if __name__ == "__main__":
Expand Down

0 comments on commit 328d0aa

Please sign in to comment.