Skip to content

Commit

Permalink
Add ARM intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 120fd96 commit 86bbd49
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 12 deletions.
127 changes: 127 additions & 0 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .. import TensorIntrin
from tvm.script import tir as T


@T.prim_func
def dot_product_4x4_i8i8i32_desc(
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])
for i in T.serial(0, 4):
with T.init():
C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_4x4_i8i8i32_neon(
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])

A_int8 = A.vload([0], "int8x4")
re_int32 = T.reinterpret(A_int8, dtype="int32")
vec_ai32 = T.broadcast(re_int32, 2)
vec_a = T.reinterpret(vec_ai32, dtype="int8x8")

vec_b = B.vload([0, 0], dtype="int8x8")

multiply = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b,
dtype="int16x8",
)

pair1 = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply,
dtype="int32x4",
)

vec_b_2 = B.vload([2, 0], dtype="int8x8")

multiply_2 = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b_2,
dtype="int16x8",
)

pair2 = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply_2,
dtype="int32x4",
)

C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
T.uint32(2),
pair1,
pair2,
dtype="int32x4",
)


@T.prim_func
def dot_product_4x4_i8i8i32_sdot(
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
) -> None:
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])

A_i8x4 = A.vload([0], "int8x4")
A_i32 = T.reinterpret(A_i8x4, dtype="int32")
vec_ai32 = T.broadcast(A_i32, 4)
vec_a = T.reinterpret(vec_ai32, dtype="int8x16")

vec_b = B.vload([0, 0], dtype="int8x16")

C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
T.uint32(3),
T.int32x4(0),
vec_a,
vec_b,
dtype="int32x4",
)


ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"

TensorIntrin.register(
ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon
)

TensorIntrin.register(
ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot
)
8 changes: 2 additions & 6 deletions python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@ def dot_product_16x4_u8i8i32_desc(


@T.prim_func
def dot_product_16x4_u8i8i32_vnni_impl(
def dot_product_16x4_u8i8i32_vnni(
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
Expand All @@ -69,5 +65,5 @@ def dot_product_16x4_u8i8i32_vnni_impl(
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"

TensorIntrin.register(
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni
)
38 changes: 32 additions & 6 deletions tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
from tvm.tir.tensor_intrin.arm_cpu import ARM_DOT_4x4_i8_NEON_INTRIN, ARM_DOT_4x4_i8_SDOT_INTRIN

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
Expand Down Expand Up @@ -532,10 +533,9 @@ def test_tensorize_with_annotation():
verify_trace_roundtrip(sch=s, mod=func)


def test_tensorize_vnni():
n, m, k = 128, 128, 128
X = te.placeholder((m, k), name="X", dtype="uint8")
packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8")
def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
X = te.placeholder((m, k), name="X", dtype=lhs_type)
packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8")

ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
Expand All @@ -550,7 +550,13 @@ def test_tensorize_vnni():
name="compute",
)

func = te.create_prim_func([X, packed_W, matmul])
return te.create_prim_func([X, packed_W, matmul])


def test_tensorize_vnni():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "uint8", 16)

sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
Expand All @@ -566,6 +572,26 @@ def test_tensorize_vnni():
verify_trace_roundtrip(sch=sch, mod=func)


def test_tensorize_arm_dot():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "int8", 4)

for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]:
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 4])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(ko, ji, ki)

sch.decompose_reduction(block, ko)
sch.tensorize(ji, intrin)

verify_trace_roundtrip(sch=sch, mod=func)


if __name__ == "__main__":
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
test_tensorize_vnni()
test_tensorize_arm_dot()

0 comments on commit 86bbd49

Please sign in to comment.