Skip to content

Commit

Permalink
fixed offset factor
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 86bbd49 commit 995cc8d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
12 changes: 9 additions & 3 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

@T.prim_func
def dot_product_4x4_i8i8i32_desc(
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
Expand All @@ -36,7 +38,9 @@ def dot_product_4x4_i8i8i32_desc(

@T.prim_func
def dot_product_4x4_i8i8i32_neon(
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
Expand Down Expand Up @@ -92,7 +96,9 @@ def dot_product_4x4_i8i8i32_neon(

@T.prim_func
def dot_product_4x4_i8i8i32_sdot(
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

@T.prim_func
def dot_product_16x4_u8i8i32_desc(
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((16, 4), "int8", offset_factor=1),
C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
Expand All @@ -40,7 +42,9 @@ def dot_product_16x4_u8i8i32_desc(

@T.prim_func
def dot_product_16x4_u8i8i32_vnni(
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((16, 4), "int8", offset_factor=1),
C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
Expand Down

0 comments on commit 995cc8d

Please sign in to comment.