Skip to content

Commit

Permalink
Allow condition in if op to be an array. (#7215)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm authored Jan 5, 2021
1 parent 8b44741 commit 197594b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,6 +2266,9 @@ class If(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
cond = inputs[0]
# Convert array to bool if needed.
if len(infer_shape(cond)) > 0:
cond = _op.take(cond, _expr.const(0, dtype="int64"))
then_branch = attr.get("then_branch", None)
else_branch = attr.get("else_branch", None)
assert then_branch is not None and else_branch is not None
Expand Down
15 changes: 12 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3969,8 +3969,7 @@ def test_loop():
verify_count_loop()


@tvm.testing.uses_gpu
def test_if():
def verify_if(cond_array):
# Given a bool scalar input cond.
# return constant tensor x if cond is True, otherwise return constant tensor y.
then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5])
Expand Down Expand Up @@ -4007,7 +4006,10 @@ def test_if():
)

if_model = onnx.helper.make_model(if_graph)
cond = np.array(1).astype("bool")
if cond_array:
cond = np.array([1]).astype("bool")
else:
cond = np.array(1).astype("bool")
correct_out = x if cond else y

for target, ctx in tvm.testing.enabled_targets():
Expand All @@ -4016,6 +4018,13 @@ def test_if():
tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05)


@tvm.testing.uses_gpu
def test_if():
# Confirm that if works with cond as an array or scalar.
verify_if(cond_array=False)
verify_if(cond_array=True)


@tvm.testing.uses_gpu
def test_size():
def verify_size(indata):
Expand Down

0 comments on commit 197594b

Please sign in to comment.