From 81a40a8e3f914d544fb0427ddfa39795e94b9176 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 28 Sep 2021 10:29:04 -0700 Subject: [PATCH] fix squeeze --- python/tvm/relay/frontend/onnx.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6791cb0fb900..1d5c1d560653 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -37,9 +37,21 @@ from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, - gru_cell, infer_channels, infer_shape, infer_type, - infer_value, lstm_cell, new_var, unbind) +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_name, + get_relay_op, + gru_cell, + infer_channels, + infer_shape, + infer_type, + infer_value, + lstm_cell, + new_var, + unbind, +) __all__ = ["from_onnx"] @@ -1500,7 +1512,7 @@ class Squeeze(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axes", None) - return _op.squeeze(*inputs, axis) + return _op.squeeze(inputs[0], axis) @classmethod def _impl_v13(cls, inputs, attr, params): @@ -1510,7 +1522,7 @@ def _impl_v13(cls, inputs, attr, params): if isinstance(axis, _expr.Constant): constant_axes = list(inputs[1].data.numpy()) constant_axes = list(map(int, constant_axes)) - return _op.squeeze(*inputs, constant_axes) + return _op.squeeze(inputs[0], constant_axes) rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype) axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis)