Skip to content

Commit

Permalink
[Torch, Quantization] Necessary workaround to prepare for 1.6 update (a…
Browse files Browse the repository at this point in the history
…pache#6602)

* add support for 1.6 quantized models

* fix lint

* move version check function to a common utils

* fix lint

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 2, 2020
1 parent 9d758f0 commit e77a951
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 73 deletions.
15 changes: 3 additions & 12 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,11 @@
from ..prelude import Prelude, StaticTensorArrayOps

from . import qnn_torch
from .pytorch_utils import is_version_greater_than

__all__ = ["from_pytorch"]


def _is_version_greater_than(ver):
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
return version.parse(torch.__version__) > version.parse(ver)


# List ADT utilities
def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
Expand Down Expand Up @@ -1882,7 +1875,7 @@ def func(x):

if _is_quantized_tensor(data, prelude):
# Torch version > 1.4 changed upsampling API
if _is_version_greater_than("1.4.0"):
if is_version_greater_than("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
Expand Down Expand Up @@ -2714,7 +2707,7 @@ def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch

if _is_version_greater_than("1.5.0"):
if is_version_greater_than("1.5.0"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
Expand Down Expand Up @@ -3069,8 +3062,6 @@ def convert_params(graph, state_dict):
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
if full_attr in vars_by_name:
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel
""" Common utilities used by PyTorch frontend """


def is_version_greater_than(ver):
import torch
from packaging import version

return version.parse(torch.__version__) > version.parse(ver)
186 changes: 125 additions & 61 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape

from .pytorch_utils import is_version_greater_than


class QNNParam:
""" A placeholder for weight quantization parameters """
Expand All @@ -46,59 +48,95 @@ def __init__(self, weight, bias, scale, zero_point, param_key):
self.zero_point = _expr.const(zero_point, dtype="int32")


def _unpack_quant_params(param_name, packed_params, unpack_func):
# Torch stores quantized params in a custom packed format,
# need to unpack and retrieve them as numpy arrays
qweight, bias = unpack_func(packed_params)
weight_np = qweight.dequantize().numpy()
class ConvPackedParam(QNNParam):
"""A placeholder for quantized conv2d op attributes
As of PyTorch 1.6, attributes of quantized conv2d ops, like
stride, padding etc are stored in ConvPackedParams objects,
together with weights and quantization parameters
"""

def __init__(
self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
):
super().__init__(weight_np, bias, scale, zero_point, param_name)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups


def _get_quant_params(qweight):
import torch

weight_np = qweight.dequantize().numpy()

if qweight.qscheme() == torch.per_tensor_affine:
param = QNNParam(
weight_np, bias, qweight.q_scale(), int(qweight.q_zero_point()), param_name
)
else:
scales = qweight.q_per_channel_scales().numpy()
zero_points = qweight.q_per_channel_zero_points().numpy()
# This is an assumption posed by QNN
msg = "The values of zero points should be all zero for per channel"
assert np.all(zero_points == 0), msg
param = QNNParam(weight_np, bias, scales, 0, param_name)
return weight_np, qweight.q_scale(), int(qweight.q_zero_point())

scales = qweight.q_per_channel_scales().numpy()
zero_points = qweight.q_per_channel_zero_points().numpy()
# This is an assumption posed by QNN
msg = "The values of zero points should be all zero for per channel"
assert np.all(zero_points == 0), msg
return weight_np, scales, 0


return param
def make_qnn_param(param_name, qweight, bias):
weight_np, scale, zero_point = _get_quant_params(qweight)
return QNNParam(weight_np, bias, scale, zero_point, param_name)


def make_conv_packed_param(param_name, qweight, bias, packed_params):
weight_np, scale, zero_point = _get_quant_params(qweight)
stride = packed_params.stride()
padding = packed_params.padding()
dilation = packed_params.dilation()
groups = packed_params.groups()
return ConvPackedParam(
weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
)


def get_weight_quant_params(script_module):
""" Retrive and unpack weight parameters from quantized modules """
conv_packed_params = []
linear_packed_params = []

import torch

# conv and linear requires different unpacking function
# extract all conv and linear parameters separately to distinguish them
for name, m in script_module.named_modules():
if isinstance(m, torch.jit.RecursiveScriptModule):
if "Conv" in m.original_name:
conv_packed_params.append((name, m.state_dict()))
elif m.original_name == "LinearPackedParams":
linear_packed_params.append((name, m.state_dict()))
param_name = "_packed_params"
quant_params = {}

def filter_func(named_module):
m = named_module[1]
return isinstance(m, torch.jit.RecursiveScriptModule) and (
("Conv" in m.original_name) or (m.original_name == "LinearPackedParams")
)

pairs = [
(torch.ops.quantized.conv2d_unpack, conv_packed_params),
(torch.ops.quantized.linear_unpack, linear_packed_params),
]
for name, m in filter(filter_func, script_module.named_modules()):
key = name + "." + param_name
state_dict = m.state_dict()

quant_params = {}
param_name = "_packed_params"
for unpack_func, params in pairs:
for name, state_dict in params:
if len(state_dict) == 0 and not hasattr(m, param_name):
# for v1.6 and above
# This case seems to happen if a model is serialized
# and loaded back
# This module can be safely ignored
continue

if len(state_dict) == 0 and hasattr(m, param_name):
# for v1.6 and above
packed_params = m._packed_params
else:
assert len(state_dict) == 1
assert param_name in state_dict
key = name + "." + param_name
packed_param = state_dict[param_name]
quant_params[key] = _unpack_quant_params(key, packed_param, unpack_func)
packed_params = list(state_dict.values())[0]

if "Conv" in m.original_name and len(state_dict) == 0:
qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params)
quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params)
elif "Conv" in m.original_name:
qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params)
quant_params[key] = make_qnn_param(key, qweight, bias)
elif m.original_name == "LinearPackedParams":
qweight, bias = torch.ops.quantized.linear_unpack(packed_params)
quant_params[key] = make_qnn_param(key, qweight, bias)

return quant_params

Expand All @@ -113,8 +151,12 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
qweight = relay.qnn.op.quantize(
qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0
)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs[node_name] = param_tup
params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var]

if isinstance(quant_params[packed_param_name], ConvPackedParam):
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]

outputs[node_name] = params


def _get_quant_param_for_input(input_value):
Expand All @@ -129,10 +171,17 @@ def _get_quant_param_for_input(input_value):
# Indices for output scale and zp
# For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7),
# 6th and 7th arg are output scale and zp respectively.

# PyTorch 1.6 changed qconv API
if is_version_greater_than("1.5.0"):
qconv_indices = (2, 3)
else:
qconv_indices = (6, 7)

output_quant_param_indices = {
"aten::quantize_per_tensor": (1, 2),
"quantized::conv2d": (6, 7),
"quantized::conv2d_relu": (6, 7),
"quantized::conv2d": qconv_indices,
"quantized::conv2d_relu": qconv_indices,
"quantized::linear": (2, 3),
"quantized::linear_relu": (2, 3),
"quantized::add_relu": (2, 3),
Expand Down Expand Up @@ -458,24 +507,40 @@ def _impl(inputs, _):
# inputs[7]: output_zero_point
# inputs[8]: input_scale (added manually by frontend)
# inputs[9]: input_zero_point (added manually by frontend)
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]

output_scale = _expr.const(inputs[6])
output_zero_point = _expr.const(inputs[7])
conv_params = inputs[1]
weight = conv_params[0]
weight_scale = conv_params[1]
weight_zero_point = conv_params[2]
bias = conv_params[3]

if len(conv_params) > 4:
# Torch 1.6 or newer case
strides = conv_params[4]
padding = conv_params[5]
dilation = conv_params[6]
groups = conv_params[7]

output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])

assert len(inputs) == 6, "Input quant params not found in op inputs"

# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
else:
strides = inputs[2]
padding = inputs[3]
dilation = inputs[4]
groups = inputs[5]
output_scale = _expr.const(inputs[6])
output_zero_point = _expr.const(inputs[7])

assert len(inputs) == 10, "Input quant params not found in op inputs"
# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[8])
input_zero_point = _expr.const(inputs[9])
assert len(inputs) == 10, "Input quant params not found in op inputs"

strides, padding, dilation = inputs[2], inputs[3], inputs[4]
strides = inputs[2]
padding = inputs[3]
dilation = inputs[4]
groups = inputs[5]
input_scale = _expr.const(inputs[8])
input_zero_point = _expr.const(inputs[9])

weight_shape = infer_shape(weight)
kernel_size = (weight_shape[2], weight_shape[3])
Expand Down Expand Up @@ -507,11 +572,10 @@ def _impl(inputs, _):
groups=groups,
channels=out_channels,
)
bias_var = inputs[1][3]

return _do_bias_and_requantize(
conv_out,
bias_var,
bias,
input_scale,
weight_scale,
output_scale,
Expand Down

0 comments on commit e77a951

Please sign in to comment.