Skip to content

Commit

Permalink
fixed intrin description
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 12, 2022
1 parent 7666cd7 commit 7b3d71c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/tvm/tir/tensor_intrin/dot_product_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,31 @@
def dp4a_desc(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[()], A[0:4], B[0:4])
T.writes(C[()])
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[()] = C[()] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")


@T.prim_func
def dp4a_impl(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[()], A[0:4], B[0:4])
T.writes(C[()])
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])

A_i8x4 = B.vload([0], "int8x4")
B_i8x4 = B.vload([0], "int8x4")

T.evaluate(T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, T.int32(0), dtype="int32"))
C[0] = T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, C[0], dtype="int32")


DP4A_INTRIN = "dp4a"
Expand Down

0 comments on commit 7b3d71c

Please sign in to comment.