Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite. #6523

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
/*! \brief Attributes used in dilate operator */
struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
Array<IndexExpr> strides;
double dilation_value;
jainris marked this conversation as resolved.
Show resolved Hide resolved

TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1}))
.describe("Dilation stride on each dimension, 1 means no dilation.");
TVM_ATTR_FIELD(dilation_value).set_default(0.0).describe("Value used to dilate the input.");
}
};

Expand Down
10 changes: 6 additions & 4 deletions include/tvm/topi/nn/dilate.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,20 @@ PrimExpr all(Array<PrimExpr> args) {
}

/*!
* \brief Dilate data with zeros
* \brief Dilate data with given dilation value (0 by default).
*
* \param x The input tensor, this can have any number of
* dimensions and any layout.
* \param strides Dilation stride for each dimension. Stride 1
* means no dilation.
* \param dilation_value Value used to dilate the input.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The output tensor.
*/
inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name = "tensor",
std::string tag = kInjective) {
inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, double dilation_value,
std::string name = "tensor", std::string tag = kInjective) {
auto n = x->shape.size();
CHECK_EQ(n, strides.size()) << "strides size (" << strides.size()
<< ") must match dimension of x (" << n << ")";
Expand All @@ -94,7 +95,8 @@ inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, std::string name
}
if (not_zero.size() > 0) {
auto all_not_zero = all(not_zero);
return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0));
return tvm::if_then_else(all_not_zero, x(index_tuple),
make_const(x->dtype, dilation_value));
}
return x(index_tuple);
},
Expand Down
111 changes: 97 additions & 14 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2770,7 +2770,7 @@ def convert_transpose_conv(self, op):
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"
assert len(input_tensors) in (3, 4), "input tensors length should be 3 or 4"

# Input (data) Tensor. NHWC layout
input_tensor = input_tensors[2]
Expand Down Expand Up @@ -2808,8 +2808,8 @@ def convert_transpose_conv(self, op):

# Weights
weights_tensor_type = weights_tensor.tensor.Type()
# weights tensor type should be UINT8 (quantization) or FLOAT32
assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
# weights tensor type should be INT8/UINT8 (quantization) or FLOAT32
assert weights_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
weight_value_ohwi = self.get_tensor_value(weights_tensor)
# Relay kernel_layout should be OIHW
Expand All @@ -2831,17 +2831,100 @@ def convert_transpose_conv(self, op):
else:
padding = (0, 0, 0, 0)

out = _op.nn.conv2d_transpose(
in_expr,
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str,
)
if input_tensor.qnn_params:
# Quantized Transpose Conv can be implemented by using qnn.conv2d
# after some transformations of the input and kernel tensors.
# These transformations are:
# kernel is flipped and transformed and
# input is dilated and padded (upsampled)
# before passing to qnn.conv2d.
# This is equivalent to transpose convolution.

# Upsampling (Dilating and Padding) input with zero_point
input_zero_point = input_tensor.qnn_params["zero_point"]
dilated_in_expr = _op.nn.dilate(
in_expr,
strides=(1, stride_h, stride_w, 1),
dilation_value=float(input_zero_point.data.asnumpy()),
)
# qnn.conv2d pads with zero_point, so we just calculate the padding width needed
pad_top = kernel_h - 1 - padding[0]
pad_left = kernel_w - 1 - padding[1]
pad_bottom = kernel_h - 1 - padding[2]
pad_right = kernel_w - 1 - padding[3]
padding = (pad_top, pad_left, pad_bottom, pad_right)
# Transforming kernel into OIHW and flipping it (rotating by 180 degrees)
weight_value_hwio = np.transpose(weight_value_ohwi, (1, 2, 3, 0))
weight_value_hwio = np.flip(weight_value_hwio, (0, 1))
weight_expr_oihw = self.exp_tab.new_const(
weight_value_hwio, dtype=weight_tensor_type_str
)
out = _qnn.op.conv2d(
dilated_in_expr,
weight_expr_oihw,
strides=(1, 1),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="HWIO",
input_zero_point=input_zero_point,
kernel_zero_point=weights_tensor.qnn_params["zero_point"],
out_dtype="int32",
input_scale=input_tensor.qnn_params["scale"],
kernel_scale=weights_tensor.qnn_params["scale"],
)
else:
out = _op.nn.conv2d_transpose(
in_expr,
weight_expr_iohw,
strides=(stride_h, stride_w),
padding=padding,
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str,
)

# Checking if there is a fused bias
if len(input_tensors) == 4:
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
)
channel_axis = 3
out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)

if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params["scale"]
data_scale_val = get_scalar_from_constant(data_scale)

weight_scale = weights_tensor.qnn_params["scale"]
# If weight scale is scalar, it is per-tensor quantization
if isinstance(weight_scale, float):
weight_scale_val = get_scalar_from_constant(weight_scale)
else:
weight_scale_val = get_tensor_from_constant(weight_scale)

new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, "float32")
new_input_zero_point = relay.const(0, "int32")

out = _qnn.op.requantize(
out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params["scale"],
output_zero_point=output_tensor.qnn_params["zero_point"],
out_dtype=output_tensor_type_str,
axis=3,
)

return out

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def compute_cross_entropy(attrs, inputs, out_dtype):
# dilate
@reg.register_compute("nn.dilate")
def compute_dilate(attrs, inputs, out_dtype):
return [topi.nn.dilate(inputs[0], attrs.strides)]
return [topi.nn.dilate(inputs[0], attrs.strides, attrs.dilation_value)]


reg.register_broadcast_schedule("nn.dilate")
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,23 +1549,26 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
return _make.pad(data, pad_width, pad_value, pad_mode)


def dilate(data, strides):
"""Dilate data with zeros.
def dilate(data, strides, dilation_value=0.0):
"""Dilate data with given dilation value (0 by default).

Parameters
----------
data : tvm.relay.Expr
n-D, can be any layout.

strides : <tuple of <int>
strides : tuple of <int>
Dilation stride on each dimension, 1 means no dilation.

dilation_value : int/float, optional
Value used to dilate the input.

Returns
-------
Output : tvm.relay.Expr
The computed result
"""
return _make.dilate(data, strides)
return _make.dilate(data, strides, dilation_value)


def mirror_pad(data, pad_width, mode="SYMMETRIC"):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/nn/dilate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@


@te.tag_scope(tag=tag.INJECTIVE + ",dilate")
def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros.
def dilate(data, strides, dilation_value=0.0, name="DilatedInput"):
"""Dilate data with given dilation value (0 by default).

Parameters
----------
Expand All @@ -34,6 +34,9 @@ def dilate(data, strides, name="DilatedInput"):
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.

dilation_value : int/float, optional
Value used to dilate the input.

name : str, optional
The name prefix operators generated

Expand Down Expand Up @@ -62,7 +65,7 @@ def _dilate(*indices):
if not_zero:
not_zero = tvm.tir.all(*not_zero)
return tvm.tir.if_then_else(
not_zero, data(*index_tuple), tvm.tir.const(0.0, data.dtype)
not_zero, data(*index_tuple), tvm.tir.const(dilation_value, data.dtype)
)
return data(*index_tuple)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/topi/testing/dilate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def dilate_python(input_np, strides):
def dilate_python(input_np, strides, dilation_value=0.0):
"""Dilate operation.

Parameters
Expand All @@ -30,6 +30,9 @@ def dilate_python(input_np, strides):
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.

dilation_value : int/float, optional
Value used to dilate the input.

Returns
-------
output_np : numpy.ndarray
Expand All @@ -45,7 +48,8 @@ def dilate_python(input_np, strides):
for i in range(n):
output_size += ((input_np.shape[i] - 1) * strides[i] + 1,)
no_zero += ((range(0, output_size[i], strides[i])),)
output_np = np.zeros(shape=output_size)
output_np = np.ones(shape=output_size)
output_np = dilation_value * output_np
output_np[np.ix_(*no_zero)] = input_np

return output_np
5 changes: 3 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,10 @@ bool DilateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

// Positional relay function to create dilate operator used by frontend FFI.
Expr MakeDilate(Expr data, Array<IndexExpr> strides) {
Expr MakeDilate(Expr data, Array<IndexExpr> strides, double dilation_value = 0.0) {
auto attrs = make_object<DilateAttrs>();
attrs->strides = std::move(strides);
attrs->dilation_value = std::move(dilation_value);
static const Op& op = Op::Get("nn.dilate");
return Call(op, {data}, Attrs(attrs), {});
}
Expand All @@ -972,7 +973,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate);

RELAY_REGISTER_OP("nn.dilate")
.describe(R"code(
Dilate data with zeros.
Dilate data with given dilation value (0 by default).
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("x", "1D Tensor", "Data to dilate.")
Expand Down
2 changes: 1 addition & 1 deletion src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValu

/* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::dilate(args[0], args[1]);
*rv = nn::dilate(args[0], args[1], args[2]);
});

/* Ops from nn/flatten.h */
Expand Down
Loading