From 52a123a02624f77d7c0d8c9b91d45e0e468b00e3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 20 Sep 2021 20:52:04 +0900 Subject: [PATCH] Register layout conversion function to more reduce ops --- src/relay/op/tensor/reduce.cc | 19 +++-- .../relay/test_pass_convert_op_layout.py | 80 +++++++++++-------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 693589fecfb46..3b009ee7026e5 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -116,13 +116,14 @@ Array GetExcludeAxes(size_t indim, const Array& inaxis) { } // Return the modified layout for AlterOpLayout pass. +template InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - const auto* attrs_ptr = attrs.as(); + const auto* attrs_ptr = attrs.as(); ICHECK(attrs_ptr); - ObjectPtr params = make_object(*attrs_ptr); + ObjectPtr params = make_object(*attrs_ptr); // Get the reduce axes. Array> old_in_shapes; @@ -389,6 +390,7 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array ArgMinCompute(const Attrs& attrs, const Array& inputs, @@ -405,6 +407,7 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMinCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array SumCompute(const Attrs& attrs, const Array& inputs, @@ -433,7 +436,7 @@ Example:: .set_attrs_type() .set_support_level(4) .add_type_rel("Reduce", ReduceRel) - .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("FTVMCompute", SumCompute) .set_attr("TOpPattern", kCommReduce); @@ -468,6 +471,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", AllCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array AnyCompute(const Attrs& attrs, const Array& inputs, @@ -516,6 +520,7 @@ RELAY_REGISTER_REDUCE_OP("max") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MaxCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array MinCompute(const Attrs& attrs, const Array& inputs, @@ -531,6 +536,7 @@ RELAY_REGISTER_REDUCE_OP("min") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MinCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array ProdCompute(const Attrs& attrs, const Array& inputs, @@ -551,10 +557,10 @@ Example:: [[1,4],[4,3],[5,2]], [[7,1],[7,2],[7,3]]] - mean(data, axis=1) + prod(data, axis=1) [35562240] - mean(data, axis=[1,2]) + prod(data, axis=[1,2]) [ 36 480 2058] )code" TVM_ADD_FILELINE) @@ -562,6 +568,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", ProdCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array MeanCompute(const Attrs& attrs, const Array& inputs, @@ -600,6 +607,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MeanCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -675,6 +683,7 @@ RELAY_REGISTER_OP("variance") .add_argument("mean", "Tensor", "The mean tensor.") .add_type_rel("Variance", VarianceRel) .set_attr("FTVMCompute", VarianceCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); } // namespace relay diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 2b7e3e9eb3a9f..a1965aa2d0c56 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Test alter op layout pass""" +import pytest + import tvm from tvm import te @@ -1925,37 +1927,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type assert test_infer_correct_layout_flag == True +def test_reduce_op_convert_layout(): + for reduce_op in [relay.argmax, relay.mean, relay.max]: + + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = reduce_op(y, axis=[2, 3]) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight = relay.layout_transform(weight, "OIHW", "HWIO") + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = reduce_op(y, axis=[1, 2]) + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": - test_qnn_binary_no_convert_layout() - test_no_convert_layout() - test_conv_convert_layout() - test_conv_nhwc_convert_layout() - test_conv_bias_pool_convert_layout() - test_conv_concat_convert_layout() - test_dual_path_convert_layout() - test_bn_convert_layout() - test_slice_like_convert_layout() - test_transpose_convert_layout() - test_resnet_convert_layout() - test_scalar_convert_layout() - test_conv_bn_convert_layout() - test_qnn_conv_requantize_convert_layout() - test_qnn_conv_concat_convert_layout() - test_qnn_conv_add_convert_layout() - test_qnn_conv_nhwc_convert_layout() - test_conv_convert_kernel_layout() - test_conv_transpose_convert_layout() - test_conv_roi_align_convert_layout() - test_conv_roi_pool_convert_layout() - test_conv_strided_slice_convert_layout() - test_deformable_conv_bias_pool_convert_layout() - test_default_keyword() - test_different_ops_convert_layout() - test_no_desired_layout() - test_convert_with_config() - test_conv_squeeze_convert_layout() - test_conv_reduce_convert_layout() - test_conv_strided_slice_axes_convert_layout() - test_image_resize_convert_layout() - test_conv_image_resize_convert_layout() - test_infer_correct_layout() + pytest.main([__file__])