diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8e626f52d528..c8fbd5a5c10c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -42,18 +42,11 @@ from ..prelude import Prelude, StaticTensorArrayOps from . import qnn_torch +from .pytorch_utils import is_version_greater_than __all__ = ["from_pytorch"] -def _is_version_greater_than(ver): - import torch - from packaging import version - - # Torch version > 1.4 changed upsampling API - return version.parse(torch.__version__) > version.parse(ver) - - # List ADT utilities def _infer_type_with_prelude(val, prelude): body = _infer_type(val, prelude.mod) @@ -1882,7 +1875,7 @@ def func(x): if _is_quantized_tensor(data, prelude): # Torch version > 1.4 changed upsampling API - if _is_version_greater_than("1.4.0"): + if is_version_greater_than("1.4.0"): num_inputs = 7 else: num_inputs = 5 @@ -2714,7 +2707,7 @@ def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch - if _is_version_greater_than("1.5.0"): + if is_version_greater_than("1.5.0"): # This is required for torchvision detection models from 1.6 above # It is the same as _jit_pass_inline, except that it has some special # case behaviors for some ops such as aten::__interpolate() @@ -3069,8 +3062,6 @@ def convert_params(graph, state_dict): full_attr_node_name = _get_output_name(getattrs[-1]) if full_attr.endswith("_packed_params"): # for quantized models - err_msg = "parameter %s not found in state dict" % full_attr - assert full_attr in state_dict, err_msg packed_param_map[full_attr_node_name] = full_attr elif full_attr in state_dict: if full_attr in vars_by_name: diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py new file mode 100644 index 000000000000..e0c8f8da7d62 --- /dev/null +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel +""" Common utilities used by PyTorch frontend """ + + +def is_version_greater_than(ver): + import torch + from packaging import version + + return version.parse(torch.__version__) > version.parse(ver) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 121307385d7e..ca67391cebc7 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -26,6 +26,8 @@ from tvm.relay import op as _op from tvm.relay.frontend.common import infer_shape +from .pytorch_utils import is_version_greater_than + class QNNParam: """ A placeholder for weight quantization parameters """ @@ -46,59 +48,95 @@ def __init__(self, weight, bias, scale, zero_point, param_key): self.zero_point = _expr.const(zero_point, dtype="int32") -def _unpack_quant_params(param_name, packed_params, unpack_func): - # Torch stores quantized params in a custom packed format, - # need to unpack and retrieve them as numpy arrays - qweight, bias = unpack_func(packed_params) - weight_np = qweight.dequantize().numpy() +class ConvPackedParam(QNNParam): + """A placeholder for quantized conv2d op attributes + As of PyTorch 1.6, attributes of quantized conv2d ops, like + stride, padding etc are stored in ConvPackedParams objects, + together with weights and quantization parameters + """ + + def __init__( + self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups + ): + super().__init__(weight_np, bias, scale, zero_point, param_name) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + +def _get_quant_params(qweight): import torch + weight_np = qweight.dequantize().numpy() + if qweight.qscheme() == torch.per_tensor_affine: - param = QNNParam( - weight_np, bias, qweight.q_scale(), int(qweight.q_zero_point()), param_name - ) - else: - scales = qweight.q_per_channel_scales().numpy() - zero_points = qweight.q_per_channel_zero_points().numpy() - # This is an assumption posed by QNN - msg = "The values of zero points should be all zero for per channel" - assert np.all(zero_points == 0), msg - param = QNNParam(weight_np, bias, scales, 0, param_name) + return weight_np, qweight.q_scale(), int(qweight.q_zero_point()) + + scales = qweight.q_per_channel_scales().numpy() + zero_points = qweight.q_per_channel_zero_points().numpy() + # This is an assumption posed by QNN + msg = "The values of zero points should be all zero for per channel" + assert np.all(zero_points == 0), msg + return weight_np, scales, 0 + - return param +def make_qnn_param(param_name, qweight, bias): + weight_np, scale, zero_point = _get_quant_params(qweight) + return QNNParam(weight_np, bias, scale, zero_point, param_name) + + +def make_conv_packed_param(param_name, qweight, bias, packed_params): + weight_np, scale, zero_point = _get_quant_params(qweight) + stride = packed_params.stride() + padding = packed_params.padding() + dilation = packed_params.dilation() + groups = packed_params.groups() + return ConvPackedParam( + weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups + ) def get_weight_quant_params(script_module): """ Retrive and unpack weight parameters from quantized modules """ - conv_packed_params = [] - linear_packed_params = [] - import torch - # conv and linear requires different unpacking function - # extract all conv and linear parameters separately to distinguish them - for name, m in script_module.named_modules(): - if isinstance(m, torch.jit.RecursiveScriptModule): - if "Conv" in m.original_name: - conv_packed_params.append((name, m.state_dict())) - elif m.original_name == "LinearPackedParams": - linear_packed_params.append((name, m.state_dict())) + param_name = "_packed_params" + quant_params = {} + + def filter_func(named_module): + m = named_module[1] + return isinstance(m, torch.jit.RecursiveScriptModule) and ( + ("Conv" in m.original_name) or (m.original_name == "LinearPackedParams") + ) - pairs = [ - (torch.ops.quantized.conv2d_unpack, conv_packed_params), - (torch.ops.quantized.linear_unpack, linear_packed_params), - ] + for name, m in filter(filter_func, script_module.named_modules()): + key = name + "." + param_name + state_dict = m.state_dict() - quant_params = {} - param_name = "_packed_params" - for unpack_func, params in pairs: - for name, state_dict in params: + if len(state_dict) == 0 and not hasattr(m, param_name): + # for v1.6 and above + # This case seems to happen if a model is serialized + # and loaded back + # This module can be safely ignored + continue + + if len(state_dict) == 0 and hasattr(m, param_name): + # for v1.6 and above + packed_params = m._packed_params + else: assert len(state_dict) == 1 - assert param_name in state_dict - key = name + "." + param_name - packed_param = state_dict[param_name] - quant_params[key] = _unpack_quant_params(key, packed_param, unpack_func) + packed_params = list(state_dict.values())[0] + + if "Conv" in m.original_name and len(state_dict) == 0: + qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) + quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params) + elif "Conv" in m.original_name: + qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) + quant_params[key] = make_qnn_param(key, qweight, bias) + elif m.original_name == "LinearPackedParams": + qweight, bias = torch.ops.quantized.linear_unpack(packed_params) + quant_params[key] = make_qnn_param(key, qweight, bias) return quant_params @@ -113,8 +151,12 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params): qweight = relay.qnn.op.quantize( qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 ) - param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) - outputs[node_name] = param_tup + params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var] + + if isinstance(quant_params[packed_param_name], ConvPackedParam): + params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups] + + outputs[node_name] = params def _get_quant_param_for_input(input_value): @@ -129,10 +171,17 @@ def _get_quant_param_for_input(input_value): # Indices for output scale and zp # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7), # 6th and 7th arg are output scale and zp respectively. + + # PyTorch 1.6 changed qconv API + if is_version_greater_than("1.5.0"): + qconv_indices = (2, 3) + else: + qconv_indices = (6, 7) + output_quant_param_indices = { "aten::quantize_per_tensor": (1, 2), - "quantized::conv2d": (6, 7), - "quantized::conv2d_relu": (6, 7), + "quantized::conv2d": qconv_indices, + "quantized::conv2d_relu": qconv_indices, "quantized::linear": (2, 3), "quantized::linear_relu": (2, 3), "quantized::add_relu": (2, 3), @@ -458,24 +507,40 @@ def _impl(inputs, _): # inputs[7]: output_zero_point # inputs[8]: input_scale (added manually by frontend) # inputs[9]: input_zero_point (added manually by frontend) - weight = inputs[1][0] - weight_scale = inputs[1][1] - weight_zero_point = inputs[1][2] - - output_scale = _expr.const(inputs[6]) - output_zero_point = _expr.const(inputs[7]) + conv_params = inputs[1] + weight = conv_params[0] + weight_scale = conv_params[1] + weight_zero_point = conv_params[2] + bias = conv_params[3] + + if len(conv_params) > 4: + # Torch 1.6 or newer case + strides = conv_params[4] + padding = conv_params[5] + dilation = conv_params[6] + groups = conv_params[7] + + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + + assert len(inputs) == 6, "Input quant params not found in op inputs" + + # These are manually added by add_input_quant_params_to_op_inputs above + # In torch, they are retrieved from QTensor data structure at runtime + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) + else: + strides = inputs[2] + padding = inputs[3] + dilation = inputs[4] + groups = inputs[5] + output_scale = _expr.const(inputs[6]) + output_zero_point = _expr.const(inputs[7]) - assert len(inputs) == 10, "Input quant params not found in op inputs" - # These are manually added by add_input_quant_params_to_op_inputs above - # In torch, they are retrieved from QTensor data structure at runtime - input_scale = _expr.const(inputs[8]) - input_zero_point = _expr.const(inputs[9]) + assert len(inputs) == 10, "Input quant params not found in op inputs" - strides, padding, dilation = inputs[2], inputs[3], inputs[4] - strides = inputs[2] - padding = inputs[3] - dilation = inputs[4] - groups = inputs[5] + input_scale = _expr.const(inputs[8]) + input_zero_point = _expr.const(inputs[9]) weight_shape = infer_shape(weight) kernel_size = (weight_shape[2], weight_shape[3]) @@ -507,11 +572,10 @@ def _impl(inputs, _): groups=groups, channels=out_channels, ) - bias_var = inputs[1][3] return _do_bias_and_requantize( conv_out, - bias_var, + bias, input_scale, weight_scale, output_scale,