diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index ba870a4ddd503..cbede480ebea4 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -68,7 +68,7 @@ def maybe_swap(i, j): return i, j c = te.compute( - (n, m), + (m, n), lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]), name="C", ) @@ -132,7 +132,8 @@ def fetch_to_shared(block, idx, ndim): sch.bind(f_2, "threadIdx.x") sch.bind(f_1, "threadIdx.y") sch.vectorize(f_3) - sch.storage_align(block_read, 0, axis=-2, factor=32, offset=8) + offset = 8 if in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) return block_read @@ -180,8 +181,6 @@ def tile_wmma_fragment(block_read, height, width): sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) - # print(sch.mod.script()) - f = tvm.build(sch.mod["main"], target="cuda", name="dense") dev = tvm.device("cuda", 0) @@ -189,27 +188,35 @@ def tile_wmma_fragment(block_read, height, width): a_np = np.random.uniform(size=(M, K)).astype("float16") if b_transposed: - b_np = np.random.uniform(size=(N, K)).astype("float16").transpose() + b_np = np.random.uniform(size=(N, K)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + out_dtype + ) else: b_np = np.random.uniform(size=(K, N)).astype("float16") - - c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) else: a_np = np.random.randint(-128, 128, (M, K)).astype("int8") if b_transposed: - b_np = np.random.randint(-128, 128, (N, K)).astype("int8").transpose() + b_np = np.random.randint(-128, 128, (N, K)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + "int32" + ) else: b_np = np.random.randint(-128, 128, (K, N)).astype("int8") - - c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) f(a, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + if out_dtype != "float16": + # The numpy reference is computed with fp32 precision (otherwise too slow). + # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) @@ -372,7 +379,7 @@ def index_map_C(i, j): ) if measure_perf: - print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean))) timer = run_test( k_inner, @@ -393,7 +400,7 @@ def index_map_C(i, j): ) if measure_perf: - print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) if __name__ == "__main__":