diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c3108ff890b1..f876b1d14fa1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2879,7 +2879,10 @@ def _impl_v1(cls, inputs, attr, params): graph_scope._nodes.update({var.name_hint: var}) # Now we can construct the relay if statement and return. - return _expr.If(cond, then_expr, else_expr) + ret = _expr.If(cond, then_expr, else_expr) + if len(then_branch.output) > 1: + ret = _expr.TupleWrapper(ret, len(then_branch.output)) + return ret class NonMaxSuppression(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9a97f895eaea..c5407697de46 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4108,29 +4108,41 @@ def test_loop(): verify_tensor_loop() -def verify_if(cond_array): +def verify_if(cond_array, num_outputs): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. - then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) - else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) - x = np.array([1, 2, 3, 4, 5]).astype(np.float32) - y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + def append_constant_nodes(nodes, outputs, expected, name): + outputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [5])) - then_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["then_out"], value=numpy_helper.from_array(x) - ) + expected.append(np.random.randn(5).astype("float32")) - else_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["else_out"], value=numpy_helper.from_array(y) - ) + nodes.append( + onnx.helper.make_node( + "Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1]) + ) + ) + + if_outputs = [] + graph_outputs = [] - then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) + then_nodes, then_outs, then_expected = [], [], [] + else_nodes, else_outs, else_expected = [], [], [] - else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out]) + for i in range(num_outputs): + append_constant_nodes(then_nodes, then_outs, then_expected, "then_out{}".format(i)) + append_constant_nodes(else_nodes, else_outs, else_expected, "else_out{}".format(i)) + + if_outputs.append("res{}".format(i)) + graph_outputs.append( + onnx.helper.make_tensor_value_info("res{}".format(i), onnx.TensorProto.FLOAT, [5]), + ) + + then_body = onnx.helper.make_graph(then_nodes, "then_body", [], then_outs) + else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs) if_node = onnx.helper.make_node( - "If", inputs=["cond"], outputs=["res"], then_branch=then_body, else_branch=else_body + "If", inputs=["cond"], outputs=if_outputs, then_branch=then_body, else_branch=else_body ) if_graph = onnx.helper.make_graph( @@ -4139,9 +4151,7 @@ def verify_if(cond_array): inputs=[ onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), ], - outputs=[ - onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]), - ], + outputs=graph_outputs, ) if_model = onnx.helper.make_model(if_graph) @@ -4149,12 +4159,14 @@ def verify_if(cond_array): cond = np.array([1]).astype("bool") else: cond = np.array(1).astype("bool") - correct_out = x if cond else y + correct_out = then_expected if cond else else_expected # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with # verify_with_ort once we update versions. for target, dev in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm(if_model, [cond], target, dev, freeze_params=True) + if not isinstance(tvm_out, list): + tvm_out = [tvm_out] for i in range(len(tvm_out)): tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) @@ -4162,8 +4174,10 @@ def verify_if(cond_array): @tvm.testing.uses_gpu def test_if(): # Confirm that if works with cond as an array or scalar. - verify_if(cond_array=False) - verify_if(cond_array=True) + verify_if(cond_array=False, num_outputs=1) + verify_if(cond_array=False, num_outputs=2) + verify_if(cond_array=True, num_outputs=1) + verify_if(cond_array=True, num_outputs=2) @tvm.testing.uses_gpu