diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6da06ac4a20b6..c6b08708a671c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2809,7 +2809,7 @@ def convert_transpose_conv(self, op): # Weights weights_tensor_type = weights_tensor.tensor.Type() # weights tensor type should be UINT8 (quantization) or FLOAT32 - assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) + assert weights_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type) weight_value_ohwi = self.get_tensor_value(weights_tensor) # Relay kernel_layout should be OIHW @@ -2831,19 +2831,40 @@ def convert_transpose_conv(self, op): else: padding = (0, 0, 0, 0) - out = _op.nn.conv2d_transpose( - in_expr, - weight_expr_iohw, - strides=(stride_h, stride_w), - padding=padding, - channels=int(out_channels), - kernel_size=(int(kernel_h), int(kernel_w)), - data_layout="NHWC", - kernel_layout="OIHW", - out_dtype=output_tensor_type_str, - ) + if input_tensor.qnn_params: + input_zero_point = input_tensor.qnn_params["zero_point"] + kernel_zero_point = weights_tensor.qnn_params["zero_point"] + input_scale = input_tensor.qnn_params["scale"] + kernel_scale = weights_tensor.qnn_params["scale"] + out = _qnn.op.conv2d_transpose( + in_expr, + weight_expr_iohw, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + strides=(stride_h, stride_w), + padding=padding, + channels=int(out_channels), + kernel_size=(int(kernel_h), int(kernel_w)), + data_layout="NHWC", + kernel_layout="OIHW", + out_dtype="int32", + ) + else: + out = _op.nn.conv2d_transpose( + in_expr, + weight_expr_iohw, + strides=(stride_h, stride_w), + padding=padding, + channels=int(out_channels), + kernel_size=(int(kernel_h), int(kernel_w)), + data_layout="NHWC", + kernel_layout="OIHW", + out_dtype=output_tensor_type_str, + ) - # if we have bias + # Checking if there is a fused bias if len(input_tensors) == 4: bias_tensor = input_tensors[3] bias_tensor_type = bias_tensor.tensor.Type() @@ -2856,6 +2877,31 @@ def convert_transpose_conv(self, op): channel_axis = 3 out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) + if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. + data_scale = input_tensor.qnn_params["scale"] + data_scale_val = get_scalar_from_constant(data_scale) + + weight_scale = weights_tensor.qnn_params["scale"] + # If weight scale is scalar, it is per-tensor quantization + if isinstance(weight_scale, float): + weight_scale_val = get_scalar_from_constant(weight_scale) + else: + weight_scale_val = get_tensor_from_constant(weight_scale) + + new_input_scale_val = data_scale_val * weight_scale_val + new_input_scale = relay.const(new_input_scale_val, "float32") + new_input_zero_point = relay.const(0, "int32") + + out = _qnn.op.requantize( + out, + input_scale=new_input_scale, + input_zero_point=new_input_zero_point, + output_scale=output_tensor.qnn_params["scale"], + output_zero_point=output_tensor.qnn_params["zero_point"], + out_dtype=output_tensor_type_str, + axis=3, + ) return out def convert_quantize(self, op): diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 50e5a02f84c00..6e24b8d487ab3 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -32,6 +32,12 @@ def legalize_qnn_conv2d(attrs, inputs, types): return qnn_conv2d_legalize(attrs, inputs, types) +# Registering QNN Conv2DTranspose legalization function. +@reg.register_qnn_legalize("qnn.conv2d_transpose") +def legalize_qnn_conv2d_transpose(attrs, inputs, types): + return qnn_conv2d_transpose_legalize(attrs, inputs, types) + + # Registering QNN dense legalization function. @reg.register_qnn_legalize("qnn.dense") def legalize_qnn_dense(attrs, inputs, types): @@ -46,6 +52,22 @@ def qnn_conv2d_legalize(attrs, inputs, types): return None +# Generic QNN Conv2Transpose legalization function. +@tvm.target.generic_func +def qnn_conv2d_transpose_legalize(attrs, inputs, types): + # Collect the input exprs. + data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs + + shift_data = relay.subtract( + relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16") + ) + shift_kernel = relay.subtract( + relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16") + ) + new_attrs = {k: attrs[k] for k in attrs.keys()} + return relay.nn.conv2d_transpose(shift_data, shift_kernel, **new_attrs) + + # Generic QNN Conv2D legalization function. @tvm.target.generic_func def qnn_dense_legalize(attrs, inputs, types): diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 9a8f22bfb9bc6..b208aea464caa 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -296,6 +296,101 @@ def conv2d( ) +def conv2d_transpose( + data, + weight, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="", + output_padding=(0, 0), + out_dtype="", +): + """This operator deconvolves quantized data with quantized kernel. The scale of + the output quantized tensor is the product of the kernel_scale and + input_scale of the input quantized tensors. The zero point of the output + quantized tensor is 0. By default, the dtype of output is int32. Please also + refer to Requantize operator to understand how to scale back the int32 + output to (u)int8. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : Tuple[int], optional + The strides of convolution. + + padding : Tuple[int], optional + The padding of convolution on both sides of inputs. + + dilation : Tuple[int], optional + Specifies the dilation rate to be used for dilated convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + groups : int, optional + Number of groups for grouped convolution. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + + output_padding : Tuple[int], optional + Used to disambiguate the output shape. + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + # convert 2-way padding to 4-way padding + padding = get_pad_tuple2d(padding) + return _make.conv2d_transpose( + data, + weight, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + strides, + padding, + dilation, + groups, + channels, + kernel_size, + data_layout, + kernel_layout, + out_layout, + output_padding, + out_dtype, + ) + + def add( lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point ): diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc new file mode 100644 index 0000000000000..64250d7005ee3 --- /dev/null +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/convolution.cc + * \brief Property def of qnn convolution operator. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../../op/nn/convolution.h" +#include "../../transforms/pattern_utils.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace qnn { + +// relay.op.qnn.conv2d_transpose + +inline Expr MakeQnnConv2DTranspose(Expr data, Expr weight, Expr input_zero_point, + Expr kernel_zero_point, Expr input_scale, Expr kernel_scale, + Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->output_padding = std::move(output_padding); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get("qnn.conv2d_transpose"); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); +} + +Array> QnnConvTransposeInferCorrectLayout( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types) { + // Use Relay Conv2D Infer correct layout. + auto layouts = ConvInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, + old_in_types); + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as channel layout. + Layout channel_layout = Layout("C"); + Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, + channel_layout, channel_layout, channel_layout}; + Array output_layouts = layouts[1]; + return {input_layouts, output_layouts}; +} + +bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 7); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr."; + ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) + << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; + ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) + << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; + ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32)) + << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; + ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + + // Check the types of scale and zero points. + ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point + ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale + // Kernel scale can be a vector of length output_channels or a scalar. + if (param->groups == 1) { + size_t axis = param->kernel_layout.find('O'); + ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + } else { + // Here, total number of output channels depend on depth multiplier. + size_t o_axis = param->kernel_layout.find('O'); + size_t i_axis = param->kernel_layout.find('I'); + ICHECK(o_axis != std::string::npos || i_axis != std::string::npos) + << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis], + reporter); // kernel scale + } + + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay + // Conv2D infer type function. + Array tensor_types = {types[0], types[1], types[6]}; + return Conv2DTransposeRel(tensor_types, 3, attrs, reporter); +} + +RELAY_REGISTER_OP("qnn.conv2d_transpose") + .describe(R"code(Quantized transposed 2D convolution layer (sometimes called Deconvolution). +This operator deconvolves quantized weight with quantized data. The scale of the +output quantized tensor is the product of the weight_scale and input_scale of +the input quantized tensors. The zero point of the output quantized tensor is +0. By default, the dtype of output is int32. Please also refer to Requantize +operator to understand how to scale back the int32 output to (u)int8. +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, in_channels, height, width) if `layout` is `NCHW`. +- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "Tensor", "The quantized input data tensor.") + .add_argument("weight", "Tensor", "The quantized weight tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QnnConv2DTranspose", QnnConv2DTransposeRel) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", QnnConvTransposeInferCorrectLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d_transpose").set_body_typed(MakeQnnConv2DTranspose); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 89ae348993315..0f849461b29f9 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1115,7 +1115,9 @@ def test_forward_convolution(): # --------------------- -def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding): +def _test_transpose_conv( + tensor_in_sizes, filter_in_sizes, output_shape, strides, padding, quantized=False +): """ One iteration of transpose convolution with given shapes and attributes """ total_size_1 = 1 @@ -1124,53 +1126,124 @@ def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides total_size_1 *= s for s in filter_in_sizes: total_size_2 *= s - # Initializes the input tensor with array containing incrementing - # numbers from 1. - data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] - filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32") - in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32") - strides = [1] + strides + [1] - # in_filter layout is HWOI - out = nn_ops.conv2d_transpose( - in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding - ) - data_array = np.reshape(data_array, tensor_in_sizes).astype("float32") - compare_tflite_with_tvm(data_array, "Placeholder:0", [in_data], [out]) + if quantized: + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [max(f, 255) for f in range(1, total_size_1 + 1)] + filter_array = [max(f, 255) for f in range(1, total_size_2 + 1)] + data_array = np.reshape(data_array, tensor_in_sizes).astype("uint8") + filter_array = np.reshape(filter_array, filter_in_sizes).astype("uint8") + + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data") + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=-100, max=100, name="q_data" + ) + input_range = {"q_data": (-100, 100)} + + in_filter = constant_op.constant( + filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter" + ) + inq_filter = tf.quantization.fake_quant_with_min_max_args( + in_filter, min=-100, max=100, name="q_filter" + ) + + strides = [1] + strides + [1] + + out = nn_ops.conv2d_transpose( + inq_data, inq_filter, output_shape=output_shape, strides=strides, padding=padding + ) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out") + compare_tflite_with_tvm( + [data_array], ["q_data"], [inq_data], [out], quantized=True, input_range=input_range + ) + else: + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] + filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] + + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data") + in_filter = constant_op.constant( + filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter" + ) + strides = [1] + strides + [1] + # in_filter layout is HWOI + out = nn_ops.conv2d_transpose( + in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding + ) + data_array = np.reshape(data_array, tensor_in_sizes).astype("float32") + compare_tflite_with_tvm([data_array], ["in_data"], [in_data], [out]) def test_forward_transpose_conv(): - # kernel 3x3, padding VALID - _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID") - - # kernel 3x3, padding SAME - _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME") - _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME") - - # kernel 2x2, padding VALID - _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID") - - # kernel 2x2, padding SAME - _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME") - _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME") - - # kernel 1x1, padding VALID - _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID") - - # kernel 1x1, padding SAME - _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME") - _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME") + for quantized in [True, False]: + # kernel 3x3, padding VALID + _test_transpose_conv( + [4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID", quantized + ) + + # kernel 3x3, padding SAME + _test_transpose_conv( + [4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME", quantized + ) + + # kernel 2x2, padding VALID + _test_transpose_conv( + [4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID", quantized + ) + + # kernel 2x2, padding SAME + _test_transpose_conv( + [4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME", quantized + ) + + # kernel 1x1, padding VALID + _test_transpose_conv( + [4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID", quantized + ) + + # kernel 1x1, padding SAME + _test_transpose_conv( + [4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME", quantized + ) + _test_transpose_conv( + [1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME", quantized + ) ####################################################################### diff --git a/tests/python/relay/test_op_qnn_conv2_transpose.py b/tests/python/relay/test_op_qnn_conv2_transpose.py new file mode 100644 index 0000000000000..a86f9e1c6a800 --- /dev/null +++ b/tests/python/relay/test_op_qnn_conv2_transpose.py @@ -0,0 +1,638 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +import numpy as np +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_infer_type +from tvm.contrib import graph_runtime +from tvm.relay.testing.temp_op_attr import TempOpAttr + + +def get_ref_func( + data, + kernel, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_size, + padding, + strides, + dilation, + data_layout, + kernel_layout, + out_dtype, + groups, + channels=None, +): + casted_data = relay.op.cast(data, "int32") + casted_kernel = relay.op.cast(kernel, "int32") + shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32")) + shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32")) + func = relay.op.nn.conv2d_transpose( + shifted_data, + shifted_kernel, + padding=padding, + strides=strides, + dilation=dilation, + groups=groups, + channels=channels, + kernel_size=kernel_size, + out_dtype=out_dtype, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + + func = relay.Function(relay.analysis.free_vars(func), func) + return func + + +def get_qnn_func( + data, + kernel, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_size, + padding, + strides, + dilation, + data_layout, + kernel_layout, + out_dtype, + channels, + groups, +): + func = relay.qnn.op.conv2d_transpose( + data, + kernel, + input_zero_point=relay.const(input_zero_point, "int32"), + kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_scale=relay.const(input_scale, "float32"), + kernel_scale=relay.const(kernel_scale, "float32"), + kernel_size=kernel_size, + strides=strides, + dilation=dilation, + padding=padding, + out_dtype=out_dtype, + groups=groups, + channels=channels, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + + mod = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(mod) + return mod + + +def get_funcs( + data_shape, + data_dtype, + kernel_shape, + kernel_dtype, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_size, + padding, + strides, + dilation, + data_layout, + kernel_layout, + out_dtype, + groups=1, + channels=None, +): + data = relay.var("data", shape=data_shape, dtype=data_dtype) + kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) + + ref_func = get_ref_func( + data, + kernel, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_size, + padding, + strides, + dilation, + data_layout, + kernel_layout, + out_dtype, + groups, + channels, + ) + ref_func = run_infer_type(ref_func) + ref_func = tvm.IRModule.from_expr(ref_func) + qnn_func = get_qnn_func( + data, + kernel, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_size, + padding, + strides, + dilation, + data_layout, + kernel_layout, + out_dtype, + channels, + groups, + ) + + return (ref_func, qnn_func) + + +def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype): + def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype): + # Keeping inputs multiple of 4 because of a bug in Average Pool2d + # https://discuss.tvm.apache.org/t/pool2d-gives-bad-output-for-integer-inputs/3377 + low = -128 + high = 127 + if data_dtype == "uint8": + low = 0 + high = 255 + golden_data = np.random.randint(low=low, high=high, size=data_shape).astype(data_dtype) + low = -128 + high = 127 + if kernel_dtype == "uint8": + low = 0 + high = 255 + golden_weight = np.random.randint(low=low, high=high, size=kernel_shape).astype( + kernel_dtype + ) + return (golden_data, golden_weight) + + def get_output(func, golden_inputs): + with tvm.transform.PassContext(opt_level=2): + golden_data, golden_weight = golden_inputs + params = {"kernel": golden_weight} + graph, lib, params = relay.build(func, "llvm", params=params) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input("data", golden_data) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + return res + + golden_inputs = get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype) + golden_output = get_output(ref_func, golden_inputs) + qnn_output = get_output(qnn_func, golden_inputs) + np.testing.assert_equal(qnn_output, golden_output) + + +def test_no_zero_point(): + # uint8 input + data_shape = (2, 1, 2, 4) + data_dtype = "uint8" + kernel_shape = (1, 3, 2, 2) + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=0, + kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 1, 2, 4) + data_dtype = "int8" + kernel_shape = (1, 3, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=0, + kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + +def test_kernel_zero_point(): + # uint8 input + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=0, + kernel_zero_point=1, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 1, 2, 4) + data_dtype = "int8" + kernel_shape = (1, 3, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=0, + kernel_zero_point=5, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + +def test_input_zero_point(): + # uint8 input + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 4, 2, 4) + data_dtype = "int8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=0, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + +def test_both_zero_point(): + # uint8 input + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 4, 2, 4) + data_dtype = "int8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + +def test_layout(): + # uint8 input + data_shape = (2, 2, 4, 4) # NHWC + data_dtype = "uint8" + kernel_shape = (2, 2, 3, 4) # HWIO + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + data_shape = (2, 2, 4, 3) # NHWC + data_dtype = "uint8" + kernel_shape = (2, 2, 1, 3) # HWIO + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=5, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + +def test_padding(): + # uint8 input + data_shape = (1, 4, 2, 2) + data_dtype = "uint8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=8, + kernel_zero_point=5, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # Try different layout + data_shape = (2, 2, 4, 4) # NHWC + data_dtype = "uint8" + kernel_shape = (2, 2, 3, 4) # HWIO + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=8, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # Try asymmetric padding + data_shape = (2, 8, 6, 4) # NHWC + data_dtype = "uint8" + kernel_shape = (2, 2, 3, 4) # HWIO + kernel_dtype = "uint8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=8, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(1, 1, 2, 2), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + +def test_const_folding(): + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (4, 3, 2, 2) + kernel_dtype = "uint8" + + golden_weight = np.random.randint(low=0, high=255, size=kernel_shape).astype(kernel_dtype) + data = relay.var("data", shape=data_shape, dtype=data_dtype) + kernel = relay.const(golden_weight) + qnn_func = get_qnn_func( + data, + kernel, + input_zero_point=8, + kernel_zero_point=3, + kernel_size=(2, 2), + input_scale=1.0, + kernel_scale=1.0, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + channels=kernel_shape[1], + groups=1, + ) + folded_mod = transform.FoldConstant()(qnn_func) + folded_func = folded_mod["main"] + assert "reshape" not in folded_func.astext() + + +def test_broadcast_layout(): + # Test broadcast support for NHWC layout. + data_shape = (1, 229, 229, 3) # NHWC + data_dtype = "uint8" + kernel_shape = (7, 7, 64, 3) # HWIO + kernel_dtype = "int8" + _, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=8, + kernel_zero_point=3, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(7, 7), + padding=(1, 1), + strides=(1, 1), + dilation=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + func = qnn_func["main"].body + bias = relay.var("bias", shape=(64,), dtype="int32") + bias2 = relay.var("bias2", shape=(1, 233, 233, 64), dtype="int32") + + # Check broadcast support on both lhs and rhs + func = relay.add(func, bias2) + func = relay.add(bias2, func) + func = relay.add(bias, func) + func = relay.add(func, bias) + func = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(func) + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") + + +def test_per_channel_kernel_scale(): + data_shape = (2, 1, 2, 4) + data_dtype = "uint8" + kernel_shape = (1, 3, 2, 2) + kernel_dtype = "uint8" + data = relay.var("data", shape=data_shape, dtype=data_dtype) + kernel = relay.var("kernel", shape=kernel_shape, dtype=kernel_dtype) + kernel_scales = [2, 2, 2] + kernel_scales = relay.const(np.array(kernel_scales).astype("float32")) + func = relay.qnn.op.conv2d_transpose( + data, + kernel, + input_zero_point=relay.const(0, "int32"), + kernel_zero_point=relay.const(0, "int32"), + input_scale=relay.const(2.0, "float32"), + kernel_scale=kernel_scales, + kernel_size=(2, 2), + channels=kernel_shape[0], + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + + mod = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(mod) + + +if __name__ == "__main__": + test_no_zero_point() + test_input_zero_point() + test_kernel_zero_point() + test_both_zero_point() + test_layout() + test_padding() + test_const_folding() + test_broadcast_layout() + test_per_channel_kernel_scale()