Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][PyTorch] support for quantized conv_transpose2d op #9133

Merged
merged 3 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,25 @@ class ConvPackedParam(QNNParam):
together with weights and quantization parameters
"""

def __init__(self, weight_np, bias, scale, zero_point, stride, padding, dilation, groups):
def __init__(
self,
weight_np,
bias,
scale,
zero_point,
stride,
padding,
dilation,
groups,
output_padding,
):
super().__init__(weight_np, bias, scale, zero_point)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Used only for conv_transpose2d
self.output_padding = output_padding


def _get_quant_params(qweight):
Expand Down Expand Up @@ -86,7 +99,18 @@ def make_conv_packed_param(qweight, bias, packed_params):
padding = packed_params.padding()
dilation = packed_params.dilation()
groups = packed_params.groups()
return ConvPackedParam(weight_np, bias, scale, zero_point, stride, padding, dilation, groups)
output_padding = packed_params.output_padding()
return ConvPackedParam(
weight_np,
bias,
scale,
zero_point,
stride,
padding,
dilation,
groups,
output_padding,
)


def get_weight_quant_params(script_module, packed_param_names):
Expand Down Expand Up @@ -208,7 +232,13 @@ def add_quant_params_to_outputs(
params = [qweight, qparam.scale, qparam.zero_point, qbias]

if isinstance(quant_params[packed_param_name], ConvPackedParam):
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]
params += [
qparam.stride,
qparam.padding,
qparam.dilation,
qparam.groups,
qparam.output_padding,
]

outputs[node_name] = params

Expand Down Expand Up @@ -246,6 +276,7 @@ def _get_quant_param_for_input(input_value):
"quantized::mul_scalar": (2, 3),
"quantized::add_scalar": (2, 3),
"quantized::hardswish": (1, 2),
"quantized::conv_transpose2d": qconv_indices,
}

def dfs(current_node):
Expand Down Expand Up @@ -416,6 +447,7 @@ def add_input_quant_params_to_op_inputs(graph):
"quantized::relu6": 1,
"quantized::hardswish": 1,
"aten::hardsigmoid": 1,
"quantized::conv_transpose2d": 1,
}

need_input_quant_param = set(num_quantized_inputs.keys())
Expand Down Expand Up @@ -457,7 +489,7 @@ def add_input_quant_params_to_op_inputs(graph):
node.addInput(scale)
node.addInput(zp)

if "conv2d" in operator or "linear" in operator:
if "conv" in operator or "linear" in operator:
# This is required for quantizing the bias
input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")

Expand Down Expand Up @@ -983,6 +1015,65 @@ def _impl(inputs, _):
return _impl


def _quantized_conv_transpose2d(with_relu=False):
def _impl(inputs, _):
# Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp
# Supported in Torch 1.7 or newer
conv_params = inputs[1]
weight = conv_params[0]
weight_scale = conv_params[1]
weight_zero_point = conv_params[2]
bias = conv_params[3]

strides = conv_params[4]
padding = conv_params[5]
dilation = conv_params[6]
groups = conv_params[7]
output_padding = conv_params[8]

output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])

assert len(inputs) == 6, "Input quant params not found in op inputs"

# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])

weight_shape = list(infer_shape(weight))

# Swap I and O dims to match shape relay expects for OIHW
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

kernel_size = (weight_shape[2], weight_shape[3])
out_channels = weight_shape[0]

conv_out = relay.qnn.op.conv2d_transpose(
inputs[0],
weight,
input_zero_point,
weight_zero_point,
input_scale,
weight_scale,
kernel_size=kernel_size,
dilation=dilation,
strides=strides,
padding=padding,
groups=groups,
channels=out_channels,
output_padding=output_padding,
out_dtype="int32",
kernel_layout="OIHW",
)

return _do_bias_and_requantize(
conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu
)

return _impl


convert_map = {
"aten::quantize_per_tensor": _quantize_per_tensor(),
"quantized::conv2d_relu": _quantized_conv2d(with_relu=True),
Expand All @@ -1000,4 +1091,5 @@ def _impl(inputs, _):
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
"quantized::hardswish": _hswish(),
"quantized::conv_transpose2d": _quantized_conv_transpose2d(),
}
26 changes: 25 additions & 1 deletion tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def fuse_model(self):
fuse_modules(self.conv, indices, inplace=True)


class ConvTranspose(nn.Module):
def __init__(self):
super().__init__()
layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)]
self.conv = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.conv)

def forward(self, x):
return self.quant_wrap(x)

def fuse_model(self):
pass


class Linear(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
Expand Down Expand Up @@ -276,6 +290,7 @@ def test_quantized_modules():
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
("conv_transpose", imagenet_ishape, ConvTranspose(), False),
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
Expand All @@ -287,7 +302,15 @@ def test_quantized_modules():
raw_module.eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp, per_channel=per_channel)
# quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0.
if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"):
prev_engine = torch.backends.quantized.engine
torch.backends.quantized.engine = "qnnpack"
quantize_model(raw_module, inp, per_channel=per_channel)
torch.backends.quantized.engine = prev_engine
else:
quantize_model(raw_module, inp, per_channel=per_channel)

script_module = torch.jit.trace(raw_module, inp).eval()

with torch.no_grad():
Expand All @@ -314,6 +337,7 @@ def test_quantized_modules():
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
linear 0.15987062 0.009231662 0.794921875
linear_relu 0.14180502 0.0053220326 0.8828125
conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
linear, per_channel 0.0 0.0 1.0
Expand Down