Skip to content

Commit

Permalink
[ROCM] Support dp4a on AMDGPU by sdot4 intrinsic
Browse files Browse the repository at this point in the history
commit 0225f2b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Apr 14 08:56:10 2022 +0900

    share op strategy between cuda and rocm

commit 762c7e8
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Apr 14 08:28:34 2022 +0900

    fixed rocm batch_matmul strategy for mixed i8i8i32

commit ce53e8d
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Apr 14 06:17:30 2022 +0900

    add rocm sdot4 TIR intrin

commit f4562b9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Apr 14 06:03:44 2022 +0900

    rocm sdot4 works

commit 6cc6280
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Apr 14 05:32:07 2022 +0900

    more wip

commit 0602f4a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Thu Apr 14 03:47:37 2022 +0900

    Squashed commit of the following:

    commit 65b8bcf
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Wed Apr 13 20:36:49 2022 +0900

        [WIP] adding DP4A support to rocm

    commit 4f8f308
    Author: Masahiro Masuda <masahi129@gmail.com>
    Date:   Wed Apr 13 14:03:25 2022 +0900

        Squashed commit of the following:

        commit 1711be3
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 13 13:11:40 2022 +0900

            fixed condition for real

        commit 8a48fb5
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 13 09:57:42 2022 +0900

            Revert "Skip applying sch_rule when both ann and sch_rule are defined"

            This reverts commit 4915c6a.

        commit daea033
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Mon Apr 11 09:31:05 2022 +0900

            [Metaschedule] Support rocm and spirv

        commit eb0cae2
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 13 07:25:04 2022 +0900

            dp4a works

        commit 4915c6a
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 13 06:13:45 2022 +0900

            Skip applying sch_rule when both ann and sch_rule are defined

        commit 7b3d71c
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 13 04:40:31 2022 +0900

            fixed intrin description

        commit 7666cd7
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Tue Apr 12 19:59:47 2022 +0900

            add DP4A intrin

        commit 7086bdb
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Tue Apr 12 19:03:44 2022 +0900

            works

        commit db34397
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Tue Apr 12 12:49:52 2022 +0900

            more hack to tensorize loop mapping to make resnet50 e2e work

        commit 2409674
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Mon Apr 11 13:40:59 2022 +0900

            wip support pad + qnn.conv2d folding

        commit 613cb7e
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Sun Apr 10 12:04:08 2022 +0900

            hack to tensorize loop mapping to make conv2d work

        commit 9e4f9df
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Sun Apr 10 11:34:13 2022 +0900

            wrap tensorize with try/catch

        commit d4b496d
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Sun Apr 10 11:33:39 2022 +0900

            revert change in task_scheduler.cc

        commit 476129b
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Sat Apr 9 05:54:10 2022 +0900

            try / catch in ThreadedApply

        commit d8226ff
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Fri Apr 8 17:17:59 2022 +0900

            filter out invalid candidate

        commit 2632899
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Fri Apr 8 10:09:48 2022 +0900

            try graceful exit in parallel_for_dynamic

        commit 9d6741c
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Fri Apr 8 09:35:51 2022 +0900

            [QNN] Fix broadcast for invalid axis

        commit 6ccde09
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 20:51:15 2022 +0900

            refactor rewrite_tensorize

        commit 2ce2066
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 20:48:17 2022 +0900

            allow missing schedule_rule in post order apply

        commit 3a69353
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 19:42:48 2022 +0900

            refactor rewrite_tensorize

        commit 43e0b2f
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 18:25:14 2022 +0900

            rewrite_vnni -> rewrite_tensorize

        commit 823797e
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 18:12:12 2022 +0900

            VNNI -> WithIntrin

        commit 4284a47
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 17:45:41 2022 +0900

            introduce TileForIntrin

        commit b87ef32
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 17:34:04 2022 +0900

            move TilingwithTensorIntrin to auto_tensorize.cc

        commit 2fc118b
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 17:28:45 2022 +0900

            clean up headers

        commit d8b2aa3
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 17:09:32 2022 +0900

            clean up using namespace

        commit eb05d25
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 17:03:05 2022 +0900

            refactored init

        commit 5e6b0a0
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 16:57:14 2022 +0900

            compiled

        commit 2b8c430
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 12:51:55 2022 +0900

            wip MultiLevelTiling refactor

        commit 7c21a9f
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:58:33 2022 +0900

            function doc string not supported by tvmscript

        commit 40f9742
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:56:45 2022 +0900

            update vnni intrin name

        commit 4814f82
        Merge: e0c5eb8 07bbb38
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:44:47 2022 +0900

            Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni

        commit 07bbb38
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:24:56 2022 +0900

            more lint fix

        commit 15e60b4
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:16:08 2022 +0900

            black

        commit 7a757fe
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:12:54 2022 +0900

            pylint

        commit 9a3e508
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:58:52 2022 +0900

            simplify import

        commit d8e43ec
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:52:50 2022 +0900

            use vectorlow/high in arm intrin

        commit 625cd27
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:34:57 2022 +0900

            fixed offset factor

        commit 69e72b6
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:12:02 2022 +0900

            Add ARM intrin

        commit 1351fde
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 08:27:27 2022 +0900

            use buffer syntax sugar

        commit 0ced85f
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 08:17:43 2022 +0900

            rename vnni.py to x86.py

        commit 38a5aca
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 07:24:44 2022 +0900

            add VNNI unittest

        commit 88b763e
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 07:10:06 2022 +0900

            refactored existing test using VNNI intrin

        commit 711a007
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 07:04:58 2022 +0900

            [TIR] Add VNNI dot product intrinsic for TIR

        commit e0c5eb8
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:42:26 2022 +0900

            merge fix

        commit b171748
        Merge: 71fe3bd 82e152a
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:33:59 2022 +0900

            Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni

        commit 71fe3bd
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 06:57:38 2022 +0900

            move tensor intrin under tir

        commit 0c51bad
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 06:12:39 2022 +0900

            remove log

        commit fed910e
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 06:11:22 2022 +0900

            more revert

        commit 7150aff
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 06:10:44 2022 +0900

            revert stmt_functor change

        commit 155107b
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 06:10:09 2022 +0900

            refactored RewriteVNNI a bit

        commit ca15255
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 05:41:13 2022 +0900

            add RewriteVNNI

        commit dc9f71d
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 05:38:56 2022 +0900

            vectorized init loop

        commit fcc31ee
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 04:55:36 2022 +0900

            tensorize worked

        commit 2b53437
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 6 19:11:05 2022 +0900

            TilingwithTensorIntrin works

        commit 86baa31
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Wed Apr 6 08:58:27 2022 +0900

            Ported auto-tensorization code

        commit 82e152a
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:24:56 2022 +0900

            more lint fix

        commit 88d9bdd
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:16:08 2022 +0900

            black

        commit 31fe7eb
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 11:12:54 2022 +0900

            pylint

        commit 7876754
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:58:52 2022 +0900

            simplify import

        commit 56f2e9a
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:52:50 2022 +0900

            use vectorlow/high in arm intrin

        commit 995cc8d
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:34:57 2022 +0900

            fixed offset factor

        commit 86bbd49
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 10:12:02 2022 +0900

            Add ARM intrin

        commit 120fd96
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 08:27:27 2022 +0900

            use buffer syntax sugar

        commit 0f0682d
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 08:17:43 2022 +0900

            rename vnni.py to x86.py

        commit f88c31e
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 07:24:44 2022 +0900

            add VNNI unittest

        commit 6cc8009
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 07:10:06 2022 +0900

            refactored existing test using VNNI intrin

        commit 11a29c7
        Author: Masahiro Masuda <masahi129@gmail.com>
        Date:   Thu Apr 7 07:04:58 2022 +0900

            [TIR] Add VNNI dot product intrinsic for TIR
  • Loading branch information
masahi committed Apr 14, 2022
1 parent e370ed4 commit f8bc306
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 274 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if layout == "NCHW":
assert kernel_layout == "OIHW"
if (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ("int8", "uint8")
and kernel.dtype in ("int8", "uint8")
):
Expand Down Expand Up @@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
Need to satisfy tensor core schedule."
)
elif (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and layout == "NCHW4c"
and data.dtype in ["int8", "uint8"]
):
Expand Down Expand Up @@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
ic_chunk = in_channels // 4

if (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ["int8", "uint8"]
and kernel.dtype in ["int8", "uint8"]
and channels % groups == 0
Expand Down Expand Up @@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if (
target.kind.name in ["cuda", "vulkan"]
target.kind.name in ["cuda", "vulkan", "rocm"]
and data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
Expand Down
171 changes: 20 additions & 151 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,168 +24,42 @@

from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda


@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
padding = attrs.get_int_tuple("padding")
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda",
)
_, _, kh, kw = get_const_tuple(kernel.shape)
if (
2 < kh < 8
and 2 < kw < 8
and kh == kw
and stride_h == 1
and stride_w == 1
and dilation_h == 1
and dilation_w == 1
):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.cuda",
plevel=5,
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
name="conv2d_nhwc.gpu",
)
N, H, W, _ = get_const_tuple(data.shape)
KH, KW, CI, CO = get_const_tuple(kernel.shape)

(_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
N,
H,
W,
KH,
KW,
CI,
CO,
padding,
stride_h,
stride_w,
dilation_h,
dilation_w,
data.dtype,
kernel.dtype,
pre_flag=False,
)
strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target)

if judge_winograd_autotvm:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
name="conv2d_nhwc_winograd_direct.cuda",
plevel=5,
)
# add miopen implementation
if (
"miopen" in target.libs
and groups == 1
and layout == "NCHW"
and padding[0] == padding[2]
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=50,
)

if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation
if (
"miopen" in target.libs
and layout == "NCHW"
and padding[0] == padding[2]
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=15,
)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.cuda",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda",
)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy


@dense_strategy.register("rocm")
def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm",
)
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)

if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand All @@ -200,13 +74,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
@batch_matmul_strategy.register("rocm")
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
"""Batch matmul strategy for ROCM"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target)

if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ def is_aarch64_arm():
return "aarch64" in target.attrs.get("mtriple", "")


def is_rocm():
"""Checks whether we are compiling for a rocm/spirv target."""
target = tvm.target.Target.current(allow_none=False)
return "rocm" in target.keys


def is_vulkan():
"""Checks whether we are compiling for a vulkan/spirv target."""
target = tvm.target.Target.current(allow_none=False)
Expand Down Expand Up @@ -456,7 +462,7 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register(["cuda", "gpu"])
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
if is_vulkan():
if is_vulkan() or is_rocm():
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
if is_cuda():
Expand All @@ -467,7 +473,7 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types):

@qnn_dense_legalize.register(["cuda", "gpu"])
def _qnn_dense_legalize_cuda(attrs, inputs, types):
if is_vulkan():
if is_vulkan() or is_rocm():
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
if is_cuda():
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
"""Intrinsics for tensorization."""
from .x86 import *
from .arm_cpu import *
from .dot_product_common import *
from .rocm import *
55 changes: 55 additions & 0 deletions python/tvm/tir/tensor_intrin/dot_product_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Dot product related intrinsics."""
from tvm.script import tir as T
from .. import TensorIntrin


@T.prim_func
def dp4a_desc(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")


@T.prim_func
def dp4a_impl(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])

C[0] += T.call_pure_extern(
"__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
)


DP4A_INTRIN = "dp4a"

TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
47 changes: 47 additions & 0 deletions python/tvm/tir/tensor_intrin/rocm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for AMDGPU tensorization."""
from tvm.script import tir as T
from .. import TensorIntrin
from .dot_product_common import dp4a_desc


@T.prim_func
def sdot4(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])

C[0] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
T.uint32(4),
T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
T.int32(0),
T.bool(1),
dtype="int32"
)


AMDGPU_SDOT4_INTRIN = "sdot4"

TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _schedule_batch_matmul_int8(cfg, s, output):
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True

if "vulkan" in target.keys:
if "vulkan" in target.keys or "rocm" in target.keys:
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product

if do_tensorize:
Expand Down
Loading

0 comments on commit f8bc306

Please sign in to comment.