Skip to content

Commit

Permalink
refactored existing test using VNNI intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 711a007 commit 88b763e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 61 deletions.
14 changes: 8 additions & 6 deletions python/tvm/tir/tensor_intrin/vnni.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from tvm.script import tir as T


# Tensorized intrinsic description and VNNI-specific implementation.
# Equivalent to the ones in topi/x86/tensor_intrin.py


@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
Expand Down Expand Up @@ -52,9 +56,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
B_i8x64 = B.vload([0, 0], dtype="int8x64")
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

C[
T.ramp(T.int32(0), 1, 16)
] += T.call_llvm_pure_intrin( # Note: this is an update +=
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.int32x16(0),
Expand All @@ -64,6 +66,6 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
)


TensorIntrin.register(
"dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin
)
INTRIN_NAME = "dot_16x1x16_uint8_int8_int32_cascadelake"

TensorIntrin.register(INTRIN_NAME, dot_product_desc, dot_product_intrin)
57 changes: 2 additions & 55 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from tvm.target.target import Target
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.trace import Trace
from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
Expand Down Expand Up @@ -332,57 +334,6 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)


# Tensorized intrinsic description and VNNI-specific implementation.
# Equivalent to the ones in topi/x86/tensor_intrin.py


@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
for i in T.serial(0, 16):
with T.init():
C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x64 = B.vload([0, 0], dtype="int8x64")
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.int32x16(0),
T.broadcast(A_i32, 16),
B_i32x16,
dtype="int32x16",
)


VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"


def schedule_dense(dense_block, M, do_tune, sch):
"""
Manually schedule a dense block, created from TE compute op via CreatePrimFunc,
Expand Down Expand Up @@ -550,10 +501,6 @@ def schedule_fn(task, sch):

@pytest.mark.skip("Requires cascadelake")
def test_tune_relay_manual_tir_vnni():
# Register a pair of an intrinsic description for 16x4 dot product, and its
# VNNI-specific implementation.
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni)

manual_tir_common(do_tune=False)

"""
Expand Down

0 comments on commit 88b763e

Please sign in to comment.