Skip to content

Commit

Permalink
use buffer syntax sugar
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 6, 2022
1 parent 0f0682d commit 120fd96
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
18 changes: 10 additions & 8 deletions python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@


@T.prim_func
def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

def dot_product_16x4_u8i8i32_desc(
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
) -> None:
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
Expand All @@ -41,7 +39,9 @@ def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None:


@T.prim_func
def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
def dot_product_16x4_u8i8i32_vnni_impl(
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
Expand All @@ -66,6 +66,8 @@ def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
)


VNNI_INTRIN = "dot_16x4_vnni"
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"

TensorIntrin.register(VNNI_INTRIN, dot_product_16x4_desc, dot_product_16x4_vnni_impl)
TensorIntrin.register(
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl
)
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from tvm.target.target import Target
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.trace import Trace
from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN


logging.basicConfig()
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm import tir, te
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
Expand Down

0 comments on commit 120fd96

Please sign in to comment.