Skip to content

Commit

Permalink
add DP4A intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 12, 2022
1 parent 7086bdb commit 7666cd7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
"""Intrinsics for tensorization."""
from .x86 import *
from .arm_cpu import *
from .dot_product_common import *
56 changes: 56 additions & 0 deletions python/tvm/tir/tensor_intrin/dot_product_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Dot product related intrinsics."""
from tvm.script import tir as T
from .. import TensorIntrin


@T.prim_func
def dp4a_desc(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[()], A[0:4], B[0:4])
T.writes(C[()])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[()] = C[()] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")


@T.prim_func
def dp4a_impl(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[()], A[0:4], B[0:4])
T.writes(C[()])

A_i8x4 = B.vload([0], "int8x4")
B_i8x4 = B.vload([0], "int8x4")

T.evaluate(T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, T.int32(0), dtype="int32"))


DP4A_INTRIN = "dp4a"

TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)

0 comments on commit 7666cd7

Please sign in to comment.