Skip to content

Commit

Permalink
[WIP] adding DP4A support to rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 13, 2022
1 parent 4f8f308 commit 65b8bcf
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 12 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if layout == "NCHW":
assert kernel_layout == "OIHW"
if (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ("int8", "uint8")
and kernel.dtype in ("int8", "uint8")
):
Expand Down Expand Up @@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
Need to satisfy tensor core schedule."
)
elif (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and layout == "NCHW4c"
and data.dtype in ["int8", "uint8"]
):
Expand Down Expand Up @@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
ic_chunk = in_channels // 4

if (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ["int8", "uint8"]
and kernel.dtype in ["int8", "uint8"]
and channels % groups == 0
Expand Down Expand Up @@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if (
target.kind.name in ["cuda", "vulkan"]
target.kind.name in ["cuda", "vulkan", "rocm"]
and data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ def is_aarch64_arm():
return "aarch64" in target.attrs.get("mtriple", "")


def is_rocm():
"""Checks whether we are compiling for a rocm/spirv target."""
target = tvm.target.Target.current(allow_none=False)
return "rocm" in target.keys


def is_vulkan():
"""Checks whether we are compiling for a vulkan/spirv target."""
target = tvm.target.Target.current(allow_none=False)
Expand Down Expand Up @@ -456,7 +462,7 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register(["cuda", "gpu"])
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
if is_vulkan():
if is_vulkan() or is_rocm():
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
if is_cuda():
Expand All @@ -467,7 +473,7 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types):

@qnn_dense_legalize.register(["cuda", "gpu"])
def _qnn_dense_legalize_cuda(attrs, inputs, types):
if is_vulkan():
if is_vulkan() or is_rocm():
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
if is_cuda():
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _schedule_batch_matmul_int8(cfg, s, output):
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True

if "vulkan" in target.keys:
if "vulkan" in target.keys or "rocm" in target.keys:
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product

if do_tensorize:
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@nn.conv2d_alter_layout.register(["cuda", "gpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
doit = "vulkan" in target.keys or "cuda" in target.keys
doit = "vulkan" in target.keys or "cuda" in target.keys or "rocm" in target.keys
if not doit:
return None
dispatch_ctx = autotvm.task.DispatchContext.current
Expand Down Expand Up @@ -87,7 +87,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
do_new_layout = False
if "vulkan" in target.keys:
if "vulkan" in target.keys or "rocm" in target.keys:
do_new_layout = "+dotprod" in target.mattr or target.supports_integer_dot_product
if not do_new_layout:
return None
Expand Down Expand Up @@ -351,7 +351,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
"""

target = tvm.target.Target.current(allow_none=False)
doit = "vulkan" in target.keys or "cuda" in target.keys
doit = "vulkan" in target.keys or "cuda" in target.keys or "rocm" in target.keys
if not doit:
return None
# Dilation not supported yet. Return None if dilation is not (1, 1)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
_, rc_block = s[conv].split(rc_block, factor=4)
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
if "vulkan" in target.keys:
if "vulkan" in target.keys or "rocm" in target.keys:
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
if do_tensorize:
dtypes = (pad_data.dtype, packed_kernel.dtype)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _schedule_dense_int8(cfg, s, output):
ko, kt = cfg["tile_k"].apply(s, CC, ko)
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
if "vulkan" in target.keys:
if "vulkan" in target.keys or "rocm" in target.keys:
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
if do_tensorize:
dtypes = (data.dtype, weight.dtype)
Expand Down

0 comments on commit 65b8bcf

Please sign in to comment.