From 9965650eb166651e6f13e6e6d16c5316fef16e94 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 15 Dec 2020 04:56:09 -0700 Subject: [PATCH] [ONNX] Fix a bug with reshape imports when an initialized target shape is used more than once (#7109) * Fix a bug with reshape imports when an initialized target shape is used more than once * run autoformat --- python/tvm/relay/frontend/onnx.py | 3 +- tests/python/frontend/onnx/test_forward.py | 37 +++++++++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 23102aaa9d32..cbec32240589 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -859,8 +859,7 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v5(cls, inputs, attr, params): if get_name(inputs[1]) in params: - # pop shape out of parameters since it wont be needed later. - shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32")) + shape = tuple(params[inputs[1].name_hint].asnumpy().astype("int32")) out = _op.reshape(inputs[0], shape) else: out = _op.reshape(*inputs) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index bae50c9d85f4..33dd048896b6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -73,7 +73,6 @@ def get_tvm_output( input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) - with tvm.transform.PassContext(opt_level=1): graph, lib, params = relay.build(mod, target, params=params) @@ -234,6 +233,42 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) +@tvm.testing.uses_gpu +def test_double_reshape(): + in_shape = (4, 3, 3, 4) + ref_shape = (6, 2, 4, 3) + + ref_array = np.array(ref_shape) + ref_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["ref_in"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT32, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int), + ), + ) + reshape_node1 = helper.make_node("Reshape", ["in", "ref_in"], ["out1"]) + reshape_node2 = helper.make_node("Reshape", ["in", "ref_in"], ["out2"]) + add_node = helper.make_node("Add", ["out1", "out2"], ["out"]) + + graph = helper.make_graph( + [ref_node, reshape_node1, reshape_node2, add_node], + "reshape_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))], + ) + + model = helper.make_model(graph, producer_name="reshape_test") + + for target, ctx in tvm.testing.enabled_targets(): + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32") + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + + # TODO(mbrookhart): enable once VM supports heterogenous execution # @tvm.testing.uses_gpu def test_expand():