Skip to content

Commit

Permalink
[Relay] Fix TFlite frontend for unpack, stridedslice (#10333)
Browse files Browse the repository at this point in the history
We found this while converting an RNN model.

The relay tflite frontend use squeeze at converting unpack, but when the
unpack.axis=0, `None` is passed to relay.squeeze(), which would squeeze
all dimensions with length 1, causing different results from TFLite.

A possible fix might be, assign the unpack.axis as-is to relay.squeeze()

As for stridedslice, when the tflite frontend handles shrink_axis_mask,
the wrapped `begin` should be used, instead of the original one which
can be negative. It can cause errors at
https://github.com/apache/tvm/blob/d65ff6594d4d6db0062537a1d43c0504173b8e5c/include/tvm/topi/detail/strided_slice.h#L140

Related cases are also added to the python test.
  • Loading branch information
chiwwang authored Feb 22, 2022
1 parent 5a22c56 commit d8d28bf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
if begin[index] < 0
else begin[index]
)
m_end[final_index] = begin[index] + 1
m_end[final_index] = m_begin[final_index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
Expand Down Expand Up @@ -2705,9 +2705,9 @@ def convert_unpack(self, op):
unpack_axis = unpack_options.Axis()

# Relay doesn't support 'unpack' operator so we use 'split' & 'squeeze' instead.
# We have to do 'squeeze' along the split axis but Relay expects
# squeeze_axis to be either None or List.
squeeze_axis = None if unpack_axis == 0 else [unpack_axis]
# We have to do 'squeeze' along the split axis.
# Relay expects squeeze_axis to be List.
squeeze_axis = [unpack_axis]

# Relay doesn't like TupleWrapper of 1 element so we isolate the case of unpacking
# a tensor by an axis with len(axis) == 1. For reference see convert_split().
Expand Down
6 changes: 6 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ def test_forward_stridedslice():
_test_stridedslice(
(4, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2, quantized=quantized
)
_test_stridedslice(
(3, 4), [-1, 0], [0, 3], [1, 1], "float32", shrink_axis_mask=1, quantized=quantized
)


#######################################################################
Expand Down Expand Up @@ -3186,6 +3189,9 @@ def test_forward_unpack():
"""UNPACK"""
_test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
_test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
_test_unpack(
np.array(np.random.uniform(0, 5, (3, 1, 2)), dtype=np.float32), axis=0, num_unpacks=3
)
# tflite 1.13 doesn't accept negative axis
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_unpack(
Expand Down

0 comments on commit d8d28bf

Please sign in to comment.