From 0225f2bfe3f413cd4764c2dba6c922af2520146b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Apr 2022 08:56:10 +0900 Subject: [PATCH] share op strategy between cuda and rocm --- python/tvm/relay/op/strategy/rocm.py | 181 +++----------------- python/tvm/topi/cuda/dense.py | 2 +- python/tvm/topi/rocm/dense.py | 79 +-------- tests/python/topi/python/test_topi_dense.py | 1 - 4 files changed, 22 insertions(+), 241 deletions(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 106ae63fded80..a6cc94d2b116c 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -24,155 +24,33 @@ from .generic import * from .. import op as _op -from .cuda import judge_winograd, naive_schedule +from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda @conv2d_strategy.register("rocm") def conv2d_strategy_rocm(attrs, inputs, out_type, target): """conv2d rocm strategy""" - strategy = _op.OpStrategy() - data, kernel = inputs - dilation_h, dilation_w = attrs.get_int_tuple("dilation") groups = attrs.groups layout = attrs.data_layout - stride_h, stride_w = attrs.get_int_tuple("strides") - kernel_layout = attrs.kernel_layout padding = attrs.get_int_tuple("padding") - if dilation_h < 1 or dilation_w < 1: - raise ValueError("dilation should be positive value") - if groups == 1: - if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8. - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), - name="conv2d_nchw.cuda", - ) - _, _, kh, kw = get_const_tuple(kernel.shape) - if ( - 2 < kh < 8 - and 2 < kw < 8 - and kh == kw - and stride_h == 1 - and stride_w == 1 - and dilation_h == 1 - and dilation_w == 1 - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd), - name="conv2d_nchw_winograd.cuda", - plevel=5, - ) - elif layout == "NHWC": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.gpu.conv2d_nhwc), - wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), - name="conv2d_nhwc.gpu", - ) - N, H, W, _ = get_const_tuple(data.shape) - KH, KW, CI, CO = get_const_tuple(kernel.shape) + strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target) - (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd( - N, - H, - W, - KH, - KW, - CI, - CO, - padding, - stride_h, - stride_w, - dilation_h, - dilation_w, - data.dtype, - kernel.dtype, - pre_flag=False, - ) - - if judge_winograd_autotvm: - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct), - name="conv2d_nhwc_winograd_direct.cuda", - plevel=5, - ) + # add miopen implementation + if ( + "miopen" in target.libs + and groups == 1 + and layout == "NCHW" + and padding[0] == padding[2] + and padding[1] == padding[3] + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), + wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), + name="conv2d_nchw_miopen.rocm", + plevel=50, + ) - if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd", - plevel=15, - ) - elif layout == "HWCN": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_hwcn), - wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), - name="conv2d_hwcn.cuda", - ) - elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: - assert kernel_layout == "OIHW4o4i" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True), - wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), - name="conv2d_NCHWc_int8.cuda", - ) - else: - raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) - # add miopen implementation - if ( - "miopen" in target.libs - and layout == "NCHW" - and padding[0] == padding[2] - and padding[1] == padding[3] - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), - wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), - name="conv2d_nchw_miopen.rocm", - plevel=15, - ) - elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - if layout == "NCHW": - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.cuda", - ) - elif layout == "NHWC": - assert kernel_layout == "HWOI" - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.cuda", - ) - else: - raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) - else: # group_conv2d - if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.cuda", - ) - elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: - assert kernel_layout == "OIHW4o4i" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), - name="group_conv2d_NCHWc_int8.cuda", - ) - else: - raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -180,22 +58,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): def dense_strategy_rocm(attrs, inputs, out_type, target): """Dense strategy for ROCM""" assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_dense(topi.rocm.dense), - wrap_topi_schedule(topi.rocm.schedule_dense), - name="dense.rocm", - ) - data, weights = inputs - if (data.dtype == "int8" - and weights.dtype == "int8" - and out_type.dtype == "int32" - ): - strategy.add_implementation( - wrap_compute_dense(topi.cuda.dense_int8), - wrap_topi_schedule(topi.cuda.schedule_dense_int8), - name="dense_int8.rocm", - ) + strategy = dense_strategy_cuda(attrs, inputs, out_type, target) + if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( @@ -210,13 +74,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): @batch_matmul_strategy.register("rocm") def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): """Batch matmul strategy for ROCM""" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True), - wrap_topi_schedule(topi.cuda.schedule_batch_matmul), - name="batch_matmul.cuda", - plevel=10, - ) + strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target) + if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 582d3e62303a4..e7e651eefd8a5 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -175,7 +175,7 @@ def _schedule_dense_int8(cfg, s, output): do_tensorize = True # if "vulkan" in target.keys or "rocm" in target.keys: # do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product - assert False + if do_tensorize: dtypes = (data.dtype, weight.dtype) s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes)) diff --git a/python/tvm/topi/rocm/dense.py b/python/tvm/topi/rocm/dense.py index 2f3ce77cc7bac..983f235f0ec8e 100644 --- a/python/tvm/topi/rocm/dense.py +++ b/python/tvm/topi/rocm/dense.py @@ -19,85 +19,8 @@ from tvm import te from tvm import autotvm from tvm.contrib import rocblas -from .. import generic, nn +from .. import generic from .. import tag -from ..utils import traverse_inline - - -@autotvm.register_topi_compute("dense.rocm") -def dense(cfg, data, weight, bias=None, out_dtype=None): - """Dense operator for rocm backend. - - Parameters - ---------- - data : tvm.te.Tensor - 2-D with shape [batch, in_dim] - - weight : tvm.te.Tensor - 2-D with shape [out_dim, in_dim] - - bias : tvm.te.Tensor, optional - 1-D with shape [out_dim] - - out_dtype : str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.te.Tensor - 2-D with shape [batch, out_dim] - """ - assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense" - if bias is not None: - assert len(bias.shape) == 1 - if out_dtype is None: - out_dtype = data.dtype - return nn.dense(data, weight, bias, out_dtype) - - -@autotvm.register_topi_schedule("dense.rocm") -def schedule_dense(cfg, outs): - """Schedule for dense operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of dense - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for dense. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "dense": - Dense = op.output(0) - num_thread = 64 - k = Dense.op.reduce_axis[0] - ko, kf = s[Dense].split(k, factor=num_thread) - DenseF = s.rfactor(Dense, kf) - - if Dense.op in s.outputs: - Out = Dense - else: - Out = outs[0].op.output(0) - s[Dense].compute_at(s[Out], s[Out].op.axis[1]) - s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y")) - s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x")) - - tx = s[Dense].op.reduce_axis[0] - thread_x = te.thread_axis("threadIdx.x") - s[Dense].bind(tx, thread_x) - s[DenseF].compute_at(s[Dense], tx) - s[Dense].set_store_predicate(thread_x.var.equal(0)) - s[Out].set_store_predicate(thread_x.var.equal(0)) - - traverse_inline(s, outs[0].op, _callback) - return s @autotvm.register_topi_compute("dense_rocblas.rocm") diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 8f58415da3297..2826d70ba0eda 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -52,7 +52,6 @@ ], "mali": [(topi.mali.dense, topi.mali.schedule_dense)], "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)], - "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)], "hls": [(topi.nn.dense, topi.hls.schedule_dense)], }