From 738348db11de025525ec1d30662de06ebb660e41 Mon Sep 17 00:00:00 2001 From: Kanghwan Jang Date: Thu, 1 Jul 2021 14:36:40 -0700 Subject: [PATCH 1/3] [ONNX] Wrap 'If' if it has multiple outputs Without this wrapper, an assertion in from_onnx() will fail with the error message showing ""Number of output mismatch" --- python/tvm/relay/frontend/onnx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7135fccdf43b..b5cb1e240cd3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2846,7 +2846,10 @@ def _impl_v1(cls, inputs, attr, params): graph_scope._nodes.update({var.name_hint: var}) # Now we can construct the relay if statement and return. - return _expr.If(cond, then_expr, else_expr) + ret = _expr.If(cond, then_expr, else_expr) + if len(then_branch.output) > 1: + ret = _expr.TupleWrapper(ret, len(then_branch.output)) + return ret class NonMaxSuppression(OnnxOpConverter): From 08788d5365d0664e96d7405cd05df369865e54ec Mon Sep 17 00:00:00 2001 From: Kanghwan Jang Date: Fri, 2 Jul 2021 16:27:11 -0700 Subject: [PATCH 2/3] [ONNX] Test If nodes with multiple output tensors --- tests/python/frontend/onnx/test_forward.py | 56 ++++++++++++++-------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 52c3346e5807..99b4194f4b31 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4065,29 +4065,43 @@ def test_loop(): verify_tensor_loop() -def verify_if(cond_array): +def verify_if(cond_array, num_outputs): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. - then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) - else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) - x = np.array([1, 2, 3, 4, 5]).astype(np.float32) - y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + def append_constant_nodes(nodes, outputs, expected, name): + outputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [5])) - then_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["then_out"], value=numpy_helper.from_array(x) - ) + expected.append(np.random.randn(5).astype("float32")) - else_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["else_out"], value=numpy_helper.from_array(y) - ) + nodes.append(onnx.helper.make_node( + "Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1]) + )) + + if_outputs = [] + graph_outputs = [] + + then_nodes, then_outs, then_expected = [], [], [] + else_nodes, else_outs, else_expected = [], [], [] + + for i in range(num_outputs): + append_constant_nodes(then_nodes, then_outs, then_expected, "then_out{}".format(i)) + append_constant_nodes(else_nodes, else_outs, else_expected, "else_out{}".format(i)) - then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) + if_outputs.append("res{}".format(i)) + graph_outputs.append( + onnx.helper.make_tensor_value_info("res{}".format(i), onnx.TensorProto.FLOAT, [5]), + ) - else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out]) + then_body = onnx.helper.make_graph(then_nodes, "then_body", [], then_outs) + else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs) if_node = onnx.helper.make_node( - "If", inputs=["cond"], outputs=["res"], then_branch=then_body, else_branch=else_body + "If", + inputs=["cond"], + outputs=if_outputs, + then_branch=then_body, + else_branch=else_body ) if_graph = onnx.helper.make_graph( @@ -4096,9 +4110,7 @@ def verify_if(cond_array): inputs=[ onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), ], - outputs=[ - onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]), - ], + outputs=graph_outputs, ) if_model = onnx.helper.make_model(if_graph) @@ -4106,12 +4118,14 @@ def verify_if(cond_array): cond = np.array([1]).astype("bool") else: cond = np.array(1).astype("bool") - correct_out = x if cond else y + correct_out = then_expected if cond else else_expected # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with # verify_with_ort once we update versions. for target, dev in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm(if_model, [cond], target, dev, freeze_params=True) + if not isinstance(tvm_out, list): + tvm_out = [tvm_out] for i in range(len(tvm_out)): tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) @@ -4119,8 +4133,10 @@ def verify_if(cond_array): @tvm.testing.uses_gpu def test_if(): # Confirm that if works with cond as an array or scalar. - verify_if(cond_array=False) - verify_if(cond_array=True) + verify_if(cond_array=False, num_outputs=1) + verify_if(cond_array=False, num_outputs=2) + verify_if(cond_array=True, num_outputs=1) + verify_if(cond_array=True, num_outputs=2) @tvm.testing.uses_gpu From 32e4c4cb63aaeb037a316121c5dd269433bb34f0 Mon Sep 17 00:00:00 2001 From: Kanghwan Jang Date: Mon, 5 Jul 2021 16:51:38 -0700 Subject: [PATCH 3/3] Fix formatting issues --- tests/python/frontend/onnx/test_forward.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 99b4194f4b31..b47e32c38e8d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4074,9 +4074,11 @@ def append_constant_nodes(nodes, outputs, expected, name): expected.append(np.random.randn(5).astype("float32")) - nodes.append(onnx.helper.make_node( - "Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1]) - )) + nodes.append( + onnx.helper.make_node( + "Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1]) + ) + ) if_outputs = [] graph_outputs = [] @@ -4097,11 +4099,7 @@ def append_constant_nodes(nodes, outputs, expected, name): else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs) if_node = onnx.helper.make_node( - "If", - inputs=["cond"], - outputs=if_outputs, - then_branch=then_body, - else_branch=else_body + "If", inputs=["cond"], outputs=if_outputs, then_branch=then_body, else_branch=else_body ) if_graph = onnx.helper.make_graph(