diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py b/python/tvm/tir/tensor_intrin/dot_product_common.py index 9dad5bd475a26..7f9fe86ba727f 100644 --- a/python/tvm/tir/tensor_intrin/dot_product_common.py +++ b/python/tvm/tir/tensor_intrin/dot_product_common.py @@ -24,31 +24,31 @@ def dp4a_desc( A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), - C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: with T.block("root"): - T.reads(C[()], A[0:4], B[0:4]) - T.writes(C[()]) + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) for i in range(0, 4): with T.block("update"): vi = T.axis.remap("R", [i]) - C[()] = C[()] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32") + C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32") @T.prim_func def dp4a_impl( A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), - C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: with T.block("root"): - T.reads(C[()], A[0:4], B[0:4]) - T.writes(C[()]) + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) A_i8x4 = B.vload([0], "int8x4") B_i8x4 = B.vload([0], "int8x4") - T.evaluate(T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, T.int32(0), dtype="int32")) + C[0] = T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, C[0], dtype="int32") DP4A_INTRIN = "dp4a"