From 711a0076d9be2b9aa80ada67e1edda5ba1fdf1fd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 07:04:58 +0900 Subject: [PATCH 01/12] [TIR] Add VNNI dot product intrinsic for TIR --- python/tvm/script/tir/special_stmt.py | 9 ++-- python/tvm/tir/tensor_intrin/__init__.py | 19 +++++++ python/tvm/tir/tensor_intrin/vnni.py | 69 ++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 python/tvm/tir/tensor_intrin/__init__.py create mode 100644 python/tvm/tir/tensor_intrin/vnni.py diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 0148bd0b4243..708acf6aa9a2 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -25,10 +25,9 @@ import tvm.tir from tvm.runtime import Object, String -from tvm import te from tvm.target import Target from tvm.ir import Span -from tvm.tir import IntImm, IterVar +from tvm.tir import IntImm, IterVar, Var from .node import BufferSlice from .utils import buffer_slice_to_region @@ -800,7 +799,7 @@ def var(dtype, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = te.var(names[0], dtype, span=span) + v = Var(names[0], dtype, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -821,7 +820,7 @@ def buffer_var(dtype, storage_scope, span): f"VarDef expected assign to only one var, but got {names}", span ) ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = te.var(names[0], ptr_type, span=span) + v = Var(names[0], ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(buffer_var, def_symbol=True) @@ -841,7 +840,7 @@ def env_thread(env_name, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = te.var(names[0], span=span) + v = Var(names[0], span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py new file mode 100644 index 000000000000..eff2653bee5c --- /dev/null +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Intrinsics for tensorization.""" +from . import vnni diff --git a/python/tvm/tir/tensor_intrin/vnni.py b/python/tvm/tir/tensor_intrin/vnni.py new file mode 100644 index 000000000000..c7cf864694d9 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/vnni.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from .. import TensorIntrin +from tvm.script import tir as T + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), "uint8", offset_factor=1) + B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) + C = T.match_buffer(c, (16,), "int32", offset_factor=1) + + with T.block("root"): + T.reads(C[0:16], A[0:4], B[0:16, 0:4]) + T.writes(C[0:16]) + for i in T.serial(0, 16): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), "uint8", offset_factor=1) + B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) + C = T.match_buffer(c, (16,), "int32", offset_factor=1) + + with T.block("root"): + T.reads(C[0:16], A[0:4], B[0:16, 0:4]) + T.writes(C[0:16]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + + B_i8x64 = B.vload([0, 0], dtype="int8x64") + B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + + C[ + T.ramp(T.int32(0), 1, 16) + ] += T.call_llvm_pure_intrin( # Note: this is an update += + T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), + T.uint32(0), + T.int32x16(0), + T.broadcast(A_i32, 16), + B_i32x16, + dtype="int32x16", + ) + + +TensorIntrin.register( + "dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin +) From 88b763ec48c20cf68db8bc3bae3fa3ae78996ee8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 07:10:06 +0900 Subject: [PATCH 02/12] refactored existing test using VNNI intrin --- python/tvm/tir/tensor_intrin/vnni.py | 14 +++-- .../unittest/test_meta_schedule_tune_relay.py | 57 +------------------ 2 files changed, 10 insertions(+), 61 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/vnni.py b/python/tvm/tir/tensor_intrin/vnni.py index c7cf864694d9..6f1d77ab8af0 100644 --- a/python/tvm/tir/tensor_intrin/vnni.py +++ b/python/tvm/tir/tensor_intrin/vnni.py @@ -18,6 +18,10 @@ from tvm.script import tir as T +# Tensorized intrinsic description and VNNI-specific implementation. +# Equivalent to the ones in topi/x86/tensor_intrin.py + + @T.prim_func def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4,), "uint8", offset_factor=1) @@ -52,9 +56,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") - C[ - T.ramp(T.int32(0), 1, 16) - ] += T.call_llvm_pure_intrin( # Note: this is an update += + C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.int32x16(0), @@ -64,6 +66,6 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: ) -TensorIntrin.register( - "dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin -) +INTRIN_NAME = "dot_16x1x16_uint8_int8_int32_cascadelake" + +TensorIntrin.register(INTRIN_NAME, dot_product_desc, dot_product_intrin) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 76cd82920c35..50f826378c61 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -42,6 +42,8 @@ from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.trace import Trace +from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN + logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -332,57 +334,6 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) -# Tensorized intrinsic description and VNNI-specific implementation. -# Equivalent to the ones in topi/x86/tensor_intrin.py - - -@T.prim_func -def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - - with T.block("root"): - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) - T.writes(C[0:16]) - for i in T.serial(0, 16): - with T.init(): - C[i] = T.int32(0) - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - - with T.block("root"): - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) - T.writes(C[0:16]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x64 = B.vload([0, 0], dtype="int8x64") - B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") - - C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += - T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), - T.uint32(0), - T.int32x16(0), - T.broadcast(A_i32, 16), - B_i32x16, - dtype="int32x16", - ) - - -VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake" - - def schedule_dense(dense_block, M, do_tune, sch): """ Manually schedule a dense block, created from TE compute op via CreatePrimFunc, @@ -550,10 +501,6 @@ def schedule_fn(task, sch): @pytest.mark.skip("Requires cascadelake") def test_tune_relay_manual_tir_vnni(): - # Register a pair of an intrinsic description for 16x4 dot product, and its - # VNNI-specific implementation. - tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni) - manual_tir_common(do_tune=False) """ From 38a5aca87ec438446593a3af17760339211f5ad9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 07:24:44 +0900 Subject: [PATCH 03/12] add VNNI unittest --- .../unittest/test_tir_schedule_tensorize.py | 40 ++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 5cef8d63587d..11f19e934e02 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -19,9 +19,10 @@ import pytest import tvm import tvm.testing -from tvm import tir +from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -531,5 +532,40 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) +def test_tensorize_vnni(): + n, m, k = 128, 128, 128 + X = te.placeholder((m, k), name="X", dtype="uint8") + packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") + + ak = te.reduce_axis((0, k), name="k") + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packed_W[ + tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4 + ].astype("int32"), + axis=ak, + ), + name="compute", + ) + + func = te.create_prim_func([X, packed_W, matmul]) + + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 16]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, VNNI_INTRIN) + + verify_trace_roundtrip(sch=sch, mod=func) + + if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_tensorize_vnni() From 0ced85fd097ed48aad8714912718d8735791e1fb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 08:17:43 +0900 Subject: [PATCH 04/12] rename vnni.py to x86.py --- python/tvm/tir/tensor_intrin/__init__.py | 2 +- python/tvm/tir/tensor_intrin/{vnni.py => x86.py} | 8 ++++---- tests/python/unittest/test_meta_schedule_tune_relay.py | 2 +- tests/python/unittest/test_tir_schedule_tensorize.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) rename python/tvm/tir/tensor_intrin/{vnni.py => x86.py} (89%) diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index eff2653bee5c..78089517b6cf 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -16,4 +16,4 @@ # under the License. # pylint: disable=unused-import """Intrinsics for tensorization.""" -from . import vnni +from . import x86 diff --git a/python/tvm/tir/tensor_intrin/vnni.py b/python/tvm/tir/tensor_intrin/x86.py similarity index 89% rename from python/tvm/tir/tensor_intrin/vnni.py rename to python/tvm/tir/tensor_intrin/x86.py index 6f1d77ab8af0..84b86ed6b202 100644 --- a/python/tvm/tir/tensor_intrin/vnni.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -23,7 +23,7 @@ @T.prim_func -def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: +def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4,), "uint8", offset_factor=1) B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) C = T.match_buffer(c, (16,), "int32", offset_factor=1) @@ -41,7 +41,7 @@ def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func -def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: +def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (4,), "uint8", offset_factor=1) B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) C = T.match_buffer(c, (16,), "int32", offset_factor=1) @@ -66,6 +66,6 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: ) -INTRIN_NAME = "dot_16x1x16_uint8_int8_int32_cascadelake" +VNNI_INTRIN = "dot_16x4_vnni" -TensorIntrin.register(INTRIN_NAME, dot_product_desc, dot_product_intrin) +TensorIntrin.register(VNNI_INTRIN, dot_product_16x4_desc, dot_product_16x4_vnni_impl) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 50f826378c61..fa59badc5da8 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -42,7 +42,7 @@ from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.trace import Trace -from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN logging.basicConfig() diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 11f19e934e02..548543630145 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -22,7 +22,7 @@ from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks From 1351fdea6b22f231a290a6c28e06732c9cf993cf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 08:27:27 +0900 Subject: [PATCH 05/12] use buffer syntax sugar --- python/tvm/tir/tensor_intrin/x86.py | 18 ++++++++++-------- .../unittest/test_meta_schedule_tune_relay.py | 2 +- .../unittest/test_tir_schedule_tensorize.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index 84b86ed6b202..6fda9484df42 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -23,11 +23,9 @@ @T.prim_func -def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - +def dot_product_16x4_u8i8i32_desc( + A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] +) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) @@ -41,7 +39,9 @@ def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func -def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None: +def dot_product_16x4_u8i8i32_vnni_impl( + A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] +) -> None: A = T.match_buffer(a, (4,), "uint8", offset_factor=1) B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) C = T.match_buffer(c, (16,), "int32", offset_factor=1) @@ -66,6 +66,8 @@ def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None: ) -VNNI_INTRIN = "dot_16x4_vnni" +VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni" -TensorIntrin.register(VNNI_INTRIN, dot_product_16x4_desc, dot_product_16x4_vnni_impl) +TensorIntrin.register( + VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl +) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index fa59badc5da8..a9da41f7e6aa 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -42,7 +42,7 @@ from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.trace import Trace -from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN logging.basicConfig() diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 548543630145..3abdb0e93c61 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -22,7 +22,7 @@ from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.tensor_intrin.x86 import VNNI_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks From 69e72b6b612588e670937e003435afa647030ceb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 10:12:02 +0900 Subject: [PATCH 06/12] Add ARM intrin --- python/tvm/tir/tensor_intrin/arm_cpu.py | 127 ++++++++++++++++++ python/tvm/tir/tensor_intrin/x86.py | 8 +- .../unittest/test_tir_schedule_tensorize.py | 38 +++++- 3 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 python/tvm/tir/tensor_intrin/arm_cpu.py diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py new file mode 100644 index 000000000000..fa28cd80c682 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from .. import TensorIntrin +from tvm.script import tir as T + + +@T.prim_func +def dot_product_4x4_i8i8i32_desc( + A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + for i in T.serial(0, 4): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_4x4_i8i8i32_neon( + A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_int8 = A.vload([0], "int8x4") + re_int32 = T.reinterpret(A_int8, dtype="int32") + vec_ai32 = T.broadcast(re_int32, 2) + vec_a = T.reinterpret(vec_ai32, dtype="int8x8") + + vec_b = B.vload([0, 0], dtype="int8x8") + + multiply = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), + T.uint32(2), + vec_a, + vec_b, + dtype="int16x8", + ) + + pair1 = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), + T.uint32(1), + multiply, + dtype="int32x4", + ) + + vec_b_2 = B.vload([2, 0], dtype="int8x8") + + multiply_2 = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), + T.uint32(2), + vec_a, + vec_b_2, + dtype="int16x8", + ) + + pair2 = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), + T.uint32(1), + multiply_2, + dtype="int32x4", + ) + + C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), + T.uint32(2), + pair1, + pair2, + dtype="int32x4", + ) + + +@T.prim_func +def dot_product_4x4_i8i8i32_sdot( + A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_i8x4 = A.vload([0], "int8x4") + A_i32 = T.reinterpret(A_i8x4, dtype="int32") + vec_ai32 = T.broadcast(A_i32, 4) + vec_a = T.reinterpret(vec_ai32, dtype="int8x16") + + vec_b = B.vload([0, 0], dtype="int8x16") + + C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), + T.uint32(3), + T.int32x4(0), + vec_a, + vec_b, + dtype="int32x4", + ) + + +ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon" +ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot" + +TensorIntrin.register( + ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon +) + +TensorIntrin.register( + ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot +) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index 6fda9484df42..1d6accd9191b 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -39,13 +39,9 @@ def dot_product_16x4_u8i8i32_desc( @T.prim_func -def dot_product_16x4_u8i8i32_vnni_impl( +def dot_product_16x4_u8i8i32_vnni( A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] ) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) @@ -69,5 +65,5 @@ def dot_product_16x4_u8i8i32_vnni_impl( VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni" TensorIntrin.register( - VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl + VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni ) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 3abdb0e93c61..b0a4a40b3daa 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -23,6 +23,7 @@ from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN +from tvm.tir.tensor_intrin.arm_cpu import ARM_DOT_4x4_i8_NEON_INTRIN, ARM_DOT_4x4_i8_SDOT_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -532,10 +533,9 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) -def test_tensorize_vnni(): - n, m, k = 128, 128, 128 - X = te.placeholder((m, k), name="X", dtype="uint8") - packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") +def get_matmul_packed(m, n, k, lhs_type, int32_lanes): + X = te.placeholder((m, k), name="X", dtype=lhs_type) + packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8") ak = te.reduce_axis((0, k), name="k") matmul = te.compute( @@ -550,7 +550,13 @@ def test_tensorize_vnni(): name="compute", ) - func = te.create_prim_func([X, packed_W, matmul]) + return te.create_prim_func([X, packed_W, matmul]) + + +def test_tensorize_vnni(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "uint8", 16) sch = tir.Schedule(func, debug_mask="all") block = sch.get_block("compute") @@ -566,6 +572,26 @@ def test_tensorize_vnni(): verify_trace_roundtrip(sch=sch, mod=func) +def test_tensorize_arm_dot(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "int8", 4) + + for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]: + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 4]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, intrin) + + verify_trace_roundtrip(sch=sch, mod=func) + + if __name__ == "__main__": # sys.exit(pytest.main([__file__] + sys.argv[1:])) - test_tensorize_vnni() + test_tensorize_arm_dot() From 625cd2774ec455307646b0c26bb3971d89613d1e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 10:34:57 +0900 Subject: [PATCH 07/12] fixed offset factor --- python/tvm/tir/tensor_intrin/arm_cpu.py | 12 +++++++++--- python/tvm/tir/tensor_intrin/x86.py | 8 ++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index fa28cd80c682..054b678b44e2 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -20,7 +20,9 @@ @T.prim_func def dot_product_4x4_i8i8i32_desc( - A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] + A: T.Buffer((4,), "int8", offset_factor=1), + B: T.Buffer((4, 4), "int8", offset_factor=1), + C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) @@ -36,7 +38,9 @@ def dot_product_4x4_i8i8i32_desc( @T.prim_func def dot_product_4x4_i8i8i32_neon( - A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] + A: T.Buffer((4,), "int8", offset_factor=1), + B: T.Buffer((4, 4), "int8", offset_factor=1), + C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) @@ -92,7 +96,9 @@ def dot_product_4x4_i8i8i32_neon( @T.prim_func def dot_product_4x4_i8i8i32_sdot( - A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] + A: T.Buffer((4,), "int8", offset_factor=1), + B: T.Buffer((4, 4), "int8", offset_factor=1), + C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index 1d6accd9191b..c0c551071c80 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -24,7 +24,9 @@ @T.prim_func def dot_product_16x4_u8i8i32_desc( - A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((16, 4), "int8", offset_factor=1), + C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) @@ -40,7 +42,9 @@ def dot_product_16x4_u8i8i32_desc( @T.prim_func def dot_product_16x4_u8i8i32_vnni( - A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((16, 4), "int8", offset_factor=1), + C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) From d8e43ecf1c0a79a2c195ff31e1e699a447a11335 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 10:52:50 +0900 Subject: [PATCH 08/12] use vectorlow/high in arm intrin --- python/tvm/script/tir/__init__.pyi | 4 ++- python/tvm/tir/tensor_intrin/arm_cpu.py | 27 ++++++++++--------- .../unittest/test_tir_schedule_tensorize.py | 3 +-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 1be249bc9e89..3eb383ed9974 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -124,6 +124,8 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... def evaluate(value: PrimExpr) -> None: ... def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... +def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ... +def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ... def store( var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True ) -> None: ... @@ -143,7 +145,7 @@ def preflattened_buffer( ) -> Buffer: ... """ -Intrinsics - tvm builtin +Intrinsics - tvm builtin """ def tvm_thread_allreduce( diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 054b678b44e2..141658d68367 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -51,45 +51,48 @@ def dot_product_4x4_i8i8i32_neon( vec_ai32 = T.broadcast(re_int32, 2) vec_a = T.reinterpret(vec_ai32, dtype="int8x8") - vec_b = B.vload([0, 0], dtype="int8x8") + vec_b = B.vload([0, 0], dtype="int8x16") + + # TODO(masahi): Remove duplication when inlined function call is supported + vec_b_low = T.vectorlow(vec_b, dtype="int8x8") - multiply = T.call_llvm_pure_intrin( + multiply_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, - vec_b, + vec_b_low, dtype="int16x8", ) - pair1 = T.call_llvm_pure_intrin( + pairwise_reduction_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), - multiply, + multiply_low, dtype="int32x4", ) - vec_b_2 = B.vload([2, 0], dtype="int8x8") + vec_b_high = T.vectorhigh(vec_b, dtype="int8x8") - multiply_2 = T.call_llvm_pure_intrin( + multiply_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), T.uint32(2), vec_a, - vec_b_2, + vec_b_high, dtype="int16x8", ) - pair2 = T.call_llvm_pure_intrin( + pairwise_reduction_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), T.uint32(1), - multiply_2, + multiply_high, dtype="int32x4", ) C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), T.uint32(2), - pair1, - pair2, + pairwise_reduction_low, + pairwise_reduction_high, dtype="int32x4", ) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index b0a4a40b3daa..315aff4c38f0 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -593,5 +593,4 @@ def test_tensorize_arm_dot(): if __name__ == "__main__": - # sys.exit(pytest.main([__file__] + sys.argv[1:])) - test_tensorize_arm_dot() + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 9a3e508b6f4529158e703b4617f2ddaa351a89eb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 10:58:52 +0900 Subject: [PATCH 09/12] simplify import --- python/tvm/tir/tensor_intrin/__init__.py | 3 ++- tests/python/unittest/test_tir_schedule_tensorize.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 78089517b6cf..62159851b3d4 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -16,4 +16,5 @@ # under the License. # pylint: disable=unused-import """Intrinsics for tensorization.""" -from . import x86 +from .x86 import * +from .arm_cpu import * diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 315aff4c38f0..482d6f3db574 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -22,8 +22,11 @@ from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN -from tvm.tir.tensor_intrin.arm_cpu import ARM_DOT_4x4_i8_NEON_INTRIN, ARM_DOT_4x4_i8_SDOT_INTRIN +from tvm.tir.tensor_intrin import ( + VNNI_DOT_16x4_INTRIN, + ARM_DOT_4x4_i8_NEON_INTRIN, + ARM_DOT_4x4_i8_SDOT_INTRIN, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -567,7 +570,7 @@ def test_tensorize_vnni(): sch.reorder(ko, ji, ki) sch.decompose_reduction(block, ko) - sch.tensorize(ji, VNNI_INTRIN) + sch.tensorize(ji, VNNI_DOT_16x4_INTRIN) verify_trace_roundtrip(sch=sch, mod=func) From 7a757fe53758e06418ea1367b348b47c8cd2dcf9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 11:12:54 +0900 Subject: [PATCH 10/12] pylint --- python/tvm/tir/tensor_intrin/arm_cpu.py | 14 ++++++++++++++ python/tvm/tir/tensor_intrin/x86.py | 10 +++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 141658d68367..6aebc5be3417 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -14,16 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name +"""Intrinsics for x86 tensorization.""" from .. import TensorIntrin from tvm.script import tir as T +# TODO(masahi): Parametrize the TVMScript description of dot product by +# shape and dtype, and share the common description with x86. + @T.prim_func def dot_product_4x4_i8i8i32_desc( A: T.Buffer((4,), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: + """ + A description for 4x4 dot product. + """ with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) @@ -42,6 +50,9 @@ def dot_product_4x4_i8i8i32_neon( B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: + """ + A implementation for 4x4 dot product, applicable for any ARM CPUs supporting NEON. + """ with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) @@ -103,6 +114,9 @@ def dot_product_4x4_i8i8i32_sdot( B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4,), "int32", offset_factor=1), ) -> None: + """ + A implementation for 4x4 dot product, applicable for ARM CPUs supporting sdot. + """ with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index c0c551071c80..cd06490633c2 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from .. import TensorIntrin +# pylint: disable=invalid-name +"""Intrinsics for x86 tensorization.""" from tvm.script import tir as T +from .. import TensorIntrin # Tensorized intrinsic description and VNNI-specific implementation. @@ -28,6 +30,9 @@ def dot_product_16x4_u8i8i32_desc( B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: + """ + A description for 16x4 dot product. + """ with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) @@ -46,6 +51,9 @@ def dot_product_16x4_u8i8i32_vnni( B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: + """ + A VNNI-specific implmementation for 16x4 dot product. + """ with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) From 15e60b42362cc64b1428b219c8eada414d1b8372 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 11:16:08 +0900 Subject: [PATCH 11/12] black --- python/tvm/tir/tensor_intrin/arm_cpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 6aebc5be3417..a360bf609b18 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -23,6 +23,7 @@ # TODO(masahi): Parametrize the TVMScript description of dot product by # shape and dtype, and share the common description with x86. + @T.prim_func def dot_product_4x4_i8i8i32_desc( A: T.Buffer((4,), "int8", offset_factor=1), From 07bbb38f7fb52db4a2ecde3d5c87cf4d5cd000a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 11:24:56 +0900 Subject: [PATCH 12/12] more lint fix --- python/tvm/script/tir/special_stmt.py | 2 +- python/tvm/tir/tensor_intrin/arm_cpu.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 708acf6aa9a2..3d0fb407ef3f 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -840,7 +840,7 @@ def env_thread(env_name, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = Var(names[0], span=span) + v = Var(names[0], dtype="int32", span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index a360bf609b18..df658edf323c 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""Intrinsics for x86 tensorization.""" -from .. import TensorIntrin +"""Intrinsics for ARM tensorization.""" from tvm.script import tir as T +from .. import TensorIntrin # TODO(masahi): Parametrize the TVMScript description of dot product by