Skip to content

Commit

Permalink
Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
2 parents 71fe3bd + 82e152a commit b171748
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 75 deletions.
4 changes: 3 additions & 1 deletion python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr:
def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
def evaluate(value: PrimExpr) -> None: ...
def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
def store(
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
) -> None: ...
Expand All @@ -143,7 +145,7 @@ def preflattened_buffer(
) -> Buffer: ...

"""
Intrinsics - tvm builtin
Intrinsics - tvm builtin
"""

def tvm_thread_allreduce(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def env_thread(env_name, span):
self.context.report_error(
f"VarDef expected assign to only one var, but got {names}", span
)
v = Var(names[0], span=span)
v = Var(names[0], dtype="int32", span=span)
self.context.func_var_env_dict[v] = env_name
self.context.update_symbol(v.name, v, self.node)

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import vnni
"""Intrinsics for tensorization."""
from .x86 import *
from .arm_cpu import *
151 changes: 151 additions & 0 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Intrinsics for ARM tensorization."""
from tvm.script import tir as T
from .. import TensorIntrin


# TODO(masahi): Parametrize the TVMScript description of dot product by
# shape and dtype, and share the common description with x86.


@T.prim_func
def dot_product_4x4_i8i8i32_desc(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
"""
A description for 4x4 dot product.
"""
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])
for i in T.serial(0, 4):
with T.init():
C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_4x4_i8i8i32_neon(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
"""
A implementation for 4x4 dot product, applicable for any ARM CPUs supporting NEON.
"""
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])

A_int8 = A.vload([0], "int8x4")
re_int32 = T.reinterpret(A_int8, dtype="int32")
vec_ai32 = T.broadcast(re_int32, 2)
vec_a = T.reinterpret(vec_ai32, dtype="int8x8")

vec_b = B.vload([0, 0], dtype="int8x16")

# TODO(masahi): Remove duplication when inlined function call is supported
vec_b_low = T.vectorlow(vec_b, dtype="int8x8")

multiply_low = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b_low,
dtype="int16x8",
)

pairwise_reduction_low = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply_low,
dtype="int32x4",
)

vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")

multiply_high = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b_high,
dtype="int16x8",
)

pairwise_reduction_high = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply_high,
dtype="int32x4",
)

C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
T.uint32(2),
pairwise_reduction_low,
pairwise_reduction_high,
dtype="int32x4",
)


@T.prim_func
def dot_product_4x4_i8i8i32_sdot(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
) -> None:
"""
A implementation for 4x4 dot product, applicable for ARM CPUs supporting sdot.
"""
with T.block("root"):
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
T.writes(C[0:4])

A_i8x4 = A.vload([0], "int8x4")
A_i32 = T.reinterpret(A_i8x4, dtype="int32")
vec_ai32 = T.broadcast(A_i32, 4)
vec_a = T.reinterpret(vec_ai32, dtype="int8x16")

vec_b = B.vload([0, 0], dtype="int8x16")

C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
T.uint32(3),
T.int32x4(0),
vec_a,
vec_b,
dtype="int32x4",
)


ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"

TensorIntrin.register(
ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon
)

TensorIntrin.register(
ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import tir
# pylint: disable=invalid-name
"""Intrinsics for x86 tensorization."""
from tvm.script import tir as T
from .. import TensorIntrin


@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
# Tensorized intrinsic description and VNNI-specific implementation.
# Equivalent to the ones in topi/x86/tensor_intrin.py


@T.prim_func
def dot_product_16x4_u8i8i32_desc(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((16, 4), "int8", offset_factor=1),
C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
"""
A description for 16x4 dot product.
"""
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
Expand All @@ -37,11 +46,14 @@ def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:


@T.prim_func
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

def dot_product_16x4_u8i8i32_vnni(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((16, 4), "int8", offset_factor=1),
C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
"""
A VNNI-specific implmementation for 16x4 dot product.
"""
with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
Expand All @@ -52,9 +64,7 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
B_i8x64 = B.vload([0, 0], dtype="int8x64")
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

C[
T.ramp(T.int32(0), 1, 16)
] += T.call_llvm_pure_intrin( # Note: this is an update +=
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.int32x16(0),
Expand All @@ -64,6 +74,8 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
)


tir.TensorIntrin.register(
"dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"

TensorIntrin.register(
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni
)
57 changes: 2 additions & 55 deletions tests/python/unittest/test_meta_schedule_tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from tvm.target.target import Target
from tvm.tir.schedule import BlockRV, Schedule
from tvm.tir.schedule.trace import Trace
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
Expand Down Expand Up @@ -332,57 +334,6 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)


# Tensorized intrinsic description and VNNI-specific implementation.
# Equivalent to the ones in topi/x86/tensor_intrin.py


@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])
for i in T.serial(0, 16):
with T.init():
C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
C = T.match_buffer(c, (16,), "int32", offset_factor=1)

with T.block("root"):
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
T.writes(C[0:16])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x64 = B.vload([0, 0], dtype="int8x64")
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(0),
T.int32x16(0),
T.broadcast(A_i32, 16),
B_i32x16,
dtype="int32x16",
)


VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"


def schedule_dense(dense_block, M, do_tune, sch):
"""
Manually schedule a dense block, created from TE compute op via CreatePrimFunc,
Expand Down Expand Up @@ -550,10 +501,6 @@ def schedule_fn(task, sch):

@pytest.mark.skip("Requires cascadelake")
def test_tune_relay_manual_tir_vnni():
# Register a pair of an intrinsic description for 16x4 dot product, and its
# VNNI-specific implementation.
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni)

manual_tir_common(do_tune=False)

"""
Expand Down
Loading

0 comments on commit b171748

Please sign in to comment.