Skip to content

Commit

Permalink
[ONNX] Support SequenceLength op (apache#13863)
Browse files Browse the repository at this point in the history
* add SequenceLength op

* add SequenceLength test

* graph fix

---------

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
  • Loading branch information
2 people authored and fzi-peccia committed Mar 27, 2023
1 parent e5cc9ca commit 5b669b5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6148,6 +6148,15 @@ def _impl_v11(cls, inputs, attr, params):
return _expr.Tuple(inputs)


class SequenceLength(OnnxOpConverter):
"""Operator converter for sequence length op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
# Get length of input sequence
return _expr.const(len(inputs[0]), dtype="int64")


class SequenceInsert(OnnxOpConverter):
"""Operator converter for sequence insert op."""

Expand Down Expand Up @@ -6483,6 +6492,7 @@ def _get_convert_map(opset):
"LinearRegressor": LinearRegressor.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceLength": SequenceLength.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
Expand Down
21 changes: 19 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7760,10 +7760,16 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
"SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis
)

# Test tensor extraction from sequence
at_node = helper.make_node(
"SequenceAt", inputs=["split_sequence", "position"], outputs=["output"]
)

# Test sequence length
length_node = helper.make_node(
"SequenceLength", inputs=["split_sequence"], outputs=["output_2"]
)

if new_axis is not None:
new_axis_attr = helper.make_attribute("new_axis", new_axis)
concat_node.attribute.append(new_axis_attr)
Expand All @@ -7781,9 +7787,20 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
output_shape[axis] = num_tensors + 1
else:
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)]
graph_outputs = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape),
helper.make_tensor_value_info("output_2", TensorProto.INT64, []),
]

graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]
graph_nodes = [
position_node,
construct_node,
insert_node,
concat_node,
split_node,
at_node,
length_node,
]

graph = helper.make_graph(
graph_nodes,
Expand Down

0 comments on commit 5b669b5

Please sign in to comment.