Skip to content

Commit

Permalink
[ONNX] Fix a bug with reshape imports when an initialized target shap…
Browse files Browse the repository at this point in the history
…e is used more than once (apache#7109)

* Fix a bug with reshape imports when an initialized target shape is used more than once

* run autoformat
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Jan 21, 2021
1 parent 9953027 commit 2135dbc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
3 changes: 1 addition & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,7 @@ def _impl_v1(cls, inputs, attr, params):
@classmethod
def _impl_v5(cls, inputs, attr, params):
if get_name(inputs[1]) in params:
# pop shape out of parameters since it wont be needed later.
shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32"))
shape = tuple(params[inputs[1].name_hint].asnumpy().astype("int32"))
out = _op.reshape(inputs[0], shape)
else:
out = _op.reshape(*inputs)
Expand Down
37 changes: 36 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def get_tvm_output(
input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)

with tvm.transform.PassContext(opt_level=1):
graph, lib, params = relay.build(mod, target, params=params)

Expand Down Expand Up @@ -234,6 +233,42 @@ def test_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)


@tvm.testing.uses_gpu
def test_double_reshape():
in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3)

ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["ref_in"],
value=onnx.helper.make_tensor(
name="const_tensor",
data_type=onnx.TensorProto.INT32,
dims=ref_array.shape,
vals=ref_array.flatten().astype(int),
),
)
reshape_node1 = helper.make_node("Reshape", ["in", "ref_in"], ["out1"])
reshape_node2 = helper.make_node("Reshape", ["in", "ref_in"], ["out2"])
add_node = helper.make_node("Add", ["out1", "out2"], ["out"])

graph = helper.make_graph(
[ref_node, reshape_node1, reshape_node2, add_node],
"reshape_test",
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))],
)

model = helper.make_model(graph, producer_name="reshape_test")

for target, ctx in tvm.testing.enabled_targets():
x = np.random.uniform(size=in_shape).astype("int32")
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
def test_expand():
Expand Down

0 comments on commit 2135dbc

Please sign in to comment.