diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py b/python/tvm/tir/tensor_intrin/dot_product_common.py index 7f9fe86ba727f..c531b80380e3c 100644 --- a/python/tvm/tir/tensor_intrin/dot_product_common.py +++ b/python/tvm/tir/tensor_intrin/dot_product_common.py @@ -45,10 +45,9 @@ def dp4a_impl( T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) - A_i8x4 = B.vload([0], "int8x4") - B_i8x4 = B.vload([0], "int8x4") - - C[0] = T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, C[0], dtype="int32") + C[0] += T.call_pure_extern( + "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32" + ) DP4A_INTRIN = "dp4a"