From 30d5631cfc781b67346d53dd35cd899acb97bcb7 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 26 Jan 2021 21:10:02 -0800 Subject: [PATCH 1/4] Add test for array loop. --- python/tvm/relay/frontend/onnx.py | 8 +++- tests/python/frontend/onnx/test_forward.py | 56 ++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7a3b168fc8fd..cbf766d4f7da 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2227,8 +2227,12 @@ def body_fn(*loop_inputs): # Add new scan outputs to tracking combined_scan_outputs = [] for i, scan in enumerate(scan_outputs): - new_scan = _op.expand_dims(new_scan_outputs[i], axis=0) - combined_scan = _op.concatenate([scan, new_scan], axis=0) + new_scan = new_scan_outputs[i] + new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) + scan_broadcast = _op.concatenate([_op.reshape(loop_count, [1]), new_scan_shape], axis=0) + new_scan = _op.expand_dims(new_scan, axis=0) + scan_with_shape = _op.broadcast_to(scan, scan_broadcast) + combined_scan = _op.concatenate([scan_with_shape, new_scan], axis=0) combined_scan_outputs.append(combined_scan) # Increment counter. diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 20937d2060c5..d39fdc54fcb2 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3708,11 +3708,67 @@ def verify_count_loop(): verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) +def verify_tensor_loop(): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + + y = np.random.normal(size=[3, 3, 3, 3]).astype(np.float32) + + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) + + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + + identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) + + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + + loop_body = helper.make_graph( + [identity_node, iter_cast_node, y_add_node, scan_identity_node], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) + + loop_node = helper.make_node( + "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body + ) + + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(np.bool) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) + + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(np.bool) + input_vals = [trip_count, cond, y] + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True) + + def test_loop(): # Test a loop that exits once a condition is met. verify_cond_loop() # Test a loop that exits after a fixed number of iterations. verify_count_loop() + # Test a loop that uses an array output. + verify_tensor_loop() def verify_if(cond_array): From 7ee26d18419eacf6e9f03545b7630fb7f527ecf3 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 27 Jan 2021 13:02:01 -0800 Subject: [PATCH 2/4] Fixed scalar issue. --- python/tvm/relay/frontend/onnx.py | 12 +- tests/python/frontend/onnx/test_forward.py | 162 ++++++++++----------- 2 files changed, 88 insertions(+), 86 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cbf766d4f7da..78888d6ce336 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2227,12 +2227,14 @@ def body_fn(*loop_inputs): # Add new scan outputs to tracking combined_scan_outputs = [] for i, scan in enumerate(scan_outputs): + rank = len(infer_shape(scan)) - 1 new_scan = new_scan_outputs[i] - new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) - scan_broadcast = _op.concatenate([_op.reshape(loop_count, [1]), new_scan_shape], axis=0) - new_scan = _op.expand_dims(new_scan, axis=0) - scan_with_shape = _op.broadcast_to(scan, scan_broadcast) - combined_scan = _op.concatenate([scan_with_shape, new_scan], axis=0) + expand_scan = _op.expand_dims(new_scan, axis=0) + if rank > 0: + new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) + scan_broadcast = _op.concatenate([_op.reshape(loop_count, [1]), new_scan_shape], axis=0) + scan = _op.broadcast_to(scan, scan_broadcast) + combined_scan = _op.concatenate([scan, expand_scan], axis=0) combined_scan_outputs.append(combined_scan) # Increment counter. diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d39fdc54fcb2..57221028b726 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3654,14 +3654,14 @@ def verify_cond_loop(): def verify_count_loop(): - y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) - y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) - scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1]) + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, []) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, []) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, []) cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) - y = np.array([-2]).astype(np.float32) + y = np.array(-2).astype(np.float32) iter_cast_node = helper.make_node( "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT @@ -3693,11 +3693,11 @@ def verify_count_loop(): inputs=[ onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, []), ], outputs=[ - onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]), + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, []), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5]), ], ) loop_model = onnx.helper.make_model(loop_graph) @@ -3953,78 +3953,78 @@ def verify_softplus(indata): if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_matmul() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_unary_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() - test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() - test_max_roi_pool() - test_roi_align() - test_range() + #test_flatten() + #test_reshape() + #test_shape() + #test_expand() + #test_power() + #test_squeeze() + #test_unsqueeze() + #test_slice() + #test_floor() + #test_ceil() + #test_round() + #test_isinf() + #test_isnan() + #test_clip() + #test_clip_min_max_as_inputs() + #test_onehot() + #test_matmul() + #test_gather() + #test_gatherelements() + #test_gather_nd() + #test_scatter() + #test_lrn() + #test_instance_norm() + #test_upsample() + #test_forward_min() + #test_forward_max() + #test_forward_mean() + #test_forward_hardsigmoid() + #test_forward_arg_min_max() + #test_softmax() + #test_constantofshape() + #test_all_reduce_funcs() + #test_pad() + #test_split() + #test_binary_ops() + #test_unary_ops() + #test_leaky_relu() + #test_elu() + #test_selu() + #test_prelu() + #test_ThresholdedRelu() + #test_LogSoftmax() + #test_resnet() + #test_inception() + #test_densenet() + #test_sign() + #test_not() + #test_and() + #test_tile() + #test_erf() + #test_where() + #test_or() + #test_depth_to_space() + #test_space_to_depth() + #test_batch_norm() + #test_batch_norm_dynamic_subgraph() + #test_conv() + #test_convtranspose() + #test_unsqueeze_constant() + #test_pooling() + #test_lppool() + #test_lstm() + #test_gru() + #test_resize() + #test_nonzero() + #test_topk() + #test_mod() + #test_xor() + #test_max_roi_pool() + #test_roi_align() + #test_range() test_loop() - test_size() - test_maxunpool() - test_softplus() + #test_size() + #test_maxunpool() + #test_softplus() From 64a7ada0f3e0c0ab8311b544736cc88b3cfa0786 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 27 Jan 2021 14:25:46 -0800 Subject: [PATCH 3/4] Formatting. --- python/tvm/relay/frontend/onnx.py | 4 +- tests/python/frontend/onnx/test_forward.py | 152 +++++++++++---------- 2 files changed, 80 insertions(+), 76 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 78888d6ce336..cb199eab25f2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2232,7 +2232,9 @@ def body_fn(*loop_inputs): expand_scan = _op.expand_dims(new_scan, axis=0) if rank > 0: new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) - scan_broadcast = _op.concatenate([_op.reshape(loop_count, [1]), new_scan_shape], axis=0) + scan_broadcast = _op.concatenate( + [_op.reshape(loop_count, [1]), new_scan_shape], axis=0 + ) scan = _op.broadcast_to(scan, scan_broadcast) combined_scan = _op.concatenate([scan, expand_scan], axis=0) combined_scan_outputs.append(combined_scan) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 57221028b726..70ac73e668b2 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3759,7 +3759,9 @@ def verify_tensor_loop(): trip_count = np.array(5).astype(np.int64) cond = np.array(1).astype(np.bool) input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True) + verify_with_ort_with_inputs( + loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True + ) def test_loop(): @@ -3953,78 +3955,78 @@ def verify_softplus(indata): if __name__ == "__main__": - #test_flatten() - #test_reshape() - #test_shape() - #test_expand() - #test_power() - #test_squeeze() - #test_unsqueeze() - #test_slice() - #test_floor() - #test_ceil() - #test_round() - #test_isinf() - #test_isnan() - #test_clip() - #test_clip_min_max_as_inputs() - #test_onehot() - #test_matmul() - #test_gather() - #test_gatherelements() - #test_gather_nd() - #test_scatter() - #test_lrn() - #test_instance_norm() - #test_upsample() - #test_forward_min() - #test_forward_max() - #test_forward_mean() - #test_forward_hardsigmoid() - #test_forward_arg_min_max() - #test_softmax() - #test_constantofshape() - #test_all_reduce_funcs() - #test_pad() - #test_split() - #test_binary_ops() - #test_unary_ops() - #test_leaky_relu() - #test_elu() - #test_selu() - #test_prelu() - #test_ThresholdedRelu() - #test_LogSoftmax() - #test_resnet() - #test_inception() - #test_densenet() - #test_sign() - #test_not() - #test_and() - #test_tile() - #test_erf() - #test_where() - #test_or() - #test_depth_to_space() - #test_space_to_depth() - #test_batch_norm() - #test_batch_norm_dynamic_subgraph() - #test_conv() - #test_convtranspose() - #test_unsqueeze_constant() - #test_pooling() - #test_lppool() - #test_lstm() - #test_gru() - #test_resize() - #test_nonzero() - #test_topk() - #test_mod() - #test_xor() - #test_max_roi_pool() - #test_roi_align() - #test_range() + test_flatten() + test_reshape() + test_shape() + test_expand() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_round() + test_isinf() + test_isnan() + test_clip() + test_clip_min_max_as_inputs() + test_onehot() + test_matmul() + test_gather() + test_gatherelements() + test_gather_nd() + test_scatter() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_all_reduce_funcs() + test_pad() + test_split() + test_binary_ops() + test_unary_ops() + test_leaky_relu() + test_elu() + test_selu() + test_prelu() + test_ThresholdedRelu() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() + test_conv() + test_convtranspose() + test_unsqueeze_constant() + test_pooling() + test_lppool() + test_lstm() + test_gru() + test_resize() + test_nonzero() + test_topk() + test_mod() + test_xor() + test_max_roi_pool() + test_roi_align() + test_range() test_loop() - #test_size() - #test_maxunpool() - #test_softplus() + test_size() + test_maxunpool() + test_softplus() From 6a003127d9942e7de7a53075bd3c1ceac99b6982 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 27 Jan 2021 14:35:33 -0800 Subject: [PATCH 4/4] Fix injective schedule for dynamic shapes. --- python/tvm/relay/frontend/onnx.py | 1 + python/tvm/topi/x86/injective.py | 27 ++++++++++++---------- tests/python/frontend/onnx/test_forward.py | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cb199eab25f2..b1b01b87f715 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2230,6 +2230,7 @@ def body_fn(*loop_inputs): rank = len(infer_shape(scan)) - 1 new_scan = new_scan_outputs[i] expand_scan = _op.expand_dims(new_scan, axis=0) + # For non scalar outputs we need to broadcast the initial value. if rank > 0: new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype) scan_broadcast = _op.concatenate( diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 29f903fd4e35..6492b78d6037 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name """x86 declaration and schedules.""" from tvm import te +from tvm.tir import IntImm from ..utils import is_empty_shape @@ -100,18 +101,20 @@ def schedule_concatenate(outs): def vectorize(sch, tensor, vectorize_limit): """Internal vectorization function for concatenate.""" inner_axis = s[tensor].op.axis[len(s[tensor].op.axis) - 1] - inner_length = tensor.shape[len(tensor.shape) - 1].value - if inner_length <= vectorize_limit: - sch[tensor].vectorize(inner_axis) - else: - split_factor = 1 - for i in range(vectorize_limit, 1, -1): - if inner_length % i == 0: - split_factor = i - break - if split_factor > 1: - _, inner_i = sch[tensor].split(inner_axis, split_factor) - sch[tensor].vectorize(inner_i) + # Check that the tensor shape is static. Otherwise skip vectorization. + if isinstance(tensor.shape[len(tensor.shape) - 1], IntImm): + inner_length = tensor.shape[len(tensor.shape) - 1].value + if inner_length <= vectorize_limit: + sch[tensor].vectorize(inner_axis) + else: + split_factor = 1 + for i in range(vectorize_limit, 1, -1): + if inner_length % i == 0: + split_factor = i + break + if split_factor > 1: + _, inner_i = sch[tensor].split(inner_axis, split_factor) + sch[tensor].vectorize(inner_i) outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs x = outs[0] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 70ac73e668b2..c666604d0e89 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3767,7 +3767,7 @@ def verify_tensor_loop(): def test_loop(): # Test a loop that exits once a condition is met. verify_cond_loop() - # Test a loop that exits after a fixed number of iterations. + # Test a loop that exits after a fixed number of iterations with scalar outputs. verify_count_loop() # Test a loop that uses an array output. verify_tensor_loop()