diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 08da62e640e10..4253d93f6500b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): if layout == "NCHW": assert kernel_layout == "OIHW" if ( - (target.kind.name in ["cuda", "vulkan"]) + (target.kind.name in ["cuda", "vulkan", "rocm"]) and data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8") ): @@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): Need to satisfy tensor core schedule." ) elif ( - (target.kind.name in ["cuda", "vulkan"]) + (target.kind.name in ["cuda", "vulkan", "rocm"]) and layout == "NCHW4c" and data.dtype in ["int8", "uint8"] ): @@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ic_chunk = in_channels // 4 if ( - (target.kind.name in ["cuda", "vulkan"]) + (target.kind.name in ["cuda", "vulkan", "rocm"]) and data.dtype in ["int8", "uint8"] and kernel.dtype in ["int8", "uint8"] and channels % groups == 0 @@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): b, i = get_const_tuple(data.shape) o, _ = get_const_tuple(weights.shape) if ( - target.kind.name in ["cuda", "vulkan"] + target.kind.name in ["cuda", "vulkan", "rocm"] and data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32" diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 1453128eeb677..a6cc94d2b116c 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -24,155 +24,33 @@ from .generic import * from .. import op as _op -from .cuda import judge_winograd, naive_schedule +from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda @conv2d_strategy.register("rocm") def conv2d_strategy_rocm(attrs, inputs, out_type, target): """conv2d rocm strategy""" - strategy = _op.OpStrategy() - data, kernel = inputs - dilation_h, dilation_w = attrs.get_int_tuple("dilation") groups = attrs.groups layout = attrs.data_layout - stride_h, stride_w = attrs.get_int_tuple("strides") - kernel_layout = attrs.kernel_layout padding = attrs.get_int_tuple("padding") - if dilation_h < 1 or dilation_w < 1: - raise ValueError("dilation should be positive value") - - if groups == 1: - if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8. - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), - name="conv2d_nchw.cuda", - ) - _, _, kh, kw = get_const_tuple(kernel.shape) - if ( - 2 < kh < 8 - and 2 < kw < 8 - and kh == kw - and stride_h == 1 - and stride_w == 1 - and dilation_h == 1 - and dilation_w == 1 - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd), - name="conv2d_nchw_winograd.cuda", - plevel=5, - ) - elif layout == "NHWC": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.gpu.conv2d_nhwc), - wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), - name="conv2d_nhwc.gpu", - ) - N, H, W, _ = get_const_tuple(data.shape) - KH, KW, CI, CO = get_const_tuple(kernel.shape) - (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd( - N, - H, - W, - KH, - KW, - CI, - CO, - padding, - stride_h, - stride_w, - dilation_h, - dilation_w, - data.dtype, - kernel.dtype, - pre_flag=False, - ) + strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target) - if judge_winograd_autotvm: - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct), - name="conv2d_nhwc_winograd_direct.cuda", - plevel=5, - ) + # add miopen implementation + if ( + "miopen" in target.libs + and groups == 1 + and layout == "NCHW" + and padding[0] == padding[2] + and padding[1] == padding[3] + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), + wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), + name="conv2d_nchw_miopen.rocm", + plevel=50, + ) - if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd", - plevel=15, - ) - elif layout == "HWCN": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_hwcn), - wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), - name="conv2d_hwcn.cuda", - ) - elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: - assert kernel_layout == "OIHW4o4i" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True), - wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), - name="conv2d_NCHWc_int8.cuda", - ) - else: - raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) - # add miopen implementation - if ( - "miopen" in target.libs - and layout == "NCHW" - and padding[0] == padding[2] - and padding[1] == padding[3] - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), - wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), - name="conv2d_nchw_miopen.rocm", - plevel=15, - ) - elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - if layout == "NCHW": - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.cuda", - ) - elif layout == "NHWC": - assert kernel_layout == "HWOI" - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.cuda", - ) - else: - raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) - else: # group_conv2d - if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.cuda", - ) - elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: - assert kernel_layout == "OIHW4o4i" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), - name="group_conv2d_NCHWc_int8.cuda", - ) - else: - raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -180,12 +58,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): def dense_strategy_rocm(attrs, inputs, out_type, target): """Dense strategy for ROCM""" assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_dense(topi.rocm.dense), - wrap_topi_schedule(topi.rocm.schedule_dense), - name="dense.rocm", - ) + strategy = dense_strategy_cuda(attrs, inputs, out_type, target) + if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( @@ -200,13 +74,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): @batch_matmul_strategy.register("rocm") def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): """Batch matmul strategy for ROCM""" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul), - wrap_topi_schedule(topi.cuda.schedule_batch_matmul), - name="batch_matmul.cuda", - plevel=10, - ) + strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target) + if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 93b1ad7a44a89..0d198c470bb66 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -387,6 +387,12 @@ def is_aarch64_arm(): return "aarch64" in target.attrs.get("mtriple", "") +def is_rocm(): + """Checks whether we are compiling for a rocm/spirv target.""" + target = tvm.target.Target.current(allow_none=False) + return "rocm" in target.keys + + def is_vulkan(): """Checks whether we are compiling for a vulkan/spirv target.""" target = tvm.target.Target.current(allow_none=False) @@ -456,7 +462,7 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): @qnn_conv2d_legalize.register(["cuda", "gpu"]) def _qnn_conv2d_legalize_cuda(attrs, inputs, types): - if is_vulkan(): + if is_vulkan() or is_rocm(): # prefers the dtypes to be same. Mixed type is not yet supported. return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) if is_cuda(): @@ -467,7 +473,7 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types): @qnn_dense_legalize.register(["cuda", "gpu"]) def _qnn_dense_legalize_cuda(attrs, inputs, types): - if is_vulkan(): + if is_vulkan() or is_rocm(): # prefers the dtypes to be same. Mixed type is not yet supported. return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) if is_cuda(): diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 62159851b3d47..4115c3b900709 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -18,3 +18,5 @@ """Intrinsics for tensorization.""" from .x86 import * from .arm_cpu import * +from .dot_product_common import * +from .rocm import * diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py b/python/tvm/tir/tensor_intrin/dot_product_common.py new file mode 100644 index 0000000000000..c531b80380e3c --- /dev/null +++ b/python/tvm/tir/tensor_intrin/dot_product_common.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Dot product related intrinsics.""" +from tvm.script import tir as T +from .. import TensorIntrin + + +@T.prim_func +def dp4a_desc( + A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), +) -> None: + with T.block("root"): + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.remap("R", [i]) + C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32") + + +@T.prim_func +def dp4a_impl( + A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), +) -> None: + with T.block("root"): + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) + + C[0] += T.call_pure_extern( + "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32" + ) + + +DP4A_INTRIN = "dp4a" + +TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl) diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py new file mode 100644 index 0000000000000..2095eb1635215 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/rocm.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Intrinsics for AMDGPU tensorization.""" +from tvm.script import tir as T +from .. import TensorIntrin +from .dot_product_common import dp4a_desc + + +@T.prim_func +def sdot4( + A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), +) -> None: + with T.block("root"): + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) + + C[0] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"), + T.uint32(4), + T.reinterpret(A.vload([0], "int8x4"), dtype="int32"), + T.reinterpret(B.vload([0], "int8x4"), dtype="int32"), + T.int32(0), + T.bool(1), + dtype="int32" + ) + + +AMDGPU_SDOT4_INTRIN = "sdot4" + +TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 5fce9d7a3f5de..859db6f00ebb3 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -372,7 +372,7 @@ def _schedule_batch_matmul_int8(cfg, s, output): target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys: + if "vulkan" in target.keys or "rocm" in target.keys: do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product if do_tensorize: diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index eaafe15e96003..7f52685e5d6db 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -34,7 +34,7 @@ @nn.conv2d_alter_layout.register(["cuda", "gpu"]) def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) - doit = "vulkan" in target.keys or "cuda" in target.keys + doit = "vulkan" in target.keys or "cuda" in target.keys or "rocm" in target.keys if not doit: return None dispatch_ctx = autotvm.task.DispatchContext.current @@ -87,7 +87,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if cfg.is_fallback: # if is fallback, clear query cache and return None autotvm.task.clear_fallback_cache(target, workload) do_new_layout = False - if "vulkan" in target.keys: + if "vulkan" in target.keys or "rocm" in target.keys: do_new_layout = "+dotprod" in target.mattr or target.supports_integer_dot_product if not do_new_layout: return None @@ -351,7 +351,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): """ target = tvm.target.Target.current(allow_none=False) - doit = "vulkan" in target.keys or "cuda" in target.keys + doit = "vulkan" in target.keys or "cuda" in target.keys or "rocm" in target.keys if not doit: return None # Dilation not supported yet. Return None if dilation is not (1, 1) diff --git a/python/tvm/topi/cuda/conv2d_int8.py b/python/tvm/topi/cuda/conv2d_int8.py index 15120f6a2532b..3c530445e92f0 100644 --- a/python/tvm/topi/cuda/conv2d_int8.py +++ b/python/tvm/topi/cuda/conv2d_int8.py @@ -312,8 +312,8 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output): _, rc_block = s[conv].split(rc_block, factor=4) target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys: - do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + # if "vulkan" in target.keys or "rocm" in target.keys: + # do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product if do_tensorize: dtypes = (pad_data.dtype, packed_kernel.dtype) s[conv].tensorize(rc_block, dp4a("shared", "shared", "local", dtypes)) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 862e7b5bc59d3..e7e651eefd8a5 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -173,8 +173,9 @@ def _schedule_dense_int8(cfg, s, output): ko, kt = cfg["tile_k"].apply(s, CC, ko) target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys: - do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + # if "vulkan" in target.keys or "rocm" in target.keys: + # do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + if do_tensorize: dtypes = (data.dtype, weight.dtype) s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes)) diff --git a/python/tvm/topi/cuda/tensor_intrin.py b/python/tvm/topi/cuda/tensor_intrin.py index c0596fc432623..6bb143140a416 100644 --- a/python/tvm/topi/cuda/tensor_intrin.py +++ b/python/tvm/topi/cuda/tensor_intrin.py @@ -71,7 +71,11 @@ def _instr(index): vec_y = yy.vload(0, dtype=vec_y_dtype) prev_z = 0 if index == 0 else zz.vload(0) - new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x, vec_y, prev_z) + # new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x, vec_y, prev_z) + new_z = tvm.tir.call_llvm_pure_intrin(zz_dtype, "llvm.amdgcn.sdot4", tvm.tir.const(4, "uint32"), + tvm.tir.call_intrin("int32", "tir.reinterpret", vec_x), + tvm.tir.call_intrin("int32", "tir.reinterpret", vec_y), + prev_z, True) ib.emit(zz.vstore(0, new_z)) return ib.get() diff --git a/python/tvm/topi/rocm/dense.py b/python/tvm/topi/rocm/dense.py index 2f3ce77cc7bac..983f235f0ec8e 100644 --- a/python/tvm/topi/rocm/dense.py +++ b/python/tvm/topi/rocm/dense.py @@ -19,85 +19,8 @@ from tvm import te from tvm import autotvm from tvm.contrib import rocblas -from .. import generic, nn +from .. import generic from .. import tag -from ..utils import traverse_inline - - -@autotvm.register_topi_compute("dense.rocm") -def dense(cfg, data, weight, bias=None, out_dtype=None): - """Dense operator for rocm backend. - - Parameters - ---------- - data : tvm.te.Tensor - 2-D with shape [batch, in_dim] - - weight : tvm.te.Tensor - 2-D with shape [out_dim, in_dim] - - bias : tvm.te.Tensor, optional - 1-D with shape [out_dim] - - out_dtype : str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.te.Tensor - 2-D with shape [batch, out_dim] - """ - assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense" - if bias is not None: - assert len(bias.shape) == 1 - if out_dtype is None: - out_dtype = data.dtype - return nn.dense(data, weight, bias, out_dtype) - - -@autotvm.register_topi_schedule("dense.rocm") -def schedule_dense(cfg, outs): - """Schedule for dense operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of dense - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for dense. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "dense": - Dense = op.output(0) - num_thread = 64 - k = Dense.op.reduce_axis[0] - ko, kf = s[Dense].split(k, factor=num_thread) - DenseF = s.rfactor(Dense, kf) - - if Dense.op in s.outputs: - Out = Dense - else: - Out = outs[0].op.output(0) - s[Dense].compute_at(s[Out], s[Out].op.axis[1]) - s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y")) - s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x")) - - tx = s[Dense].op.reduce_axis[0] - thread_x = te.thread_axis("threadIdx.x") - s[Dense].bind(tx, thread_x) - s[DenseF].compute_at(s[Dense], tx) - s[Dense].set_store_predicate(thread_x.var.equal(0)) - s[Out].set_store_predicate(thread_x.var.equal(0)) - - traverse_inline(s, outs[0].op, _callback) - return s @autotvm.register_topi_compute("dense_rocblas.rocm") diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 860118531e513..b93236b8cee61 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -346,45 +346,45 @@ def get_ref_data(): tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) targets = [ - ( - "cuda", - lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), - topi.cuda.schedule_conv2d_NCHWc_int8, - 4, - False, - ), - # Disable on CI since it does not support spirv int8 dot product # ( - # "vulkan -from_device=0", + # "cuda", # lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), # topi.cuda.schedule_conv2d_NCHWc_int8, # 4, # False, # ), + # Disable on CI since it does not support spirv int8 dot product + ( + "rocm", + lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), + topi.cuda.schedule_conv2d_NCHWc_int8, + 4, + False, + ), ] build_only_aarch64 = platform.machine() != "aarch64" - targets.append( - ( - "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", - topi.arm_cpu.conv2d_NCHWc_int8, - topi.arm_cpu.schedule_conv2d_NCHWc_int8, - 8, - build_only_aarch64, - ) - ) - - if in_dtype == "int8": - targets.append( - ( - "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", - topi.arm_cpu.conv2d_NCHWc_int8, - topi.arm_cpu.schedule_conv2d_NCHWc_int8, - 8, - build_only_aarch64, - ) - ) + # targets.append( + # ( + # "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod", + # topi.arm_cpu.conv2d_NCHWc_int8, + # topi.arm_cpu.schedule_conv2d_NCHWc_int8, + # 8, + # build_only_aarch64, + # ) + # ) + + # if in_dtype == "int8": + # targets.append( + # ( + # "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", + # topi.arm_cpu.conv2d_NCHWc_int8, + # topi.arm_cpu.schedule_conv2d_NCHWc_int8, + # 8, + # build_only_aarch64, + # ) + # ) for target, compute, schedule, oc_block_factor, build_only in targets: check_target(target, compute, schedule, oc_block_factor, build_only) @@ -517,6 +517,7 @@ def test_conv2d_nchw(in_dtype): with Int8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 3, 1, 1) + return verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 3, 2, 1) verify_conv2d_NCHWc_int8(in_dtype, 1, 64, 56, 128, 1, 2, 0) diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 8f58415da3297..2826d70ba0eda 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -52,7 +52,6 @@ ], "mali": [(topi.mali.dense, topi.mali.schedule_dense)], "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)], - "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)], "hls": [(topi.nn.dense, topi.hls.schedule_dense)], }