Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Robustify Loop Importer (apache#7353)
Browse files Browse the repository at this point in the history
* Add test for array loop.

* Fixed scalar issue.

* Formatting.

* Fix injective schedule for dynamic shapes.
  • Loading branch information
jwfromm authored and trevor-m committed Mar 2, 2021
1 parent 8e23ea5 commit 1ea436c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 22 deletions.
13 changes: 11 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,8 +2227,17 @@ def body_fn(*loop_inputs):
# Add new scan outputs to tracking
combined_scan_outputs = []
for i, scan in enumerate(scan_outputs):
new_scan = _op.expand_dims(new_scan_outputs[i], axis=0)
combined_scan = _op.concatenate([scan, new_scan], axis=0)
rank = len(infer_shape(scan)) - 1
new_scan = new_scan_outputs[i]
expand_scan = _op.expand_dims(new_scan, axis=0)
# For non scalar outputs we need to broadcast the initial value.
if rank > 0:
new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype)
scan_broadcast = _op.concatenate(
[_op.reshape(loop_count, [1]), new_scan_shape], axis=0
)
scan = _op.broadcast_to(scan, scan_broadcast)
combined_scan = _op.concatenate([scan, expand_scan], axis=0)
combined_scan_outputs.append(combined_scan)

# Increment counter.
Expand Down
27 changes: 15 additions & 12 deletions python/tvm/topi/x86/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name
"""x86 declaration and schedules."""
from tvm import te
from tvm.tir import IntImm
from ..utils import is_empty_shape


Expand Down Expand Up @@ -100,18 +101,20 @@ def schedule_concatenate(outs):
def vectorize(sch, tensor, vectorize_limit):
"""Internal vectorization function for concatenate."""
inner_axis = s[tensor].op.axis[len(s[tensor].op.axis) - 1]
inner_length = tensor.shape[len(tensor.shape) - 1].value
if inner_length <= vectorize_limit:
sch[tensor].vectorize(inner_axis)
else:
split_factor = 1
for i in range(vectorize_limit, 1, -1):
if inner_length % i == 0:
split_factor = i
break
if split_factor > 1:
_, inner_i = sch[tensor].split(inner_axis, split_factor)
sch[tensor].vectorize(inner_i)
# Check that the tensor shape is static. Otherwise skip vectorization.
if isinstance(tensor.shape[len(tensor.shape) - 1], IntImm):
inner_length = tensor.shape[len(tensor.shape) - 1].value
if inner_length <= vectorize_limit:
sch[tensor].vectorize(inner_axis)
else:
split_factor = 1
for i in range(vectorize_limit, 1, -1):
if inner_length % i == 0:
split_factor = i
break
if split_factor > 1:
_, inner_i = sch[tensor].split(inner_axis, split_factor)
sch[tensor].vectorize(inner_i)

outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
x = outs[0]
Expand Down
74 changes: 66 additions & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3654,14 +3654,14 @@ def verify_cond_loop():


def verify_count_loop():
y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1])
y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1])
scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1])
y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [])
y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [])
scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [])
cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, [])
cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, [])
iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, [])

y = np.array([-2]).astype(np.float32)
y = np.array(-2).astype(np.float32)

iter_cast_node = helper.make_node(
"Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT
Expand Down Expand Up @@ -3693,11 +3693,11 @@ def verify_count_loop():
inputs=[
onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []),
onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]),
onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, []),
],
outputs=[
onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]),
onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, []),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5]),
],
)
loop_model = onnx.helper.make_model(loop_graph)
Expand All @@ -3708,11 +3708,69 @@ def verify_count_loop():
verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True)


def verify_tensor_loop():
y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3])
y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3])
scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3])
cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, [])
cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, [])
iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, [])

y = np.random.normal(size=[3, 3, 3, 3]).astype(np.float32)

iter_cast_node = helper.make_node(
"Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT
)

y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"])

identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"])

scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"])

loop_body = helper.make_graph(
[identity_node, iter_cast_node, y_add_node, scan_identity_node],
"loop_body",
[iter_count, cond_in, y_in],
[cond_out, y_out, scan_out],
)

loop_node = helper.make_node(
"Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body
)

trip_count = np.array(5).astype(np.int64)
cond = np.array(1).astype(np.bool)
loop_graph = onnx.helper.make_graph(
[loop_node],
"loop_outer",
inputs=[
onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []),
onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]),
],
outputs=[
onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]),
],
)
loop_model = onnx.helper.make_model(loop_graph)

trip_count = np.array(5).astype(np.int64)
cond = np.array(1).astype(np.bool)
input_vals = [trip_count, cond, y]
verify_with_ort_with_inputs(
loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True
)


def test_loop():
# Test a loop that exits once a condition is met.
verify_cond_loop()
# Test a loop that exits after a fixed number of iterations.
# Test a loop that exits after a fixed number of iterations with scalar outputs.
verify_count_loop()
# Test a loop that uses an array output.
verify_tensor_loop()


def verify_if(cond_array):
Expand Down

0 comments on commit 1ea436c

Please sign in to comment.