Skip to content

Commit

Permalink
fix squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Zhao Luo committed Sep 28, 2021
1 parent 4447044 commit 81a40a8
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,21 @@
from .. import random as _random
from .. import ty as _ty
from .. import vision as _vision
from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op,
gru_cell, infer_channels, infer_shape, infer_type,
infer_value, lstm_cell, new_var, unbind)
from .common import (
AttrCvt,
Renamer,
fold_constant,
get_name,
get_relay_op,
gru_cell,
infer_channels,
infer_shape,
infer_type,
infer_value,
lstm_cell,
new_var,
unbind,
)

__all__ = ["from_onnx"]

Expand Down Expand Up @@ -1500,7 +1512,7 @@ class Squeeze(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axes", None)
return _op.squeeze(*inputs, axis)
return _op.squeeze(inputs[0], axis)

@classmethod
def _impl_v13(cls, inputs, attr, params):
Expand All @@ -1510,7 +1522,7 @@ def _impl_v13(cls, inputs, attr, params):
if isinstance(axis, _expr.Constant):
constant_axes = list(inputs[1].data.numpy())
constant_axes = list(map(int, constant_axes))
return _op.squeeze(*inputs, constant_axes)
return _op.squeeze(inputs[0], constant_axes)

rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype)
axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis)
Expand Down

0 comments on commit 81a40a8

Please sign in to comment.