Skip to content

Commit

Permalink
move version check function to a common utils
Browse files Browse the repository at this point in the history
  • Loading branch information
masa committed Oct 16, 2020
1 parent 801b882 commit 1a18815
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
13 changes: 3 additions & 10 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,11 @@
from ..prelude import Prelude, StaticTensorArrayOps

from . import qnn_torch
from .pytorch_utils import is_version_greater_than

__all__ = ["from_pytorch"]


def _is_version_greater_than(ver):
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
return version.parse(torch.__version__) > version.parse(ver)


# List ADT utilities
def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
Expand Down Expand Up @@ -1882,7 +1875,7 @@ def func(x):

if _is_quantized_tensor(data, prelude):
# Torch version > 1.4 changed upsampling API
if _is_version_greater_than("1.4.0"):
if is_version_greater_than("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
Expand Down Expand Up @@ -2714,7 +2707,7 @@ def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch

if _is_version_greater_than("1.5.0"):
if is_version_greater_than("1.5.0"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Common utilities used by PyTorch frontend """


def is_version_greater_than(ver):
import torch
from packaging import version

return version.parse(torch.__version__) > version.parse(ver)
14 changes: 4 additions & 10 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape

from packaging import version


def _is_newer_than_1_5():
import torch

return version.parse(torch.__version__) > version.parse("1.5.0")
from .pytorch_utils import is_version_greater_than


class QNNParam:
Expand All @@ -55,7 +49,7 @@ def __init__(self, weight, bias, scale, zero_point, param_key):


class ConvPackedParam(QNNParam):
"""A placeholder for quantized conv2d op attributs
"""A placeholder for quantized conv2d op attributes
As of PyTorch 1.6, attributes of quantized conv2d ops, like
stride, padding etc are stored in ConvPackedParams objects,
together with weights and quantization parameters
Expand Down Expand Up @@ -179,7 +173,7 @@ def _get_quant_param_for_input(input_value):
# 6th and 7th arg are output scale and zp respectively.

# PyTorch 1.6 changed qconv API
if _is_newer_than_1_5():
if is_version_greater_than("1.5.0"):
qconv_indices = (2, 3)
else:
qconv_indices = (6, 7)
Expand Down Expand Up @@ -532,7 +526,7 @@ def _impl(inputs, _):
assert len(inputs) == 6, "Input quant params not found in op inputs"

# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runt
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
else:
Expand Down

0 comments on commit 1a18815

Please sign in to comment.