diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 693589fecfb4..c9f14c91c7b1 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -116,13 +116,14 @@ Array GetExcludeAxes(size_t indim, const Array& inaxis) { } // Return the modified layout for AlterOpLayout pass. +template InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - const auto* attrs_ptr = attrs.as(); + const auto* attrs_ptr = attrs.as(); ICHECK(attrs_ptr); - ObjectPtr params = make_object(*attrs_ptr); + ObjectPtr params = make_object(*attrs_ptr); // Get the reduce axes. Array> old_in_shapes; @@ -152,11 +153,14 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, for (auto iter_var : layout->axes) { const auto& layout_axis = LayoutAxis::Get(iter_var); const std::string& layout_dim = layout_axis.name(); - if (old_r_dims.count(layout_dim)) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } // Collect only the primal axis. if (layout_axis.IsPrimal()) { + if (old_r_dims.count(layout_dim) && !params->exclude) { + new_r_axes.push_back(tvm::Integer(axis_index)); + } + if (!old_r_dims.count(layout_dim) && params->exclude) { + new_r_axes.push_back(tvm::Integer(axis_index)); + } if (!old_r_dims.count(layout_dim) || params->keepdims) { inferred_out_string += layout_dim; } @@ -171,18 +175,24 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, std::string new_layout_string; Array new_r_axes; + Array new_input_layouts; + + auto check_num_input_layouts = [](Array in_layouts) { + // The second case is for variance op + ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2); + }; if (new_in_layouts.defined() && r_axes.size()) { // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the // modified layout axes. - ICHECK_EQ(new_in_layouts.size(), 1); - ICHECK_EQ(old_in_layouts.size(), 1); + check_num_input_layouts(new_in_layouts); + check_num_input_layouts(old_in_layouts); // Get inferred_in and inferred_out from new_in_layout. std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]); params->axis = new_r_axes; } else if (old_in_layouts.defined()) { - ICHECK_EQ(old_in_layouts.size(), 1); + check_num_input_layouts(old_in_layouts); // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout. if (old_in_layouts[0].defined()) { @@ -190,7 +200,13 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, } } - return InferCorrectLayoutOutput({inferred_in}, {inferred_out}, Attrs(params)); + new_input_layouts.push_back(inferred_in); + + if (old_in_layouts.size() == 2) { + new_input_layouts.push_back(inferred_in); + } + + return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params)); } template @@ -389,6 +405,7 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array ArgMinCompute(const Attrs& attrs, const Array& inputs, @@ -405,6 +422,7 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMinCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array SumCompute(const Attrs& attrs, const Array& inputs, @@ -433,7 +451,7 @@ Example:: .set_attrs_type() .set_support_level(4) .add_type_rel("Reduce", ReduceRel) - .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("FTVMCompute", SumCompute) .set_attr("TOpPattern", kCommReduce); @@ -468,6 +486,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", AllCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array AnyCompute(const Attrs& attrs, const Array& inputs, @@ -516,6 +535,7 @@ RELAY_REGISTER_REDUCE_OP("max") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MaxCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array MinCompute(const Attrs& attrs, const Array& inputs, @@ -531,6 +551,7 @@ RELAY_REGISTER_REDUCE_OP("min") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MinCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array ProdCompute(const Attrs& attrs, const Array& inputs, @@ -551,10 +572,10 @@ Example:: [[1,4],[4,3],[5,2]], [[7,1],[7,2],[7,3]]] - mean(data, axis=1) + prod(data, axis=1) [35562240] - mean(data, axis=[1,2]) + prod(data, axis=[1,2]) [ 36 480 2058] )code" TVM_ADD_FILELINE) @@ -562,6 +583,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", ProdCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array MeanCompute(const Attrs& attrs, const Array& inputs, @@ -600,6 +622,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MeanCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -675,6 +698,7 @@ RELAY_REGISTER_OP("variance") .add_argument("mean", "Tensor", "The mean tensor.") .add_type_rel("Variance", VarianceRel) .set_attr("FTVMCompute", VarianceCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); } // namespace relay diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index ef5824c957e8..3310b6b2ed69 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -486,8 +486,7 @@ def before(): beta = relay.var("beta") y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3) y = y[0] - y = relay.Function(analysis.free_vars(y), y) - return y + return relay.Function(analysis.free_vars(y), y) def alter_conv2d(attrs, inputs, tinfos, out_type): data, weight = inputs @@ -509,9 +508,8 @@ def expected(): bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") - y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC") - mean = relay.mean(y, axis=3, exclude=True) - var = relay.variance(y, axis=3, exclude=True) + mean = relay.mean(y, axis=1, exclude=True) + var = relay.variance(y, axis=1, exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) denom = denom * gamma diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 2b7e3e9eb3a9..a1965aa2d0c5 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Test alter op layout pass""" +import pytest + import tvm from tvm import te @@ -1925,37 +1927,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type assert test_infer_correct_layout_flag == True +def test_reduce_op_convert_layout(): + for reduce_op in [relay.argmax, relay.mean, relay.max]: + + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = reduce_op(y, axis=[2, 3]) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight = relay.layout_transform(weight, "OIHW", "HWIO") + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = reduce_op(y, axis=[1, 2]) + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": - test_qnn_binary_no_convert_layout() - test_no_convert_layout() - test_conv_convert_layout() - test_conv_nhwc_convert_layout() - test_conv_bias_pool_convert_layout() - test_conv_concat_convert_layout() - test_dual_path_convert_layout() - test_bn_convert_layout() - test_slice_like_convert_layout() - test_transpose_convert_layout() - test_resnet_convert_layout() - test_scalar_convert_layout() - test_conv_bn_convert_layout() - test_qnn_conv_requantize_convert_layout() - test_qnn_conv_concat_convert_layout() - test_qnn_conv_add_convert_layout() - test_qnn_conv_nhwc_convert_layout() - test_conv_convert_kernel_layout() - test_conv_transpose_convert_layout() - test_conv_roi_align_convert_layout() - test_conv_roi_pool_convert_layout() - test_conv_strided_slice_convert_layout() - test_deformable_conv_bias_pool_convert_layout() - test_default_keyword() - test_different_ops_convert_layout() - test_no_desired_layout() - test_convert_with_config() - test_conv_squeeze_convert_layout() - test_conv_reduce_convert_layout() - test_conv_strided_slice_axes_convert_layout() - test_image_resize_convert_layout() - test_conv_image_resize_convert_layout() - test_infer_correct_layout() + pytest.main([__file__])