From 6cc80094adac398762924b0b31a4c741417ba9dc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 07:10:06 +0900 Subject: [PATCH] refactored existing test using VNNI intrin --- python/tvm/tir/tensor_intrin/vnni.py | 14 +++-- .../unittest/test_meta_schedule_tune_relay.py | 57 +------------------ 2 files changed, 10 insertions(+), 61 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/vnni.py b/python/tvm/tir/tensor_intrin/vnni.py index c7cf864694d9..6f1d77ab8af0 100644 --- a/python/tvm/tir/tensor_intrin/vnni.py +++ b/python/tvm/tir/tensor_intrin/vnni.py @@ -18,6 +18,10 @@ from tvm.script import tir as T +# Tensorized intrinsic description and VNNI-specific implementation. +# Equivalent to the ones in topi/x86/tensor_intrin.py + + @T.prim_func def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4,), "uint8", offset_factor=1) @@ -52,9 +56,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") - C[ - T.ramp(T.int32(0), 1, 16) - ] += T.call_llvm_pure_intrin( # Note: this is an update += + C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.int32x16(0), @@ -64,6 +66,6 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: ) -TensorIntrin.register( - "dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin -) +INTRIN_NAME = "dot_16x1x16_uint8_int8_int32_cascadelake" + +TensorIntrin.register(INTRIN_NAME, dot_product_desc, dot_product_intrin) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 76cd82920c35..50f826378c61 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -42,6 +42,8 @@ from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.trace import Trace +from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN + logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -332,57 +334,6 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) -# Tensorized intrinsic description and VNNI-specific implementation. -# Equivalent to the ones in topi/x86/tensor_intrin.py - - -@T.prim_func -def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - - with T.block("root"): - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) - T.writes(C[0:16]) - for i in T.serial(0, 16): - with T.init(): - C[i] = T.int32(0) - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - - with T.block("root"): - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) - T.writes(C[0:16]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x64 = B.vload([0, 0], dtype="int8x64") - B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") - - C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += - T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), - T.uint32(0), - T.int32x16(0), - T.broadcast(A_i32, 16), - B_i32x16, - dtype="int32x16", - ) - - -VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake" - - def schedule_dense(dense_block, M, do_tune, sch): """ Manually schedule a dense block, created from TE compute op via CreatePrimFunc, @@ -550,10 +501,6 @@ def schedule_fn(task, sch): @pytest.mark.skip("Requires cascadelake") def test_tune_relay_manual_tir_vnni(): - # Register a pair of an intrinsic description for 16x4 dot product, and its - # VNNI-specific implementation. - tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni) - manual_tir_common(do_tune=False) """