diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 5b7fd32add4c..2071a43f828b 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -26,6 +26,7 @@ from .. import op as reg from .. import strategy from ..op import OpPattern +from .image import resize # resize @@ -58,6 +59,36 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.resize") +@reg.register_convert_op_layout("image.resize") +def convert_image_resize(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current resize op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data input. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + new_attrs = dict(attrs) + assert len(desired_layouts) == 1, "Only one desired layout is expected" + desired_layout = str(desired_layouts[0]) + assert desired_layout != "default", "Layout cannot be default" + new_attrs["layout"] = desired_layout + return resize(*inputs, **new_attrs) + + @script def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 9c3d60198add..2c90d7b8a057 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -33,6 +33,31 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); +template +Array > ResizeInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + + if (new_in_layouts.defined()) { + ICHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout new_layout = new_in_layouts[0]; + Layout old_layout = old_in_layouts[0]; + if (!new_layout.Equals(old_layout) && raw_layout.Equals(old_layout) && + new_layout->axes.size() == old_layout->axes.size()) { + // Follow input layout + params->layout = new_layout.name(); + } + } + + Layout inferred_layout(params->layout); + return Array >{{inferred_layout}, {inferred_layout}}; +} + bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); @@ -102,6 +127,7 @@ RELAY_REGISTER_OP("image.resize") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) .add_type_rel("Resize", ResizeRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(Resize3dAttrs); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 4710d50ea8e4..88590c946e88 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1797,6 +1797,90 @@ def expected(): _test_conv_reduce_convert_layout2() +def test_image_resize_convert_layout(): + def _test_image_resize_convert_layout_nchw_to_nhwc(): + def before(): + x = relay.var("x", shape=(1, 2, 4, 4)) + y = relay.image.resize(x, (8, 8)) + y = relay.Function([x], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 2, 4, 4)) + x = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def _test_image_resize_convert_layout_nhwc_to_nchw(): + def before(): + x = relay.var("x", shape=(1, 4, 4, 2)) + y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.Function([x], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 4, 4, 2)) + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.image.resize(x, (8, 8), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + _test_image_resize_convert_layout_nchw_to_nhwc() + _test_image_resize_convert_layout_nhwc_to_nchw() + + +def test_conv_image_resize_convert_layout(): + """Check that layout transforms are propagated through image resize.""" + + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.image.resize(y, (112, 112), layout="NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, "NHWC", "NCHW") + w = relay.layout_transform(w, "HWIO", "OIHW") + y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.image.resize(y, (112, 112), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1828,3 +1912,5 @@ def expected(): test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() test_conv_strided_slice_axes_convert_layout() + test_image_resize_convert_layout() + test_conv_image_resize_convert_layout()