diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 19ca6129ecbe..1ff428ce333c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2752,6 +2752,46 @@ Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { return Call(op, {data, shape_like}, Attrs(attrs), {}); } +Array> SliceLikeInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + Array new_axes; + if (old_in_layouts.defined() && new_in_layouts.defined()) { + ICHECK_EQ(new_in_layouts.size(), 2); + ICHECK_EQ(new_in_layouts[0]->name, new_in_layouts[1]->name); + ICHECK_EQ(old_in_layouts.size(), 2); + ICHECK_EQ(old_in_layouts[0]->name, old_in_layouts[1]->name); + + auto old_layout = old_in_layouts[0]; + auto new_layout = new_in_layouts[0]; + + // Discard "const" qualifier. + auto* params = const_cast(attrs.as()); + ICHECK(params != nullptr); + + for (auto axis : params->axes) { + auto new_axis = new_layout.IndexOf(old_layout[axis->value]); + // Cannot find the target axis in the new layout. + if (new_axis == -1) { + new_axes.clear(); + break; + } + new_axes.push_back(new_axis); + } + if (!new_axes.empty()) { + params->axes = std::move(new_axes); + return Array>({{new_layout, new_layout}, {new_layout}}); + } + } + + if (old_in_layouts.defined()) { + ICHECK_EQ(old_in_layouts.size(), 2); + return {{old_in_layouts[0], old_in_layouts[1]}, {old_in_layouts[1]}}; + } + return Array>({{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}}); +} + Array SliceLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); @@ -2801,6 +2841,7 @@ RELAY_REGISTER_OP("slice_like") .set_support_level(10) .add_type_rel("SliceLike", SliceLikeRel) .set_attr("FTVMCompute", SliceLikeCompute) + .set_attr("FInferCorrectLayout", SliceLikeInferCorrectLayout) .set_attr("TOpPattern", kInjective); // relay.layout_transform diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 6765d1f69b00..4c4bb9dee937 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -499,6 +499,75 @@ def before(): assert len(has_lt) == 1 +def test_slice_like_convert_layout(): + def verify_slice_like(after, expected_axes): + # Verify if the slice_like after the convert layout has the expected axes. + has_expected = list() + checker = lambda x: has_expected.append( + isinstance(x, tvm.relay.expr.Call) + and x.op.name == "slice_like" + and str(x.attrs.axes) == str(expected_axes) + ) + relay.analysis.post_order_visit(after, checker) + assert any(has_expected) + + def func_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + out = relay.slice_like(y, y, axes=[1, 2]) + return relay.Function(analysis.free_vars(out), out) + + after = run_opt_pass(func_nhwc(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + verify_slice_like(after, [2, 3]) + + def func_nchw(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(32, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + out = relay.slice_like(y, y, axes=[2, 3]) + return relay.Function(analysis.free_vars(out), out) + + after = run_opt_pass(func_nchw(), transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) + verify_slice_like(after, [1, 2]) + + def func_vars(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + # z has no layout information so convert layout won't happen. + z = relay.var("y", shape=(1, 56, 56, 32)) + out = relay.slice_like(y, z, axes=[1, 2]) + return relay.Function(analysis.free_vars(out), out) + + after = run_opt_pass(func_vars(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + verify_slice_like(after, [1, 2]) + + def test_resnet_convert_layout(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -1412,6 +1481,7 @@ def expected(): test_conv_concat_convert_layout() test_dual_path_convert_layout() test_bn_convert_layout() + test_slice_like_convert_layout() test_resnet_convert_layout() test_scalar_convert_layout() test_conv_bn_convert_layout()