diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 3f72bdc4b667..7cb4b09b8805 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -176,10 +176,13 @@ def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layou else: KH, KW, CIG, CO = get_const_tuple(kernel.shape) - pt, pl, pb, pr = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW))) dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) + pt, pl, pb, pr = get_pad_tuple( + padding, + (get_const_int((KH - 1) * dilation_h + 1), get_const_int((KW - 1) * dilation_w + 1)), + ) GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index a3639b57e7e0..48ffb8c6d9ff 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -24,7 +24,7 @@ from .dilate import dilate from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify +from ..utils import simplify, get_const_tuple # workload description of depthwise-conv2d Workload = namedtuple( @@ -50,11 +50,47 @@ ) -def _get_workload(data, kernel, stride, padding, dilation, out_dtype): - """Get the workload structure.""" - _, in_channel, height, width = [x.value for x in data.shape] - channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape] - out_channel = channel * channel_multiplier +def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layout="NCHW"): + """Get the workload structure for a depthwise conv2d. + + Input data and filter should use NCHW layout. + """ + if data_layout == "NCHW": + _, in_channel, height, width = get_const_tuple(data.shape) + filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape) + elif data_layout == "NHWC": + _, height, width, in_channel = get_const_tuple(data.shape) + kh, kw, filter_channel, channel_multiplier = get_const_tuple(kernel.shape) + elif data_layout == "NCHWc": + _, in_channel_chunk, height, width, in_channel_block = get_const_tuple(data.shape) + in_channel = in_channel_chunk * in_channel_block + ( + filter_channel_chunk, + cm_chunk, + kh, + kw, + cm_block, + filter_channel_block, + ) = get_const_tuple(kernel.shape) + filter_channel = filter_channel_chunk * filter_channel_block + channel_multiplier = cm_chunk * cm_block + + assert ( + in_channel_block == filter_channel_block + ), "Incorrect dimensions, data has block size {}, but filter has block size {}".format( + in_channel_block, filter_channel_block + ) + + else: + raise ValueError("Data layout {} not supported".format(data_layout)) + + assert ( + in_channel == filter_channel + ), "Incorrect dimensions, data has {} channels but filter expects {} channels".format( + in_channel, filter_channel + ) + + out_channel = filter_channel * channel_multiplier dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) @@ -102,8 +138,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No Filter : tvm.te.Tensor 4-D with shape [in_channel, channel_multiplier, filter_height, filter_width] - stride : tuple of two ints - The spatial stride along height and width + stride : int or a list/tuple of two ints + The spatial stride, or (stride_height, stride_width). padding : int or str Padding size, or ['VALID', 'SAME'] diff --git a/python/tvm/topi/nn/mapping.py b/python/tvm/topi/nn/mapping.py index c048fc86d4d5..0e0b1825df30 100644 --- a/python/tvm/topi/nn/mapping.py +++ b/python/tvm/topi/nn/mapping.py @@ -29,7 +29,7 @@ def scale_shift_nchw(Input, Scale, Shift): Parameters ---------- Input : tvm.te.Tensor - Input tensor, layout is NCHW + 4-D input tensor, NCHW layout [batch, channel, height, width] Scale : tvm.te.Tensor Scale tensor, 1-D of size channel number @@ -54,7 +54,7 @@ def scale_shift_nhwc(Input, Scale, Shift): Parameters ---------- Input : tvm.te.Tensor - Input tensor, layout is NHWC + 4-D input tensor, NHWC layout [batch, height, width, channel] Scale : tvm.te.Tensor Scale tensor, 1-D of size channel number @@ -70,3 +70,30 @@ def scale_shift_nhwc(Input, Scale, Shift): return te.compute( Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name="ScaleShift" ) + + +@tvm.te.tag_scope(tag=tag.BROADCAST) +def scale_shift_nchwc(Input, Scale, Shift): + """Batch normalization operator in inference. + + Parameters + ---------- + Input : tvm.te.Tensor + 5-D input tensor, NCHWc layout [batch, channel_chunk, height, width, channel_block] + + Scale : tvm.te.Tensor + Scale tensor, 2-D of size [channel_chunk, channel_block] + + Shift : tvm.te.Tensor + Shift tensor, 2-D of size [channel_chunk, channel_block] + + Returns + ------- + Output : tvm.te.Tensor + Output tensor, layout is NHWC + """ + return te.compute( + Input.shape, + lambda b, cc, i, j, cb: Input[b, cc, i, j, cb] * Scale[cc, cb] + Shift[cc, cb], + name="ScaleShift", + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 610c51668835..d10c49f5c084 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -32,7 +32,11 @@ from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .correlation_nchw_python import correlation_nchw_python from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python -from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc +from .depthwise_conv2d_python import ( + depthwise_conv2d_python_nchw, + depthwise_conv2d_python_nhwc, + depthwise_conv2d_python_nchwc, +) from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python from .resize_python import resize1d_python, resize2d_python, resize3d_python diff --git a/python/tvm/topi/testing/depthwise_conv2d_python.py b/python/tvm/topi/testing/depthwise_conv2d_python.py index 02964ecfae3b..1ec64b7e7b82 100644 --- a/python/tvm/topi/testing/depthwise_conv2d_python.py +++ b/python/tvm/topi/testing/depthwise_conv2d_python.py @@ -89,6 +89,63 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): return output_np +def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding): + """Depthwise convolution operator in NCHWc layout. + + Parameters + ---------- + input_np : numpy.ndarray + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + filter_np : numpy.ndarray + 6-D with shape [out_channel_chunk, channel_multiplier_chunk, + filter_height, filter_width, + channel_multiplier_block, out_channel_block] + + stride : list / tuple of 2 ints + [stride_height, stride_width] + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + output_np : np.ndarray + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + # Transform to NCHW + batch_size, in_channel_chunk, in_height, in_width, in_channel_block = input_np.shape + input_nchw = input_np.transpose(0, 1, 4, 2, 3).reshape( + (batch_size, in_channel_chunk * in_channel_block, in_height, in_width) + ) + + ( + out_channel_chunk, + channel_multiplier_chunk, + filter_height, + filter_width, + channel_multiplier_block, + out_channel_block, + ) = filter_np.shape + filter_nchw = filter_np.transpose(0, 5, 1, 4, 2, 3).reshape( + ( + out_channel_chunk * out_channel_block, + channel_multiplier_chunk * channel_multiplier_block, + filter_height, + filter_width, + ) + ) + + # Perform conv2d + output_np = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding) + + # Transform back + batch_size, out_channel, out_height, out_width = output_np.shape + return output_np.reshape( + (batch_size, out_channel_chunk, out_channel_block, out_height, out_width) + ).transpose(0, 1, 3, 4, 2) + + def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): """Depthwise convolution operator in nchw layout. diff --git a/python/tvm/topi/x86/group_conv2d.py b/python/tvm/topi/x86/group_conv2d.py index 0501c5534cf2..0e10052e2428 100644 --- a/python/tvm/topi/x86/group_conv2d.py +++ b/python/tvm/topi/x86/group_conv2d.py @@ -43,7 +43,9 @@ def schedule_group_conv2d_nchw(outs): return schedule_group_conv2d_nchwc(outs) -def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"): +def _get_default_config( + cfg, data, kernel, strides, padding, dilation, groups, out_dtype, layout="NCHW" +): """ Get default schedule config for the workload """ @@ -55,7 +57,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, static_data_shape.append(dim) data = te.placeholder(static_data_shape, dtype=data.dtype) - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout) _fallback_schedule(cfg, wkl) @@ -159,6 +161,7 @@ def group_conv2d_nchw_spatial_pack( ), strides, padding, + dilation, groups, out_dtype, ) diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index 8dbe94b45a2f..2a4865c6dd8d 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -16,13 +16,15 @@ # under the License. """Example code to do convolution.""" +import sys + +import pytest import numpy as np + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi +from tvm import autotvm, te, topi import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import cudnn from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.topi.nn.conv2d import _get_workload @@ -30,238 +32,272 @@ import tvm.testing +dtype = tvm.testing.parameter("float32") -def verify_conv2d_nchw( - batch, - in_channel, - in_size, - num_filter, - kernel, - stride, - padding, - dilation=1, - add_bias=False, - add_relu=False, - use_cudnn=False, -): - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - print( - "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation) - ) +@tvm.testing.fixture +def input_shape(batch, in_channel, in_size): + return (batch, in_channel, in_size, in_size) - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name="A") - W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W") - bias = te.placeholder((num_filter, 1, 1), name="bias") - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - bias_shape = get_const_tuple(bias.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = np.random.uniform(size=bias_shape).astype(dtype) - dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - if add_bias: - c_np += b_np - if add_relu: - c_np = np.maximum(c_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def verify_workload_padding(): - _, _, out_height, out_width = get_const_tuple(c_np.shape) - wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype) - - # check if tile_ow candidates are the factors of the right output weight. - cfg = autotvm.get_config() - _fallback_schedule(cfg, wkl) - ow_tile = np.prod(cfg["tile_ow"].size) - tvm.testing.assert_allclose(ow_tile, out_width) +@tvm.testing.fixture +def weight_shape(num_filter, in_channel, kernel): + return (num_filter, in_channel, kernel, kernel) - def check_target(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - if "cudnn" in target: - fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn - else: - fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(target) +@tvm.testing.fixture +def bias_shape(num_filter): + return (num_filter, 1, 1) - with tvm.target.Target(target): - if "cudnn" in target: - C = fcompute( - A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype - ) + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + input_shape, + weight_shape, + bias_shape, + dtype, + stride, + padding, + dilation, + add_bias, + apply_relu, +): + a_np = np.random.uniform(size=input_shape).astype(dtype) + w_np = np.random.uniform(size=weight_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + + if add_bias: + c_np = c_np + b_np + if apply_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + +class BaseConv2DTests: + add_bias = tvm.testing.parameter(False) + apply_relu = tvm.testing.parameter(False) + dilation = tvm.testing.parameter(1) + batch = tvm.testing.parameter(1) + + def test_conv2d_nchw( + self, + target, + dev, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dtype, + ref_data, + dilation, + add_bias, + apply_relu, + ): + target = tvm.target.Target(target) + is_cudnn_target = target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []) + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + + a_np, w_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + bias = te.placeholder(b_np.shape, name="bias") + + with autotvm.tophub.context(target): # load tophub pre-tuned parameters + if is_cudnn_target: + fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn else: - C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) - if add_bias: - C = topi.add(C, bias) - if add_relu: - C = topi.nn.relu(C) - s = fschedule([C]) - - if "llvm" in target: - verify_workload_padding() - - a = tvm.nd.array(a_np, dev) - w = tvm.nd.array(w_np, dev) - b = tvm.nd.array(b_np, dev) - - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) - if add_bias: + fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(target) + + with target: + if is_cudnn_target: + C = fcompute( + A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype + ) + else: + C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) + if add_bias: + C = topi.add(C, bias) + if apply_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(b_np, dev) + + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) func = tvm.build( s, [A, W, bias, C], target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" + name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) - else: - func = tvm.build( - s, - [A, W, C], - target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), - ) - func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + + @tvm.testing.parametrize_targets("llvm") + def test_workload_padding( + self, + target, + input_shape, + weight_shape, + stride, + padding, + dilation, + dtype, + ref_data, + ): + a_np, w_np, b_np, c_np = ref_data + _, _, out_height, out_width = c_np.shape + + A = te.placeholder(input_shape, name="A", dtype=dtype) + W = te.placeholder(weight_shape, name="W", dtype=dtype) - for target, dev in tvm.testing.enabled_targets(): - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - check_target(target) - - if use_cudnn: - check_target("cuda -model=unknown -libs=cudnn") - if ("opencl", tvm.device("opencl")) in tvm.testing.enabled_targets(): - check_target("opencl -device=intel_graphics") - - -@tvm.testing.uses_gpu -def test_conv2d_nchw(): - # ResNet18 workloads - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - - # bias, relu - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) - - # dilation = 2 - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2) - - # batch size - verify_conv2d_nchw(4, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(9, 64, 56, 64, 3, 1, 1) - - # weird workloads - verify_conv2d_nchw(2, 2, 2, 2, 2, 2, 2) - verify_conv2d_nchw(3, 3, 3, 3, 3, 3, 3) - verify_conv2d_nchw(4, 4, 4, 4, 4, 4, 4) - verify_conv2d_nchw(5, 5, 5, 5, 5, 5, 5) - verify_conv2d_nchw(6, 6, 6, 6, 6, 6, 6) - - # disable these tests due to some bugs of llvm with nvptx - # verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1) - # verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2) - # verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1) - - # inception v3 workloads - verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0) - verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0) - verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0) - verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0) - verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1) - verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1) - verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0) - verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0) - verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0) - verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0) - verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0) - verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3) - verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0) - # disable these tests due to some bugs of llvm with nvptx - # verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0) - verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3) - verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0) - verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0) - verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0) - verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1) - verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0) - verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1) - verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0) - verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1) - verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1) - verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1) - verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1) - - # Asymmetric padding - verify_conv2d_nchw(1, 3, 35, 64, 7, 2, (0, 0, 1, 1)) - verify_conv2d_nchw(1, 64, 8, 128, 3, 1, (3, 3, 2, 2)) - verify_conv2d_nchw(1, 64, 8, 64, 1, 1, (1, 2, 2, 1)) - verify_conv2d_nchw(1, 64, 17, 192, 1, 1, (1, 2)) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (3, 1)) - verify_conv2d_nchw(1, 128, 8, 384, 3, 1, (0, 2)) - verify_conv2d_nchw(1, 64, 35, 64, 3, 1, (1, 2), use_cudnn=True) - verify_conv2d_nchw(1, 64, 8, 64, 1, 1, "VALID") - verify_conv2d_nchw(1, 388, 8, 64, 3, 1, "VALID") - verify_conv2d_nchw(1, 64, 10, 48, 3, 1, "VALID", use_cudnn=True) - verify_conv2d_nchw(1, 512, 19, 64, 1, 1, "SAME") - verify_conv2d_nchw(1, 64, 5, 32, 2, 1, "SAME") - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "SAME", use_cudnn=True) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True) - verify_conv2d_nchw(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "VALID", add_bias=True, add_relu=True) - verify_conv2d_nchw(1, 64, 8, 64, 24, 1, "SAME", add_bias=True, add_relu=True) - verify_conv2d_nchw(1, 32, 35, 64, 7, 2, (0, 0, 2, 2)) + with tvm.target.Target(target): + wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype) + + # check if tile_ow candidates are the factors of the right output weight. + cfg = autotvm.get_config() + _fallback_schedule(cfg, wkl) + ow_tile = np.prod(cfg["tile_ow"].size) + + tvm.testing.assert_allclose(ow_tile, out_width) + + +class TestResNet18Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 224, 64, 7, 2, 3), + (64, 56, 64, 3, 1, 1), + (64, 56, 64, 1, 1, 0), + (64, 56, 128, 3, 2, 1), + (64, 56, 128, 1, 2, 0), + (128, 28, 128, 3, 1, 1), + (128, 28, 256, 3, 2, 1), + (128, 28, 256, 1, 2, 0), + (256, 14, 256, 3, 1, 1), + (256, 14, 512, 3, 2, 1), + (256, 14, 512, 1, 2, 0), + (512, 7, 512, 3, 1, 1), + ) + + +class TestInceptionV3Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 299, 32, 3, 2, 0), + (32, 149, 32, 3, 1, 0), + (32, 147, 64, 3, 1, 1), + (64, 73, 80, 1, 1, 0), + (80, 73, 192, 3, 1, 0), + (192, 35, 64, 1, 1, 0), + (192, 35, 48, 1, 1, 0), + (48, 35, 64, 5, 1, 2), + (64, 35, 96, 3, 1, 1), + (96, 35, 96, 3, 1, 1), + (192, 35, 32, 1, 1, 0), + (256, 35, 64, 1, 1, 0), + (256, 35, 48, 1, 1, 0), + (288, 35, 64, 1, 1, 0), + (288, 35, 48, 1, 1, 0), + (288, 35, 384, 3, 2, 0), + (96, 35, 96, 3, 2, 0), + (768, 17, 192, 1, 1, 0), + (768, 17, 128, 1, 1, 0), + (128, 17, 128, 1, 1, 0), + (128, 17, 192, 7, 1, 3), + (128, 17, 128, 7, 1, 3), + (128, 17, 192, 1, 1, 0), + (768, 17, 160, 1, 1, 0), + # disable these tests due to some bugs of llvm with nvptx + # (160, 17, 160, 1, 1, 0), + (160, 17, 192, 7, 1, 3), + (160, 17, 160, 7, 1, 3), + (160, 17, 192, 1, 1, 0), + (192, 17, 192, 1, 1, 0), + (192, 17, 192, 7, 1, 3), + (192, 17, 320, 3, 2, 0), + (192, 17, 192, 3, 2, 0), + (1280, 8, 320, 1, 1, 0), + (1280, 8, 384, 1, 1, 0), + (384, 8, 384, 1, 1, 0), + (384, 8, 384, 3, 1, 1), + (1280, 8, 448, 1, 1, 0), + (448, 8, 384, 3, 1, 1), + (1280, 8, 192, 1, 1, 0), + (2048, 8, 320, 1, 1, 0), + (2048, 8, 384, 1, 1, 0), + (2048, 8, 448, 1, 1, 0), + (2048, 8, 192, 1, 1, 0), + (1024, 19, 84, 3, 1, 1), + (2048, 10, 126, 3, 1, 1), + (512, 5, 126, 3, 1, 1), + (256, 3, 126, 3, 1, 1), + ) + + +class TestWeirdWorkloads(BaseConv2DTests): + batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (2, 2, 2, 2, 2, 2, 2), + (3, 3, 3, 3, 3, 3, 3), + (4, 4, 4, 4, 4, 4, 4), + (5, 5, 5, 5, 5, 5, 5), + (6, 6, 6, 6, 6, 6, 6), + # disable these tests due to some bugs of llvm with nvptx + # (1, 1, 1, 1, 1, 1, 1), + # (2, 13, 71, 59, 3, 1, 1), + ) + + +class TestAsymmetricPadding(BaseConv2DTests): + dilation = tvm.testing.parameter(1, 2) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 35, 64, 7, 2, (0, 0, 1, 1)), + (64, 8, 128, 3, 1, (3, 3, 2, 2)), + (64, 8, 64, 1, 1, (1, 2, 2, 1)), + (64, 17, 192, 1, 1, (1, 2)), + (64, 8, 64, 3, 1, (3, 1)), + (128, 8, 384, 3, 1, (0, 2)), + (64, 35, 64, 3, 1, (1, 2)), + (64, 8, 64, 1, 1, "VALID"), + (388, 8, 64, 3, 1, "VALID"), + (64, 10, 48, 3, 1, "VALID"), + (512, 19, 64, 1, 1, "SAME"), + (64, 5, 32, 2, 1, "SAME"), + (64, 8, 64, 3, 1, "SAME"), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (64, 8, 64, 24, 1, "SAME"), + (32, 35, 64, 7, 2, (0, 0, 2, 2)), + ) + + +class TestBatchSize(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + ) + batch = tvm.testing.parameter(1, 4, 9) + + +class TestBiasRelu(BaseConv2DTests): + add_relu = tvm.testing.parameter(True, False) + add_bias = tvm.testing.parameter(True, False) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (64, 8, 64, 24, 1, "SAME"), + ) if __name__ == "__main__": - test_conv2d_nchw() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 76093c51b4c8..092ac9df5f9a 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -14,561 +14,375 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import numpy as np +import pytest + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi +import tvm.testing import tvm.topi.testing -import numpy as np + +from tvm import autotvm, te, topi from tvm.topi.utils import get_const_tuple from tvm.topi.nn.utils import get_pad_tuple from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.depthwise_conv2d import _get_workload from tvm.topi.x86.depthwise_conv2d import _fallback_schedule -import tvm.testing -_depthwise_conv2d_nchw_implement = { - "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], - "arm_cpu": [ - (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw), - ( - topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack, - topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack, - ), - ], - "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)], - "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)], - "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)], - "intel_graphics": [ - ( - topi.intel_graphics.depthwise_conv2d_nchw, - topi.intel_graphics.schedule_depthwise_conv2d_nchw, - ) - ], +_depthwise_conv2d_implement = { + "NCHW": { + "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], + "arm_cpu": [ + (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw), + ( + topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack, + topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack, + ), + ], + "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)], + "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)], + "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)], + "intel_graphics": [ + ( + topi.intel_graphics.depthwise_conv2d_nchw, + topi.intel_graphics.schedule_depthwise_conv2d_nchw, + ) + ], + }, + "NHWC": { + "generic": [(topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc)], + "arm_cpu": [ + ( + topi.arm_cpu.compute_depthwise_conv2d_nhwc, + topi.arm_cpu.schedule_depthwise_conv2d_nhwc, + ) + ], + "gpu": [(topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc)], + }, + "NCHWc": { + "generic": [(topi.x86.depthwise_conv2d_NCHWc, topi.x86.schedule_depthwise_conv2d_NCHWc)], + }, } -_depthwise_conv2d_nhwc_implement = { - "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), - "arm_cpu": ( - topi.arm_cpu.compute_depthwise_conv2d_nhwc, - topi.arm_cpu.schedule_depthwise_conv2d_nhwc, - ), - "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), -} +in_dtype, out_dtype = tvm.testing.parameters(("float32", "float32")) -def compile_depthwise_NHWC_int8_arm( - batch, - in_channel, - in_size, - kernel, - depth_multiplier, - stride, - padding, - add_bias=False, - dilation=1, -): - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - - in_height = in_width = in_size - A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int16") - W = te.placeholder((kernel, kernel, in_channel, depth_multiplier), name="W", dtype="int16") - bias = te.placeholder((in_channel * depth_multiplier,), name="bias", dtype="int32") - dtype = "int32" - - target = "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu" - compute = topi.arm_cpu.compute_depthwise_conv2d_nhwc - schedule = topi.arm_cpu.schedule_depthwise_conv2d_nhwc - - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - - print("Compiling on arm AArch64 target: %s" % target) - with tvm.target.Target(target): - assert topi.arm_cpu.arm_utils.is_aarch64_arm(), "AArch64 target not recognized" - - C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype) - if add_bias: - C += bias - ins_outs = [A, W, bias, C] - else: - ins_outs = [A, W, C] - - s = schedule([C]) - - func = tvm.build( - s, - ins_outs, - target, - name="depthwise_conv2d", - ) - - -def depthwise_conv2d_with_workload_nchw( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride, - padding, - dilation=1, -): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_h = stride_w = stride - - if dilation == 1: - # here we transform the padding argument from 'str' to 'tuple' , - # because we need this to match the "workload" tuple to the records in TopHub - padt, padl, padb, padr = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (padt, padl, padb, padr) - else: - padding_args = padding - - # placeholder - Input = te.placeholder((batch, in_channel, in_height, in_width), name="Input") - Filter = te.placeholder( - (filter_channel, channel_multiplier, filter_height, filter_width), name="Filter" - ) - Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale") - Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift") - dtype = "float32" - - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - impl_list = tvm.topi.testing.dispatch(target, _depthwise_conv2d_nchw_implement)[:] - if target == "llvm" and channel_multiplier == 1 and dilation == 1: - impl_list.append( - (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw) - ) +@tvm.testing.fixture +def input_shape(layout, batch, in_channel, in_size, filter_shape): + if layout == "NCHW": + return (batch, in_channel, in_size, in_size) + elif layout == "NHWC": + return (batch, in_size, in_size, in_channel) + elif layout == "NCHWc": + oc_block = filter_shape[-1] + ic_block = next(bn for bn in range(oc_block, 0, -1) if in_channel % bn == 0) + return (batch, in_channel // ic_block, in_size, in_size, ic_block) - for fcompute, fschedule in impl_list: - with tvm.target.Target(target): - # declare - DepthwiseConv2d = fcompute( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = fschedule(DepthwiseConv2d) - s2 = fschedule(ScaleShift) - s3 = fschedule(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], target) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], target) - - # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) - scale_shape = get_const_tuple(Scale.shape) - shift_shape = get_const_tuple(Shift.shape) - scale_shift_shape = get_const_tuple(ScaleShift.shape) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nchw") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - dilated_filter_np = tvm.topi.testing.dilate_python( - filter_np, (1, 1, dilation, dilation) - ) - scale_np = np.random.uniform(size=scale_shape).astype(dtype) - shift_np = np.random.uniform(size=shift_shape).astype(dtype) - # correctness with scipy - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, dilated_filter_np, stride, padding - ) - scale_shift_scipy = np.zeros(shape=scale_shift_shape) - for c in range(in_channel * channel_multiplier): - scale_shift_scipy[:, c, :, :] = ( - depthwise_conv2d_scipy[:, c, :, :] * scale_np[c] + shift_np[c] - ) - relu_scipy = np.maximum(scale_shift_scipy, 0) - return ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) - # Get the test data - ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) = get_ref_data() - - def verify_workload_padding(): - _, _, out_height, out_width = get_const_tuple(depthwise_conv2d_scipy.shape) - wkl = _get_workload( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - - # check if tile_ow candidates are the factors of the right output weight. - with tvm.target.Target(target): - cfg = autotvm.get_config() - _fallback_schedule(cfg, wkl) - ow_tile = np.prod(cfg["tile_ow"].size) - - tvm.testing.assert_allclose(ow_tile, out_width) - - if "llvm" in target: - verify_workload_padding() - - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - scale_tvm = tvm.nd.array(scale_np, dev) - shift_tvm = tvm.nd.array(shift_np, dev) - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), - dev, - ) - scale_shift_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), dev - ) - relu_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev - ) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, dev, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, dev, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, dev, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - tvm.testing.assert_allclose( - depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5 - ) - tvm.testing.assert_allclose(scale_shift_tvm.numpy(), scale_shift_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -def depthwise_conv2d_with_workload_nhwc( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride_h, +@tvm.testing.fixture +def filter_shape(layout, in_channel, channel_multiplier, kernel): + filter_channel = in_channel + if layout == "NCHW": + return (filter_channel, channel_multiplier, kernel, kernel) + elif layout == "NHWC": + return (kernel, kernel, filter_channel, channel_multiplier) + elif layout == "NCHWc": + out_channel = in_channel * channel_multiplier + # For testing the functionality, we choose an arbitrary block + # size that can divide out_channel, regardless of the + # performance. + oc_block = next(bn for bn in range(16, 0, -1) if out_channel % bn == 0) + return (out_channel // oc_block, 1, kernel, kernel, 1, oc_block) + + +@tvm.testing.fixture +def scale_shape(layout, in_channel, channel_multiplier, filter_shape): + out_channel = in_channel * channel_multiplier + + if layout in ("NCHW", "NHWC"): + return (out_channel,) + + if layout == "NCHWc": + oc_block = filter_shape[-1] + return (out_channel // oc_block, oc_block) + + raise ValueError("Unknown layout {}".format(layout)) + + +@tvm.testing.fixture +def shift_shape(scale_shape): + return scale_shape + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + dilation, + stride, padding, - dilation=1, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, ): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_w = stride_h - - if dilation == 1: - # here we transform the padding argument from 'str' to 'tuple' , - # because we need this to match the "workload" tuple to the records in TopHub - pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (pad_h, pad_w) - else: - padding_args = padding - - # placeholder - Input = te.placeholder((batch, in_height, in_width, in_channel), name="Input") - Filter = te.placeholder( - (filter_height, filter_width, filter_channel, channel_multiplier), name="Filter" + input_np = np.random.uniform(size=input_shape).astype(in_dtype) + filter_np = np.random.uniform(size=filter_shape).astype(in_dtype) + scale_np = np.random.uniform(size=scale_shape).astype(out_dtype) + shift_np = np.random.uniform(size=shift_shape).astype(out_dtype) + if layout == "NCHW": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchw + dilation = (1, 1, dilation, dilation) + reshape = (1, -1, 1, 1) + elif layout == "NHWC": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nhwc + dilation = (dilation, dilation, 1, 1) + reshape = (1, 1, 1, -1) + elif layout == "NCHWc": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchwc + dilation = (1, 1, dilation, dilation, 1, 1) + reshape = (1, scale_shape[0], 1, 1, scale_shape[1]) + + dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, dilation) + output_np = np_depthwise_conv2d(input_np, dilated_filter_np, stride, padding) + + if use_scale_shift: + output_np = output_np * scale_np.reshape(reshape) + shift_np.reshape(reshape) + if apply_relu: + output_np = np.maximum(output_np, 0) + + return ( + input_np, + filter_np, + scale_np, + shift_np, + output_np, ) - Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale") - Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift") - dtype = "float32" - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - fcompute, fschedule = tvm.topi.testing.dispatch(target, _depthwise_conv2d_nhwc_implement) - with tvm.target.Target(target): - # declare - DepthwiseConv2d = fcompute( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = fschedule(DepthwiseConv2d) - s2 = fschedule(ScaleShift) - s3 = fschedule(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], target) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], target) - - # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) - scale_shape = get_const_tuple(Scale.shape) - shift_shape = get_const_tuple(Shift.shape) - scale_shift_shape = get_const_tuple(ScaleShift.shape) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nhwc.v2") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - dilated_filter_np = tvm.topi.testing.dilate_python( - filter_np, (dilation, dilation, 1, 1) - ) - scale_np = np.random.uniform(size=scale_shape).astype(dtype) - shift_np = np.random.uniform(size=shift_shape).astype(dtype) - # correctness with scipy - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nhwc( - input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding - ) - scale_shift_scipy = np.zeros(shape=scale_shift_shape) - for c in range(in_channel * channel_multiplier): - scale_shift_scipy[:, :, :, c] = ( - depthwise_conv2d_scipy[:, :, :, c] * scale_np[c] + shift_np[c] - ) - relu_scipy = np.maximum(scale_shift_scipy, 0) - return ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) +class BaseDepthwiseConv2D: + """Provides the test_conv2d test function, to be used by other test classes. - # Get the test data - ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) = get_ref_data() - - # prepare data - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - scale_tvm = tvm.nd.array(scale_np, dev) - shift_tvm = tvm.nd.array(shift_np, dev) - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), dev - ) - scale_shift_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), dev - ) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, dev, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, dev, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, dev, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - relu_scipy = np.maximum(scale_shift_scipy, 0) - tvm.testing.assert_allclose(depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5) - tvm.testing.assert_allclose(scale_shift_tvm.numpy(), scale_shift_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -def _transform_data(data, bn): - # NCHW -> NCHW[x]c - batch_size, channel, height, width = data.shape - data = np.reshape(data, (batch_size, channel // bn, bn, height, width)) - data = np.transpose(data, (0, 1, 3, 4, 2)) - return data - - -def _transform_kernel(kernel, bn): - # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block - channel, channel_multiplier, kh, kw = kernel.shape - out_channel = channel * channel_multiplier - kernel = np.reshape(kernel, (out_channel // bn, bn, kh, kw)) - kernel = np.transpose(kernel, (0, 2, 3, 1)) - out_channel_chunk, kh, kw, out_channel_block = kernel.shape - return kernel.reshape(out_channel_chunk, 1, kh, kw, 1, out_channel_block) - - -def depthwise_conv2d_with_workload_NCHWc( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride, - padding, - dilation=1, -): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_h = stride_w = stride - - assert ( - channel_multiplier == 1 - ), "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1." - pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (pad_h, pad_w) - - out_channel = filter_channel * channel_multiplier - # for testing functionality, - # we choose arbitrary block size that can divide the channel, - # regardless of the performance. - oc_block = 1 - for bn in range(16, 0, -1): - if out_channel % bn == 0: - oc_block = bn - break - - ic_block = 1 - for bn in range(oc_block, 0, -1): - if in_channel % bn == 0: - ic_block = bn - break - - # placeholder - Input = te.placeholder( - (batch, in_channel // ic_block, in_height, in_width, ic_block), name="Input" - ) - Filter = te.placeholder( - (out_channel // oc_block, 1, filter_height, filter_width, 1, oc_block), name="Filter" - ) - in_layout = "NCHW%dc" % ic_block - out_layout = "NCHW%dc" % oc_block - dtype = "float32" + Test parameter sets are split out into different classes for + readability (e.g. used for mobilenet), and for restrictions + (e.g. implemented only for llvm). + """ - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - dev = tvm.device(target, 0) - with tvm.target.Target(target): - # declare - DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc( + layout = tvm.testing.parameter("NCHW", "NHWC") + + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 728, 32, 1, 3, 1), + (4, 256, 64, 2, 5, 2), + ) + padding = tvm.testing.parameter("SAME", "VALID") + dilation = tvm.testing.parameter(1, 2) + + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) + apply_relu = tvm.testing.parameter(True, False, ids=["with_relu", "no_relu"]) + + run_after_compile = True + + def test_conv2d( + self, + request, + target, + dev, + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, + batch, + in_channel, + channel_multiplier, + kernel, + stride, + padding, + dilation, + ): + # Transform the padding argument from 'str' to 'tuple' to + # match the "workload" tuple in TopHub. Which padding_args to + # use for each layout chosen to reproduce previous behavior. + if dilation == 1: + padding_args = get_pad_tuple(padding, (kernel, kernel)) + padding_args_i = [0, 1, 2, 3] if layout == "NCHW" else [0, 1] + padding_args = [padding_args[i] for i in padding_args_i] + else: + padding_args = padding + + # placeholder + Input = te.placeholder(input_shape, name="Input", dtype=in_dtype) + Filter = te.placeholder(filter_shape, name="Filter", dtype=in_dtype) + Scale = te.placeholder(scale_shape, name="Scale", dtype=out_dtype) + Shift = te.placeholder(shift_shape, name="Shift", dtype=out_dtype) + + if layout == "NCHW": + topi_scale_shift = topi.nn.scale_shift_nchw + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NHWC": + topi_scale_shift = topi.nn.scale_shift_nhwc + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NCHWc": + topi_scale_shift = topi.nn.scale_shift_nchwc + in_layout = "NCHW{}c".format(input_shape[-1]) + out_layout = "NCHW{}c".format(filter_shape[-1]) + fcompute_args = ( Input, Filter, - (stride_h, stride_w), + stride, padding, - (dilation, dilation), + dilation, in_layout, out_layout, - dtype, - ) - # TODO: add scale_shift implement for NCHWc and add test here - Relu = topi.nn.relu(DepthwiseConv2d) - # schedule - s1 = topi.x86.schedule_depthwise_conv2d_NCHWc(DepthwiseConv2d) - s2 = topi.x86.schedule_depthwise_conv2d_NCHWc(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Relu], target) - - # Prepare pod type for test data closure - input_shape = (batch, in_channel, in_height, in_width) - filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.NCHWc") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - # correctness with scipy - dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype( - dtype - ) - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, dw_np, stride, padding - ) - relu_scipy = np.maximum(depthwise_conv2d_scipy, 0) - return ( - _transform_data(input_np, ic_block), - _transform_kernel(filter_np, oc_block), - _transform_data(depthwise_conv2d_scipy, oc_block), - _transform_data(relu_scipy, oc_block), + out_dtype, ) - # Get the test data - (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data() - - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), dev - ) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev) - # launch kernel 1 (depthwise_conv2d) - f1(input_tvm, filter_tvm, depthwise_conv2d_tvm) - # launch kernel 2 (depthwise_conv2d + relu) - f2(input_tvm, filter_tvm, relu_tvm) - tvm.testing.assert_allclose(depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -@tvm.testing.parametrize_targets -def test_depthwise_conv2d_nchw(target, dev): - # mobilenet workloads - depthwise_conv2d_with_workload_nchw(target, dev, 1, 32, 112, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 64, 112, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 128, 56, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 128, 56, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 256, 28, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 256, 28, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 512, 14, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 512, 14, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 1024, 7, 1, 3, 1, "SAME") - - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload_nchw(target, dev, 4, 256, 64, 2, 5, 2, "VALID") - # dilation = 2 - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 64, 1, 3, 1, "SAME", dilation=2) - - -@tvm.testing.parametrize_targets -def test_depthwise_conv2d_nhwc(target, dev): - depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nhwc(target, dev, 4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload_nhwc(target, dev, 4, 256, 64, 2, 5, 2, "VALID") - - # dilation = 2 - # disabled because it uses too large shared memory on cuda - # depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 64, 1, 3, 1, "SAME", dilation=2) - - -# test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend. + with autotvm.tophub.context(target): # load tophub pre-tuned parameters + impl_list = tvm.topi.testing.dispatch(target, _depthwise_conv2d_implement[layout])[:] + if target == "llvm" and layout == "NCHW" and channel_multiplier == 1 and dilation == 1: + impl_list.append( + (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw) + ) + + for fcompute, fschedule in impl_list: + with tvm.target.Target(target): + # Declare, build schedule + C = fcompute(*fcompute_args) + if use_scale_shift: + C = topi_scale_shift(C, Scale, Shift) + if apply_relu: + C = topi.nn.relu(C) + + s = fschedule(C) + + # Build and run + f = tvm.build(s, [Input, Filter, Scale, Shift, C], target) + + if self.run_after_compile: + input_np, filter_np, scale_np, shift_np, output_np = request.getfixturevalue( + "ref_data" + ) + input_tvm = tvm.nd.array(input_np, dev) + filter_tvm = tvm.nd.array(filter_np, dev) + scale_tvm = tvm.nd.array(scale_np, dev) + shift_tvm = tvm.nd.array(shift_np, dev) + output_tvm = tvm.nd.array( + np.zeros(shape=get_const_tuple(C.shape), dtype=C.dtype), + dev, + ) + + f(input_tvm, filter_tvm, scale_tvm, shift_tvm, output_tvm) + tvm.testing.assert_allclose(output_np, output_tvm.numpy(), rtol=1e-5) + + +class TestDepthwiseConv2D(BaseDepthwiseConv2D): + """Test variety of parameters, defined in BaseDepthwiseConv2D. Also + has llvm-specific tests for workload padding.""" + + @tvm.testing.parametrize_targets("llvm") + def test_workload_padding( + self, + out_dtype, + layout, + input_shape, + filter_shape, + target, + ref_data, + stride, + padding, + dilation, + ): + input_np, filter_np, scale_np, shift_np, output_np = ref_data + if layout == "NCHW": + _, _, out_height, out_width = output_np.shape + elif layout == "NHWC": + _, out_height, out_width, _ = output_np.shape + elif layout == "NCHWc": + _, _, out_height, out_width, _ = output_np.shape + + Input = te.placeholder(input_shape, name="Input") + Filter = te.placeholder(filter_shape, name="Filter") + wkl = _get_workload(Input, Filter, (stride, stride), padding, dilation, out_dtype, layout) + + # check if tile_ow candidates are the factors of the right output weight. + with tvm.target.Target(target): + cfg = autotvm.get_config() + _fallback_schedule(cfg, wkl) + ow_tile = np.prod(cfg["tile_ow"].size) + + tvm.testing.assert_allclose(ow_tile, out_width) + + +class TestDepthwiseConv2D_MobilenetWorkloads(BaseDepthwiseConv2D): + """Extra tests to verify functionality for workloads used by mobilenet.""" + + layout = tvm.testing.parameter("NCHW") + + batch = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1) + kernel = tvm.testing.parameter(3) + padding = tvm.testing.parameter("SAME") + dilation = tvm.testing.parameter(1) + + in_channel, in_size, stride = tvm.testing.parameters( + (32, 112, 1), + (64, 112, 2), + (128, 56, 1), + (128, 56, 2), + (256, 28, 1), + (256, 28, 2), + (512, 14, 1), + (512, 14, 2), + (1024, 7, 1), + ) + + @tvm.testing.parametrize_targets("llvm") -def test_depthwise_conv2d_nchwc(target, dev): - # NCHW[x]c - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "SAME", dilation=2) - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "VALID") +class TestDepthwiseConv2D_NCHWc(BaseDepthwiseConv2D): + """Tests specific to NCHWc layouts. + + Once the implementation supports channel_multiplier>1 and GPU + devices, this class can be merged into TestDepthwiseConv2D. + """ + + # depthwise_conv2d_NCHWc currently does not support channel multiplier > 1 + layout = tvm.testing.parameter("NCHWc") + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 728, 32, 1, 3, 1), + ) + + +@tvm.testing.parametrize_targets("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu") +class TestDepthwiseConv2DArmCompile(BaseDepthwiseConv2D): + """Compile-only tests for cross-compiling to ARM.""" + layout = tvm.testing.parameter("NHWC", "NCHW") + batch = tvm.testing.parameter(1) + dilation = tvm.testing.parameter(1) + in_dtype, out_dtype = tvm.testing.parameters(("int16", "int32")) + in_channel = tvm.testing.parameter(728) + in_size = tvm.testing.parameter(32) + kernel = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1, 3) + stride = tvm.testing.parameter(1) + padding = tvm.testing.parameter("SAME") + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) -def test_depthwise_conv2d_arm(): - # Test compilation on arm targets - compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 3, 1, "SAME") - compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 1, 1, "SAME", True) + run_after_compile = False if __name__ == "__main__": - test_depthwise_conv2d() + sys.exit(pytest.main(sys.argv))