Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 16, 2022
1 parent ad85036 commit 8f67fc8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def transform_Assign(self, node):
body = self.parse_body(node)

for var, value in zip(node.lhs, out):
self.context.remove_symbol(var.name)
self.context.remove_symbol(var.id.name)

return body

Expand Down
38 changes: 33 additions & 5 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,13 @@ def constant_binds_wrapped():
assert_structural_equal(constant_binds, constant_binds_wrapped)


def test():
def test_func_call():
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
thread_id = (i % 8) * 4 + (j % 8) // 2
return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)

@T.prim_func
def mma_sync_m16n16k16desc(a: T.handle, b: T.handle, c: T.handle) -> None:
def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
Expand All @@ -293,7 +293,35 @@ def mma_sync_m16n16k16desc(a: T.handle, b: T.handle, c: T.handle) -> None:
)
T.writes(C[thread_id_C, local_id_C])

C[thread_id_C, local_id_C] += A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
C[thread_id_C, local_id_C] += (
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
)

@T.prim_func
def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
+ A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
* B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
)

assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual)


if __name__ == "__main__":
Expand Down

0 comments on commit 8f67fc8

Please sign in to comment.