Skip to content

Commit

Permalink
graph fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Jan 30, 2023
1 parent 8cfc1b9 commit 7106a48
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7766,8 +7766,8 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
)

# Test sequence length
split_node = helper.make_node(
"SequenceLength", inputs=["concat_sequence"], outputs=["output_2"]
length_node = helper.make_node(
"SequenceLength", inputs=["split_sequence"], outputs=["output_2"]
)

if new_axis is not None:
Expand All @@ -7787,10 +7787,20 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
output_shape[axis] = num_tensors + 1
else:
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape),
helper.make_tensor_value_info("output_2", TensorProto.INT, ())]
graph_outputs = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape),
helper.make_tensor_value_info("output_2", TensorProto.INT64, []),
]

graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]
graph_nodes = [
position_node,
construct_node,
insert_node,
concat_node,
split_node,
at_node,
length_node,
]

graph = helper.make_graph(
graph_nodes,
Expand Down

0 comments on commit 7106a48

Please sign in to comment.