diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py new file mode 100644 index 000000000000..fa28cd80c682 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from .. import TensorIntrin +from tvm.script import tir as T + + +@T.prim_func +def dot_product_4x4_i8i8i32_desc( + A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + for i in T.serial(0, 4): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_4x4_i8i8i32_neon( + A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_int8 = A.vload([0], "int8x4") + re_int32 = T.reinterpret(A_int8, dtype="int32") + vec_ai32 = T.broadcast(re_int32, 2) + vec_a = T.reinterpret(vec_ai32, dtype="int8x8") + + vec_b = B.vload([0, 0], dtype="int8x8") + + multiply = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), + T.uint32(2), + vec_a, + vec_b, + dtype="int16x8", + ) + + pair1 = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), + T.uint32(1), + multiply, + dtype="int32x4", + ) + + vec_b_2 = B.vload([2, 0], dtype="int8x8") + + multiply_2 = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), + T.uint32(2), + vec_a, + vec_b_2, + dtype="int16x8", + ) + + pair2 = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), + T.uint32(1), + multiply_2, + dtype="int32x4", + ) + + C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), + T.uint32(2), + pair1, + pair2, + dtype="int32x4", + ) + + +@T.prim_func +def dot_product_4x4_i8i8i32_sdot( + A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"] +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_i8x4 = A.vload([0], "int8x4") + A_i32 = T.reinterpret(A_i8x4, dtype="int32") + vec_ai32 = T.broadcast(A_i32, 4) + vec_a = T.reinterpret(vec_ai32, dtype="int8x16") + + vec_b = B.vload([0, 0], dtype="int8x16") + + C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), + T.uint32(3), + T.int32x4(0), + vec_a, + vec_b, + dtype="int32x4", + ) + + +ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon" +ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot" + +TensorIntrin.register( + ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon +) + +TensorIntrin.register( + ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot +) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index 6fda9484df42..1d6accd9191b 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -39,13 +39,9 @@ def dot_product_16x4_u8i8i32_desc( @T.prim_func -def dot_product_16x4_u8i8i32_vnni_impl( +def dot_product_16x4_u8i8i32_vnni( A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"] ) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) @@ -69,5 +65,5 @@ def dot_product_16x4_u8i8i32_vnni_impl( VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni" TensorIntrin.register( - VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl + VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni ) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 3abdb0e93c61..b0a4a40b3daa 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -23,6 +23,7 @@ from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN +from tvm.tir.tensor_intrin.arm_cpu import ARM_DOT_4x4_i8_NEON_INTRIN, ARM_DOT_4x4_i8_SDOT_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -532,10 +533,9 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) -def test_tensorize_vnni(): - n, m, k = 128, 128, 128 - X = te.placeholder((m, k), name="X", dtype="uint8") - packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") +def get_matmul_packed(m, n, k, lhs_type, int32_lanes): + X = te.placeholder((m, k), name="X", dtype=lhs_type) + packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8") ak = te.reduce_axis((0, k), name="k") matmul = te.compute( @@ -550,7 +550,13 @@ def test_tensorize_vnni(): name="compute", ) - func = te.create_prim_func([X, packed_W, matmul]) + return te.create_prim_func([X, packed_W, matmul]) + + +def test_tensorize_vnni(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "uint8", 16) sch = tir.Schedule(func, debug_mask="all") block = sch.get_block("compute") @@ -566,6 +572,26 @@ def test_tensorize_vnni(): verify_trace_roundtrip(sch=sch, mod=func) +def test_tensorize_arm_dot(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "int8", 4) + + for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]: + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 4]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, intrin) + + verify_trace_roundtrip(sch=sch, mod=func) + + if __name__ == "__main__": # sys.exit(pytest.main([__file__] + sys.argv[1:])) - test_tensorize_vnni() + test_tensorize_arm_dot()