Skip to content

Commit

Permalink
fixed missing reverse_compute_at
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 93f9fe7 commit 00df308
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/python/unittest/test_mma_16x8x8_4k.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def lambda_b(i, j):

# fetch to C_warp 16 * 8 -> 32 * 4
C_warp = sch.cache_write(block, 0, "warp")
# sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
# need to do a reverse_compute_at to place it under blockidx.x

sch.transform_layout(
Expand All @@ -307,6 +307,7 @@ def lambda_b(i, j):
fused_2 = sch.fuse(f_0, f_3)
sch.bind(fused_1, "threadIdx.x")


block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])

block_init_c = sch.get_block("C_init")
Expand Down Expand Up @@ -345,6 +346,13 @@ def lambda_b(i, j):
print(f.imported_modules[0].get_source())
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

print("ok")

evaluator = f.time_evaluator(f.entry_name, dev, number=100)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))


if __name__ == "__main__":
test_integration_matmul()

0 comments on commit 00df308

Please sign in to comment.