Skip to content

Commit

Permalink
share op strategy between cuda and rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 14, 2022
1 parent 762c7e8 commit 0225f2b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 241 deletions.
181 changes: 20 additions & 161 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,178 +24,42 @@

from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda


@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
padding = attrs.get_int_tuple("padding")
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda",
)
_, _, kh, kw = get_const_tuple(kernel.shape)
if (
2 < kh < 8
and 2 < kw < 8
and kh == kw
and stride_h == 1
and stride_w == 1
and dilation_h == 1
and dilation_w == 1
):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.cuda",
plevel=5,
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
name="conv2d_nhwc.gpu",
)
N, H, W, _ = get_const_tuple(data.shape)
KH, KW, CI, CO = get_const_tuple(kernel.shape)
strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target)

(_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
N,
H,
W,
KH,
KW,
CI,
CO,
padding,
stride_h,
stride_w,
dilation_h,
dilation_w,
data.dtype,
kernel.dtype,
pre_flag=False,
)

if judge_winograd_autotvm:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
name="conv2d_nhwc_winograd_direct.cuda",
plevel=5,
)
# add miopen implementation
if (
"miopen" in target.libs
and groups == 1
and layout == "NCHW"
and padding[0] == padding[2]
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=50,
)

if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation
if (
"miopen" in target.libs
and layout == "NCHW"
and padding[0] == padding[2]
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=15,
)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.cuda",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda",
)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy


@dense_strategy.register("rocm")
def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm",
)
data, weights = inputs
if (data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8),
wrap_topi_schedule(topi.cuda.schedule_dense_int8),
name="dense_int8.rocm",
)
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)

if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand All @@ -210,13 +74,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
@batch_matmul_strategy.register("rocm")
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
"""Batch matmul strategy for ROCM"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target)

if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _schedule_dense_int8(cfg, s, output):
do_tensorize = True
# if "vulkan" in target.keys or "rocm" in target.keys:
# do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
assert False

if do_tensorize:
dtypes = (data.dtype, weight.dtype)
s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes))
Expand Down
79 changes: 1 addition & 78 deletions python/tvm/topi/rocm/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,85 +19,8 @@
from tvm import te
from tvm import autotvm
from tvm.contrib import rocblas
from .. import generic, nn
from .. import generic
from .. import tag
from ..utils import traverse_inline


@autotvm.register_topi_compute("dense.rocm")
def dense(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for rocm backend.
Parameters
----------
data : tvm.te.Tensor
2-D with shape [batch, in_dim]
weight : tvm.te.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.te.Tensor, optional
1-D with shape [out_dim]
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
out_dtype = data.dtype
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense.rocm")
def schedule_dense(cfg, outs):
"""Schedule for dense operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
Dense = op.output(0)
num_thread = 64
k = Dense.op.reduce_axis[0]
ko, kf = s[Dense].split(k, factor=num_thread)
DenseF = s.rfactor(Dense, kf)

if Dense.op in s.outputs:
Out = Dense
else:
Out = outs[0].op.output(0)
s[Dense].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))

tx = s[Dense].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[Dense].bind(tx, thread_x)
s[DenseF].compute_at(s[Dense], tx)
s[Dense].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("dense_rocblas.rocm")
Expand Down
1 change: 0 additions & 1 deletion tests/python/topi/python/test_topi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
],
"mali": [(topi.mali.dense, topi.mali.schedule_dense)],
"bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
"rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)],
"hls": [(topi.nn.dense, topi.hls.schedule_dense)],
}

Expand Down

0 comments on commit 0225f2b

Please sign in to comment.