Skip to content

Commit

Permalink
add support for 1.6 quantized models
Browse files Browse the repository at this point in the history
  • Loading branch information
masa committed Sep 30, 2020
1 parent 0535fd1 commit 880b3a4
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 63 deletions.
2 changes: 0 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,8 +3025,6 @@ def convert_params(graph, state_dict):
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
if full_attr in vars_by_name:
Expand Down
191 changes: 130 additions & 61 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape

from packaging import version


def _is_newer_than_1_5():
import torch

return version.parse(torch.__version__) > version.parse("1.5.0")


class QNNParam:
""" A placeholder for weight quantization parameters """
Expand All @@ -46,59 +54,94 @@ def __init__(self, weight, bias, scale, zero_point, param_key):
self.zero_point = _expr.const(zero_point, dtype="int32")


def _unpack_quant_params(param_name, packed_params, unpack_func):
# Torch stores quantized params in a custom packed format,
# need to unpack and retrieve them as numpy arrays
qweight, bias = unpack_func(packed_params)
weight_np = qweight.dequantize().numpy()
class ConvPackedParam(QNNParam):
"""A placeholder for quantized conv2d op attributs
As of PyTorch 1.6, attributes of quantized conv2d ops, like
stride, padding etc are stored in ConvPackedParams objects,
together with weights and quantization parameters
"""

def __init__(
self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
):
super().__init__(weight_np, bias, scale, zero_point, param_name)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups


def get_quant_params(qweight):
import torch

weight_np = qweight.dequantize().numpy()

if qweight.qscheme() == torch.per_tensor_affine:
param = QNNParam(
weight_np, bias, qweight.q_scale(), int(qweight.q_zero_point()), param_name
)
else:
scales = qweight.q_per_channel_scales().numpy()
zero_points = qweight.q_per_channel_zero_points().numpy()
# This is an assumption posed by QNN
msg = "The values of zero points should be all zero for per channel"
assert np.all(zero_points == 0), msg
param = QNNParam(weight_np, bias, scales, 0, param_name)
return weight_np, qweight.q_scale(), int(qweight.q_zero_point())

scales = qweight.q_per_channel_scales().numpy()
zero_points = qweight.q_per_channel_zero_points().numpy()
# This is an assumption posed by QNN
msg = "The values of zero points should be all zero for per channel"
assert np.all(zero_points == 0), msg
return weight_np, scales, 0

return param

def make_qnn_param(param_name, qweight, bias):
weight_np, scale, zero_point = get_quant_params(qweight)
return QNNParam(weight_np, bias, scale, zero_point, param_name)


def make_conv_packed_param(param_name, qweight, bias, packed_params):
weight_np, scale, zero_point = get_quant_params(qweight)
stride = packed_params.stride()
padding = packed_params.padding()
dilation = packed_params.dilation()
groups = packed_params.groups()
return ConvPackedParam(
weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
)


def get_weight_quant_params(script_module):
""" Retrive and unpack weight parameters from quantized modules """
conv_packed_params = []
linear_packed_params = []

import torch

# conv and linear requires different unpacking function
# extract all conv and linear parameters separately to distinguish them
for name, m in script_module.named_modules():
if isinstance(m, torch.jit.RecursiveScriptModule):
if "Conv" in m.original_name:
conv_packed_params.append((name, m.state_dict()))
elif m.original_name == "LinearPackedParams":
linear_packed_params.append((name, m.state_dict()))
param_name = "_packed_params"
quant_params = {}

pairs = [
(torch.ops.quantized.conv2d_unpack, conv_packed_params),
(torch.ops.quantized.linear_unpack, linear_packed_params),
]
def filter_func(named_module):
m = named_module[1]
return isinstance(m, torch.jit.RecursiveScriptModule) and (
("Conv" in m.original_name) or (m.original_name == "LinearPackedParams")
)

quant_params = {}
param_name = "_packed_params"
for unpack_func, params in pairs:
for name, state_dict in params:
for name, m in filter(filter_func, script_module.named_modules()):
key = name + "." + param_name
state_dict = m.state_dict()

if len(state_dict) == 0 and not hasattr(m, param_name):
# for v1.6 and above
# This case seems to happen if a model is serialized
# and loaded back
# This module can be safely ignored
continue
elif len(state_dict) == 0 and hasattr(m, param_name):
# for v1.6 and above
packed_params = m._packed_params
else:
assert len(state_dict) == 1
assert param_name in state_dict
key = name + "." + param_name
packed_param = state_dict[param_name]
quant_params[key] = _unpack_quant_params(key, packed_param, unpack_func)
packed_params = list(state_dict.values())[0]

if "Conv" in m.original_name and len(state_dict) == 0:
qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params)
quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params)
elif "Conv" in m.original_name:
qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params)
quant_params[key] = make_qnn_param(key, qweight, bias)
elif m.original_name == "LinearPackedParams":
qweight, bias = torch.ops.quantized.linear_unpack(packed_params)
quant_params[key] = make_qnn_param(key, qweight, bias)

return quant_params

Expand All @@ -113,8 +156,12 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
qweight = relay.qnn.op.quantize(
qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0
)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs[node_name] = param_tup
params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var]

if isinstance(quant_params[packed_param_name], ConvPackedParam):
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]

outputs[node_name] = params


def _get_quant_param_for_input(input_value):
Expand All @@ -129,10 +176,17 @@ def _get_quant_param_for_input(input_value):
# Indices for output scale and zp
# For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7),
# 6th and 7th arg are output scale and zp respectively.

# PyTorch 1.6 changed qconv API
if _is_newer_than_1_5():
qconv_indices = (2, 3)
else:
qconv_indices = (6, 7)

output_quant_param_indices = {
"aten::quantize_per_tensor": (1, 2),
"quantized::conv2d": (6, 7),
"quantized::conv2d_relu": (6, 7),
"quantized::conv2d": qconv_indices,
"quantized::conv2d_relu": qconv_indices,
"quantized::linear": (2, 3),
"quantized::linear_relu": (2, 3),
"quantized::add_relu": (2, 3),
Expand Down Expand Up @@ -458,24 +512,40 @@ def _impl(inputs, _):
# inputs[7]: output_zero_point
# inputs[8]: input_scale (added manually by frontend)
# inputs[9]: input_zero_point (added manually by frontend)
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]

output_scale = _expr.const(inputs[6])
output_zero_point = _expr.const(inputs[7])
conv_params = inputs[1]
weight = conv_params[0]
weight_scale = conv_params[1]
weight_zero_point = conv_params[2]
bias = conv_params[3]

if len(conv_params) > 4:
# Torch 1.6 or newer case
strides = conv_params[4]
padding = conv_params[5]
dilation = conv_params[6]
groups = conv_params[7]

output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])

assert len(inputs) == 6, "Input quant params not found in op inputs"

# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runt
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
else:
strides = inputs[2]
padding = inputs[3]
dilation = inputs[4]
groups = inputs[5]
output_scale = _expr.const(inputs[6])
output_zero_point = _expr.const(inputs[7])

assert len(inputs) == 10, "Input quant params not found in op inputs"
# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[8])
input_zero_point = _expr.const(inputs[9])
assert len(inputs) == 10, "Input quant params not found in op inputs"

strides, padding, dilation = inputs[2], inputs[3], inputs[4]
strides = inputs[2]
padding = inputs[3]
dilation = inputs[4]
groups = inputs[5]
input_scale = _expr.const(inputs[8])
input_zero_point = _expr.const(inputs[9])

weight_shape = infer_shape(weight)
kernel_size = (weight_shape[2], weight_shape[3])
Expand Down Expand Up @@ -507,11 +577,10 @@ def _impl(inputs, _):
groups=groups,
channels=out_channels,
)
bias_var = inputs[1][3]

return _do_bias_and_requantize(
conv_out,
bias_var,
bias,
input_scale,
weight_scale,
output_scale,
Expand Down

0 comments on commit 880b3a4

Please sign in to comment.