Skip to content

Commit

Permalink
use vectorlow/high in arm intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 625cd27 commit d8e43ec
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
4 changes: 3 additions & 1 deletion python/tvm/script/tir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr:
def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
def evaluate(value: PrimExpr) -> None: ...
def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
def store(
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
) -> None: ...
Expand All @@ -143,7 +145,7 @@ def preflattened_buffer(
) -> Buffer: ...

"""
Intrinsics - tvm builtin
Intrinsics - tvm builtin
"""

def tvm_thread_allreduce(
Expand Down
27 changes: 15 additions & 12 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,45 +51,48 @@ def dot_product_4x4_i8i8i32_neon(
vec_ai32 = T.broadcast(re_int32, 2)
vec_a = T.reinterpret(vec_ai32, dtype="int8x8")

vec_b = B.vload([0, 0], dtype="int8x8")
vec_b = B.vload([0, 0], dtype="int8x16")

# TODO(masahi): Remove duplication when inlined function call is supported
vec_b_low = T.vectorlow(vec_b, dtype="int8x8")

multiply = T.call_llvm_pure_intrin(
multiply_low = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b,
vec_b_low,
dtype="int16x8",
)

pair1 = T.call_llvm_pure_intrin(
pairwise_reduction_low = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply,
multiply_low,
dtype="int32x4",
)

vec_b_2 = B.vload([2, 0], dtype="int8x8")
vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")

multiply_2 = T.call_llvm_pure_intrin(
multiply_high = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b_2,
vec_b_high,
dtype="int16x8",
)

pair2 = T.call_llvm_pure_intrin(
pairwise_reduction_high = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply_2,
multiply_high,
dtype="int32x4",
)

C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
T.uint32(2),
pair1,
pair2,
pairwise_reduction_low,
pairwise_reduction_high,
dtype="int32x4",
)

Expand Down
3 changes: 1 addition & 2 deletions tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,5 +593,4 @@ def test_tensorize_arm_dot():


if __name__ == "__main__":
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
test_tensorize_arm_dot()
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit d8e43ec

Please sign in to comment.