diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 6bfdb492fed0..fbe31a305ea5 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in dilate operator */ struct DilateAttrs : public tvm::AttrsNode { Array strides; + double dilation_value; TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") { TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) .describe("Dilation stride on each dimension, 1 means no dilation."); + TVM_ATTR_FIELD(dilation_value).set_default(0.0).describe("Value used to dilate the input."); } }; diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index a021402e097c..9b5a8047740e 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -55,19 +55,20 @@ PrimExpr all(Array args) { } /*! - * \brief Dilate data with zeros + * \brief Dilate data with given dilation value (0 by default). * * \param x The input tensor, this can have any number of * dimensions and any layout. * \param strides Dilation stride for each dimension. Stride 1 * means no dilation. + * \param dilation_value Value used to dilate the input. * \param name The name of the operation * \param tag The tag to mark the operation * * \return The output tensor. */ -inline Tensor dilate(const Tensor& x, Array strides, std::string name = "tensor", - std::string tag = kInjective) { +inline Tensor dilate(const Tensor& x, Array strides, double dilation_value, + std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); CHECK_EQ(n, strides.size()) << "strides size (" << strides.size() << ") must match dimension of x (" << n << ")"; @@ -94,7 +95,8 @@ inline Tensor dilate(const Tensor& x, Array strides, std::string name } if (not_zero.size() > 0) { auto all_not_zero = all(not_zero); - return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0)); + return tvm::if_then_else(all_not_zero, x(index_tuple), + make_const(x->dtype, dilation_value)); } return x(index_tuple); }, diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1b09cf307554..3ea6a4c63a8e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2770,7 +2770,7 @@ def convert_transpose_conv(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) - assert len(input_tensors) == 3, "input tensors length should be 3" + assert len(input_tensors) in (3, 4), "input tensors length should be 3 or 4" # Input (data) Tensor. NHWC layout input_tensor = input_tensors[2] @@ -2808,8 +2808,8 @@ def convert_transpose_conv(self, op): # Weights weights_tensor_type = weights_tensor.tensor.Type() - # weights tensor type should be UINT8 (quantization) or FLOAT32 - assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) + # weights tensor type should be INT8/UINT8 (quantization) or FLOAT32 + assert weights_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type) weight_value_ohwi = self.get_tensor_value(weights_tensor) # Relay kernel_layout should be OIHW @@ -2831,17 +2831,100 @@ def convert_transpose_conv(self, op): else: padding = (0, 0, 0, 0) - out = _op.nn.conv2d_transpose( - in_expr, - weight_expr_iohw, - strides=(stride_h, stride_w), - padding=padding, - channels=int(out_channels), - kernel_size=(int(kernel_h), int(kernel_w)), - data_layout="NHWC", - kernel_layout="OIHW", - out_dtype=output_tensor_type_str, - ) + if input_tensor.qnn_params: + # Quantized Transpose Conv can be implemented by using qnn.conv2d + # after some transformations of the input and kernel tensors. + # These transformations are: + # kernel is flipped and transformed and + # input is dilated and padded (upsampled) + # before passing to qnn.conv2d. + # This is equivalent to transpose convolution. + + # Upsampling (Dilating and Padding) input with zero_point + input_zero_point = input_tensor.qnn_params["zero_point"] + dilated_in_expr = _op.nn.dilate( + in_expr, + strides=(1, stride_h, stride_w, 1), + dilation_value=float(input_zero_point.data.asnumpy()), + ) + # qnn.conv2d pads with zero_point, so we just calculate the padding width needed + pad_top = kernel_h - 1 - padding[0] + pad_left = kernel_w - 1 - padding[1] + pad_bottom = kernel_h - 1 - padding[2] + pad_right = kernel_w - 1 - padding[3] + padding = (pad_top, pad_left, pad_bottom, pad_right) + # Transforming kernel into OIHW and flipping it (rotating by 180 degrees) + weight_value_hwio = np.transpose(weight_value_ohwi, (1, 2, 3, 0)) + weight_value_hwio = np.flip(weight_value_hwio, (0, 1)) + weight_expr_oihw = self.exp_tab.new_const( + weight_value_hwio, dtype=weight_tensor_type_str + ) + out = _qnn.op.conv2d( + dilated_in_expr, + weight_expr_oihw, + strides=(1, 1), + padding=padding, + channels=int(out_channels), + kernel_size=(int(kernel_h), int(kernel_w)), + data_layout="NHWC", + kernel_layout="HWIO", + input_zero_point=input_zero_point, + kernel_zero_point=weights_tensor.qnn_params["zero_point"], + out_dtype="int32", + input_scale=input_tensor.qnn_params["scale"], + kernel_scale=weights_tensor.qnn_params["scale"], + ) + else: + out = _op.nn.conv2d_transpose( + in_expr, + weight_expr_iohw, + strides=(stride_h, stride_w), + padding=padding, + channels=int(out_channels), + kernel_size=(int(kernel_h), int(kernel_w)), + data_layout="NHWC", + kernel_layout="OIHW", + out_dtype=output_tensor_type_str, + ) + + # Checking if there is a fused bias + if len(input_tensors) == 4: + bias_tensor = input_tensors[3] + bias_tensor_type = bias_tensor.tensor.Type() + # bias tensor type should be INT32 (quantization) or FLOAT32 + assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) + bias_expr = self.exp_tab.new_const( + self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str + ) + channel_axis = 3 + out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) + + if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. + data_scale = input_tensor.qnn_params["scale"] + data_scale_val = get_scalar_from_constant(data_scale) + + weight_scale = weights_tensor.qnn_params["scale"] + # If weight scale is scalar, it is per-tensor quantization + if isinstance(weight_scale, float): + weight_scale_val = get_scalar_from_constant(weight_scale) + else: + weight_scale_val = get_tensor_from_constant(weight_scale) + + new_input_scale_val = data_scale_val * weight_scale_val + new_input_scale = relay.const(new_input_scale_val, "float32") + new_input_zero_point = relay.const(0, "int32") + + out = _qnn.op.requantize( + out, + input_scale=new_input_scale, + input_zero_point=new_input_zero_point, + output_scale=output_tensor.qnn_params["scale"], + output_zero_point=output_tensor.qnn_params["zero_point"], + out_dtype=output_tensor_type_str, + axis=3, + ) return out diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 53442ef8d850..6694b5a5fd75 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -668,7 +668,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype): # dilate @reg.register_compute("nn.dilate") def compute_dilate(attrs, inputs, out_dtype): - return [topi.nn.dilate(inputs[0], attrs.strides)] + return [topi.nn.dilate(inputs[0], attrs.strides, attrs.dilation_value)] reg.register_broadcast_schedule("nn.dilate") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 853cd4240b48..86a76ff28fa5 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1549,23 +1549,26 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"): return _make.pad(data, pad_width, pad_value, pad_mode) -def dilate(data, strides): - """Dilate data with zeros. +def dilate(data, strides, dilation_value=0.0): + """Dilate data with given dilation value (0 by default). Parameters ---------- data : tvm.relay.Expr n-D, can be any layout. - strides : + strides : tuple of Dilation stride on each dimension, 1 means no dilation. + dilation_value : int/float, optional + Value used to dilate the input. + Returns ------- Output : tvm.relay.Expr The computed result """ - return _make.dilate(data, strides) + return _make.dilate(data, strides, dilation_value) def mirror_pad(data, pad_width, mode="SYMMETRIC"): diff --git a/python/tvm/topi/nn/dilate.py b/python/tvm/topi/nn/dilate.py index 836e29a6812d..6980fea58173 100644 --- a/python/tvm/topi/nn/dilate.py +++ b/python/tvm/topi/nn/dilate.py @@ -23,8 +23,8 @@ @te.tag_scope(tag=tag.INJECTIVE + ",dilate") -def dilate(data, strides, name="DilatedInput"): - """Dilate data with zeros. +def dilate(data, strides, dilation_value=0.0, name="DilatedInput"): + """Dilate data with given dilation value (0 by default). Parameters ---------- @@ -34,6 +34,9 @@ def dilate(data, strides, name="DilatedInput"): strides : list / tuple of n ints Dilation stride on each dimension, 1 means no dilation. + dilation_value : int/float, optional + Value used to dilate the input. + name : str, optional The name prefix operators generated @@ -62,7 +65,7 @@ def _dilate(*indices): if not_zero: not_zero = tvm.tir.all(*not_zero) return tvm.tir.if_then_else( - not_zero, data(*index_tuple), tvm.tir.const(0.0, data.dtype) + not_zero, data(*index_tuple), tvm.tir.const(dilation_value, data.dtype) ) return data(*index_tuple) diff --git a/python/tvm/topi/testing/dilate_python.py b/python/tvm/topi/testing/dilate_python.py index b4fff24a1d43..0ae611559729 100644 --- a/python/tvm/topi/testing/dilate_python.py +++ b/python/tvm/topi/testing/dilate_python.py @@ -19,7 +19,7 @@ import numpy as np -def dilate_python(input_np, strides): +def dilate_python(input_np, strides, dilation_value=0.0): """Dilate operation. Parameters @@ -30,6 +30,9 @@ def dilate_python(input_np, strides): strides : list / tuple of n ints Dilation stride on each dimension, 1 means no dilation. + dilation_value : int/float, optional + Value used to dilate the input. + Returns ------- output_np : numpy.ndarray @@ -45,7 +48,8 @@ def dilate_python(input_np, strides): for i in range(n): output_size += ((input_np.shape[i] - 1) * strides[i] + 1,) no_zero += ((range(0, output_size[i], strides[i])),) - output_np = np.zeros(shape=output_size) + output_np = np.ones(shape=output_size) + output_np = dilation_value * output_np output_np[np.ix_(*no_zero)] = input_np return output_np diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 19348c018dbf..619b86d358d1 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -961,9 +961,10 @@ bool DilateRel(const Array& types, int num_inputs, const Attrs& attrs, } // Positional relay function to create dilate operator used by frontend FFI. -Expr MakeDilate(Expr data, Array strides) { +Expr MakeDilate(Expr data, Array strides, double dilation_value = 0.0) { auto attrs = make_object(); attrs->strides = std::move(strides); + attrs->dilation_value = std::move(dilation_value); static const Op& op = Op::Get("nn.dilate"); return Call(op, {data}, Attrs(attrs), {}); } @@ -972,7 +973,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate); RELAY_REGISTER_OP("nn.dilate") .describe(R"code( -Dilate data with zeros. +Dilate data with given dilation value (0 by default). )code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("x", "1D Tensor", "Data to dilate.") diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 4a209b2f2932..c03d1b056d35 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValu /* Ops from nn/dilate.h */ TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::dilate(args[0], args[1]); + *rv = nn::dilate(args[0], args[1], args[2]); }); /* Ops from nn/flatten.h */ diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7d674278800b..8506783c2d6d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1105,7 +1105,9 @@ def test_forward_convolution(): # --------------------- -def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding): +def _test_transpose_conv( + tensor_in_sizes, filter_in_sizes, output_shape, strides, padding, quantized=False +): """ One iteration of transpose convolution with given shapes and attributes """ total_size_1 = 1 @@ -1114,53 +1116,129 @@ def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides total_size_1 *= s for s in filter_in_sizes: total_size_2 *= s - # Initializes the input tensor with array containing incrementing - # numbers from 1. - data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] - filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32") - in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32") - strides = [1] + strides + [1] - # in_filter layout is HWOI - out = nn_ops.conv2d_transpose( - in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding - ) - data_array = np.reshape(data_array, tensor_in_sizes).astype("float32") - compare_tflite_with_tvm(data_array, "Placeholder:0", [in_data], [out]) + if quantized: + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [max(f, 255) for f in range(1, total_size_1 + 1)] + filter_array = [max(f, 255) for f in range(1, total_size_2 + 1)] + data_array = np.reshape(data_array, tensor_in_sizes).astype("uint8") + filter_array = np.reshape(filter_array, filter_in_sizes).astype("uint8") + + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data") + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=-100, max=100, name="q_data" + ) + input_range = {"q_data": (-100, 100)} + + in_filter = constant_op.constant( + filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter" + ) + inq_filter = tf.quantization.fake_quant_with_min_max_args( + in_filter, min=-100, max=100, name="q_filter" + ) + + strides = [1] + strides + [1] + + out = nn_ops.conv2d_transpose( + inq_data, inq_filter, output_shape=output_shape, strides=strides, padding=padding + ) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out") + compare_tflite_with_tvm( + [data_array], ["q_data"], [inq_data], [out], quantized=True, input_range=input_range + ) + else: + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] + filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] + + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data") + in_filter = constant_op.constant( + filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter" + ) + strides = [1] + strides + [1] + # in_filter layout is HWOI + out = nn_ops.conv2d_transpose( + in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding + ) + data_array = np.reshape(data_array, tensor_in_sizes).astype("float32") + compare_tflite_with_tvm([data_array], ["in_data"], [in_data], [out]) def test_forward_transpose_conv(): - # kernel 3x3, padding VALID - _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID") - - # kernel 3x3, padding SAME - _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME") - - # kernel 2x2, padding VALID - _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID") - - # kernel 2x2, padding SAME - _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME") - - # kernel 1x1, padding VALID - _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID") - - # kernel 1x1, padding SAME - _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME") + for quantized in [False, True]: + # kernel 3x3, padding VALID + _test_transpose_conv( + [4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID", quantized + ) + + # kernel 3x3, padding SAME + _test_transpose_conv( + [4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME", quantized + ) + + # kernel 2x2, padding VALID + _test_transpose_conv( + [4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID", quantized + ) + + # kernel 2x2, padding SAME + _test_transpose_conv( + [4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME", quantized + ) + + # kernel 1x1, padding VALID + _test_transpose_conv( + [4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID", quantized + ) + + # kernel 1x1, padding SAME + _test_transpose_conv( + [4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME", quantized + ) + + # asymmetric kernel (3x2) + _test_transpose_conv( + [4, 32, 32, 16], [3, 2, 5, 16], [4, 34, 33, 5], [1, 1], "VALID", quantized + ) ####################################################################### diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d24a733f7655..c13f679e6108 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -740,18 +740,24 @@ def test_any_pad(): verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1)) -def verify_any_dilate(data_shape, strides, static_data_shape): +def verify_any_dilate(data_shape, strides, static_data_shape, dilation_value=None): assert len(data_shape) == len(strides) mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) - y = relay.nn.dilate(data, strides) + if dilation_value is None: + y = relay.nn.dilate(data, strides) + else: + y = relay.nn.dilate(data, strides, dilation_value) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_shape = tuple( (static_data_shape[i] - 1) * strides[i] + 1 for i in range(len(static_data_shape)) ) - ref_out = np.zeros(shape=ref_shape, dtype=dtype) + if dilation_value is None: + dilation_value = 0.0 + ref_out = np.ones(shape=ref_shape, dtype=dtype) + ref_out = dilation_value * ref_out ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np check_result([data_np], mod, ref_out) @@ -766,6 +772,7 @@ def test_any_dilate(): verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3)) verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3)) verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4)) + verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4), 1.0) def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): diff --git a/tests/python/topi/python/test_topi_dilate.py b/tests/python/topi/python/test_topi_dilate.py index ab5c61dce406..0ee51a6c7bf4 100644 --- a/tests/python/topi/python/test_topi_dilate.py +++ b/tests/python/topi/python/test_topi_dilate.py @@ -26,12 +26,18 @@ def test_dilate(): target = "llvm" ctx = tvm.cpu(0) - def _test_dilate(input_size, strides): + def _test_dilate(input_size, strides, dilation_value=None): Input = te.placeholder((input_size)) - Output = topi.nn.dilate(Input, strides) + if dilation_value is None: + Output = topi.nn.dilate(Input, strides) + else: + Output = topi.nn.dilate(Input, strides, dilation_value) schedule = te.create_schedule(Output.op) input_np = np.random.uniform(size=input_size).astype(Input.dtype) - output_np = tvm.topi.testing.dilate_python(input_np, strides) + if dilation_value is None: + output_np = tvm.topi.testing.dilate_python(input_np, strides) + else: + output_np = tvm.topi.testing.dilate_python(input_np, strides, dilation_value) input_tvm = tvm.nd.array(input_np, ctx=ctx) output_size = topi.util.get_const_tuple(Output.shape) output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx) @@ -47,6 +53,7 @@ def _test_dilate(input_size, strides): _test_dilate((1, 32, 32, 3, 3), (2, 2, 2, 2, 2)) _test_dilate((1, 32, 32, 32, 3, 3), (1, 1, 1, 2, 2, 2)) _test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1)) + _test_dilate((1, 32, 32, 32, 3, 3), (2, 2, 2, 1, 1, 1), 1.0) if __name__ == "__main__":