Skip to content

Commit

Permalink
[ONNX] Wrap 'If' if it has multiple outputs (#8385)
Browse files Browse the repository at this point in the history
* [ONNX] Wrap 'If' if it has multiple outputs

Without this wrapper, an assertion in from_onnx() will fail with the
error message showing ""Number of output mismatch"

* [ONNX] Test If nodes with multiple output tensors

* Fix formatting issues
  • Loading branch information
karljang authored Jul 7, 2021
1 parent bd5cd9f commit 2628179
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,7 +2879,10 @@ def _impl_v1(cls, inputs, attr, params):
graph_scope._nodes.update({var.name_hint: var})

# Now we can construct the relay if statement and return.
return _expr.If(cond, then_expr, else_expr)
ret = _expr.If(cond, then_expr, else_expr)
if len(then_branch.output) > 1:
ret = _expr.TupleWrapper(ret, len(then_branch.output))
return ret


class NonMaxSuppression(OnnxOpConverter):
Expand Down
54 changes: 34 additions & 20 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4108,29 +4108,41 @@ def test_loop():
verify_tensor_loop()


def verify_if(cond_array):
def verify_if(cond_array, num_outputs):
# Given a bool scalar input cond.
# return constant tensor x if cond is True, otherwise return constant tensor y.
then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5])
else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5])

x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
y = np.array([5, 4, 3, 2, 1]).astype(np.float32)
def append_constant_nodes(nodes, outputs, expected, name):
outputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [5]))

then_const_node = onnx.helper.make_node(
"Constant", inputs=[], outputs=["then_out"], value=numpy_helper.from_array(x)
)
expected.append(np.random.randn(5).astype("float32"))

else_const_node = onnx.helper.make_node(
"Constant", inputs=[], outputs=["else_out"], value=numpy_helper.from_array(y)
)
nodes.append(
onnx.helper.make_node(
"Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1])
)
)

if_outputs = []
graph_outputs = []

then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out])
then_nodes, then_outs, then_expected = [], [], []
else_nodes, else_outs, else_expected = [], [], []

else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out])
for i in range(num_outputs):
append_constant_nodes(then_nodes, then_outs, then_expected, "then_out{}".format(i))
append_constant_nodes(else_nodes, else_outs, else_expected, "else_out{}".format(i))

if_outputs.append("res{}".format(i))
graph_outputs.append(
onnx.helper.make_tensor_value_info("res{}".format(i), onnx.TensorProto.FLOAT, [5]),
)

then_body = onnx.helper.make_graph(then_nodes, "then_body", [], then_outs)
else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs)

if_node = onnx.helper.make_node(
"If", inputs=["cond"], outputs=["res"], then_branch=then_body, else_branch=else_body
"If", inputs=["cond"], outputs=if_outputs, then_branch=then_body, else_branch=else_body
)

if_graph = onnx.helper.make_graph(
Expand All @@ -4139,31 +4151,33 @@ def verify_if(cond_array):
inputs=[
onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
],
outputs=[
onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]),
],
outputs=graph_outputs,
)

if_model = onnx.helper.make_model(if_graph)
if cond_array:
cond = np.array([1]).astype("bool")
else:
cond = np.array(1).astype("bool")
correct_out = x if cond else y
correct_out = then_expected if cond else else_expected

# TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with
# verify_with_ort once we update versions.
for target, dev in tvm.testing.enabled_targets():
tvm_out = get_tvm_output_with_vm(if_model, [cond], target, dev, freeze_params=True)
if not isinstance(tvm_out, list):
tvm_out = [tvm_out]
for i in range(len(tvm_out)):
tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05)


@tvm.testing.uses_gpu
def test_if():
# Confirm that if works with cond as an array or scalar.
verify_if(cond_array=False)
verify_if(cond_array=True)
verify_if(cond_array=False, num_outputs=1)
verify_if(cond_array=False, num_outputs=2)
verify_if(cond_array=True, num_outputs=1)
verify_if(cond_array=True, num_outputs=2)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 2628179

Please sign in to comment.