Skip to content

Commit

Permalink
Add initial support for quantized transpose convolution in Relay
Browse files Browse the repository at this point in the history
This work is based on @jainris initial PR: apache#6523

I added a relay.qnn.conv2d_transpose node. The strategy I followed is to
convert to int16 and invoke nn.conv2d_transpose (which already exists in
relay). Main changes:

- The node declaration lives in relay/qnn/op/convolution_transpose.cc
- Cast int8->int16 and subsequent offset removal is in tvm/relay/qnn/op/legalizations.py.
- I added and tested the operator in the tflite front-end
- I added a unit-test in Relay for qnn.conv2d_transpose

Co-authored-by: Rishabh Jain
  • Loading branch information
Giuseppe Rossini committed Nov 11, 2020
1 parent b7318a7 commit 6da5514
Show file tree
Hide file tree
Showing 6 changed files with 1,084 additions and 56 deletions.
72 changes: 59 additions & 13 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,7 +2809,7 @@ def convert_transpose_conv(self, op):
# Weights
weights_tensor_type = weights_tensor.tensor.Type()
# weights tensor type should be UINT8 (quantization) or FLOAT32
assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weights_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
weight_value_ohwi = self.get_tensor_value(weights_tensor)
# Relay kernel_layout should be OIHW
Expand All @@ -2831,19 +2831,40 @@ def convert_transpose_conv(self, op):
else:
padding = (0, 0, 0, 0)

out = _op.nn.conv2d_transpose(
in_expr,
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str,
)
if input_tensor.qnn_params:
input_zero_point = input_tensor.qnn_params["zero_point"]
kernel_zero_point = weights_tensor.qnn_params["zero_point"]
input_scale = input_tensor.qnn_params["scale"]
kernel_scale = weights_tensor.qnn_params["scale"]
out = _qnn.op.conv2d_transpose(
in_expr,
weight_expr_iohw,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype="int32",
)
else:
out = _op.nn.conv2d_transpose(
in_expr,
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str,
)

# if we have bias
# Checking if there is a fused bias
if len(input_tensors) == 4:
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
Expand All @@ -2856,6 +2877,31 @@ def convert_transpose_conv(self, op):
channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)

if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params["scale"]
data_scale_val = get_scalar_from_constant(data_scale)

weight_scale = weights_tensor.qnn_params["scale"]
# If weight scale is scalar, it is per-tensor quantization
if isinstance(weight_scale, float):
weight_scale_val = get_scalar_from_constant(weight_scale)
else:
weight_scale_val = get_tensor_from_constant(weight_scale)

new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, "float32")
new_input_zero_point = relay.const(0, "int32")

out = _qnn.op.requantize(
out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params["scale"],
output_zero_point=output_tensor.qnn_params["zero_point"],
out_dtype=output_tensor_type_str,
axis=3,
)
return out

def convert_quantize(self, op):
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def legalize_qnn_conv2d(attrs, inputs, types):
return qnn_conv2d_legalize(attrs, inputs, types)


# Registering QNN Conv2DTranspose legalization function.
@reg.register_qnn_legalize("qnn.conv2d_transpose")
def legalize_qnn_conv2d_transpose(attrs, inputs, types):
return qnn_conv2d_transpose_legalize(attrs, inputs, types)


# Registering QNN dense legalization function.
@reg.register_qnn_legalize("qnn.dense")
def legalize_qnn_dense(attrs, inputs, types):
Expand All @@ -46,6 +52,22 @@ def qnn_conv2d_legalize(attrs, inputs, types):
return None


# Generic QNN Conv2Transpose legalization function.
@tvm.target.generic_func
def qnn_conv2d_transpose_legalize(attrs, inputs, types):
# Collect the input exprs.
data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs

shift_data = relay.subtract(
relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16")
)
shift_kernel = relay.subtract(
relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16")
)
new_attrs = {k: attrs[k] for k in attrs.keys()}
return relay.nn.conv2d_transpose(shift_data, shift_kernel, **new_attrs)


# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
def qnn_dense_legalize(attrs, inputs, types):
Expand Down
95 changes: 95 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,101 @@ def conv2d(
)


def conv2d_transpose(
data,
weight,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
output_padding=(0, 0),
out_dtype="",
):
"""This operator deconvolves quantized data with quantized kernel. The scale of
the output quantized tensor is the product of the kernel_scale and
input_scale of the input quantized tensors. The zero point of the output
quantized tensor is 0. By default, the dtype of output is int32. Please also
refer to Requantize operator to understand how to scale back the int32
output to (u)int8.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Tuple[int], optional
The strides of convolution.
padding : Tuple[int], optional
The padding of convolution on both sides of inputs.
dilation : Tuple[int], optional
Specifies the dilation rate to be used for dilated convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
groups : int, optional
Number of groups for grouped convolution.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
output_padding : Tuple[int], optional
Used to disambiguate the output shape.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.conv2d_transpose(
data,
weight,
input_zero_point,
kernel_zero_point,
input_scale,
kernel_scale,
strides,
padding,
dilation,
groups,
channels,
kernel_size,
data_layout,
kernel_layout,
out_layout,
output_padding,
out_dtype,
)


def add(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
):
Expand Down
154 changes: 154 additions & 0 deletions src/relay/qnn/op/convolution_transpose.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/qnn/op/convolution.cc
* \brief Property def of qnn convolution operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/base.h>
#include <tvm/relay/op.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/data_layout.h>

#include "../../op/nn/convolution.h"
#include "../../transforms/pattern_utils.h"
#include "../utils.h"

namespace tvm {
namespace relay {
namespace qnn {

// relay.op.qnn.conv2d_transpose

inline Expr MakeQnnConv2DTranspose(Expr data, Expr weight, Expr input_zero_point,
Expr kernel_zero_point, Expr input_scale, Expr kernel_scale,
Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Array<IndexExpr> kernel_size, std::string data_layout,
std::string kernel_layout, std::string out_layout,
Array<IndexExpr> output_padding, DataType out_dtype) {
auto attrs = make_object<Conv2DTransposeAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->output_padding = std::move(output_padding);
attrs->out_dtype = std::move(out_dtype);
const Op& op = Op::Get("qnn.conv2d_transpose");
return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
Attrs(attrs), {});
}

Array<Array<Layout>> QnnConvTransposeInferCorrectLayout(
const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Use Relay Conv2D Infer correct layout.
auto layouts = ConvInferCorrectLayout<Conv2DTransposeAttrs>(attrs, new_in_layouts, old_in_layouts,
old_in_types);

// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as channel layout.
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout,
channel_layout, channel_layout, channel_layout};
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}

bool QnnConv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 7);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<Conv2DTransposeAttrs>();
ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
ICHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
// Kernel scale can be a vector of length output_channels or a scalar.
if (param->groups == 1) {
size_t axis = param->kernel_layout.find('O');
ICHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
} else {
// Here, total number of output channels depend on depth multiplier.
size_t o_axis = param->kernel_layout.find('O');
size_t i_axis = param->kernel_layout.find('I');
ICHECK(o_axis != std::string::npos || i_axis != std::string::npos)
<< "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
reporter); // kernel scale
}

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Conv2D infer type function.
Array<Type> tensor_types = {types[0], types[1], types[6]};
return Conv2DTransposeRel<Conv2DTransposeAttrs>(tensor_types, 3, attrs, reporter);
}

RELAY_REGISTER_OP("qnn.conv2d_transpose")
.describe(R"code(Quantized transposed 2D convolution layer (sometimes called Deconvolution).
This operator deconvolves quantized weight with quantized data. The scale of the
output quantized tensor is the product of the weight_scale and input_scale of
the input quantized tensors. The zero point of the output quantized tensor is
0. By default, the dtype of output is int32. Please also refer to Requantize
operator to understand how to scale back the int32 output to (u)int8.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DTransposeAttrs>()
.set_num_inputs(6)
.add_argument("data", "Tensor", "The quantized input data tensor.")
.add_argument("weight", "Tensor", "The quantized weight tensor.")
.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.")
.add_argument("weight_zero_point", "Tensor",
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QnnConv2DTranspose", QnnConv2DTransposeRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvTransposeInferCorrectLayout);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d_transpose").set_body_typed(MakeQnnConv2DTranspose);

} // namespace qnn
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 6da5514

Please sign in to comment.