From 961db913f16c5334d4ab391dd7452a58adb30a86 Mon Sep 17 00:00:00 2001 From: "Joe (Chien-Chun) Chou" Date: Fri, 8 Oct 2021 18:02:42 -0700 Subject: [PATCH] [bug] pooling convert layout bug in pooling.cc and in test_pass_convert_op_layout.py --- src/relay/op/nn/pooling.cc | 3 +- .../relay/test_pass_convert_op_layout.py | 31 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0d40caa15052..5d6be2e74cb4 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -49,7 +49,8 @@ InferCorrectLayoutOutput PoolInferCorrectLayout(const Attrs& attrs, ICHECK(attrs_ptr); ObjectPtr params = make_object(*attrs_ptr); - if (new_in_layouts.defined()) { + // set to new_in_layouts[0].name() only when params->layout == "" + if ((params->layout == "") && new_in_layouts.defined()) { // Set the pool with the new layout. ICHECK_EQ(new_in_layouts.size(), 1); params->layout = new_in_layouts[0].name(); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 9b4d154360b2..a4e2a0f00e81 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -234,9 +234,9 @@ def expected(): # a useless tuple, which will be eliminated y = relay.Tuple([y])[0] y = relay.nn.relu(y) - y = relay.nn.max_pool2d(y, pool_size=(2, 2)) - y = relay.cast(y, "int32") y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC") + y = relay.cast(y, "int32") y = relay.nn.batch_flatten(y) y = relay.Function(analysis.free_vars(y), y) return y @@ -245,7 +245,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) def test_conv_concat_convert_layout(): @@ -330,7 +330,7 @@ def before(N, CI, H, W, CO, KH, KW, layout): y = relay.Function(analysis.free_vars(y), y) return y - def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): + def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout, max_pool_layout=None): layout_map = {"src": {}, "dst": {}} if src_layout == "NCHW": nchw = layout_map["src"] @@ -386,11 +386,10 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): ) y = relay.add(y, bias) y = relay.nn.relu(y) - y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout=layout_map["dst"]["data_layout"]) + if max_pool_layout != layout_map["dst"]["data_layout"]: + y = relay.layout_transform(y, layout_map["dst"]["data_layout"], max_pool_layout) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout=max_pool_layout) y = relay.cast(y, "int32") - y = relay.layout_transform( - y, layout_map["dst"]["data_layout"], layout_map["src"]["data_layout"] - ) y = relay.nn.batch_flatten(y) y = relay.Function(analysis.free_vars(y), y) return y @@ -398,16 +397,22 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): # NHWC -> NCHW a = before(1, 3, 224, 224, 32, 3, 3, "NHWC") a = run_opt_pass(a, transform.ConvertLayout({"nn.deformable_conv2d": ["NCHW", "default"]})) + # - in the before() func, its last argument "NHWC" is also the layout of max_pool b = run_opt_pass( - expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW"), transform.InferType() + # max_pool has its own layout argument + expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW", max_pool_layout="NHWC"), + transform.InferType(), ) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) # NCHW -> NHWC a = before(1, 3, 224, 224, 32, 3, 3, "NCHW") a = run_opt_pass(a, transform.ConvertLayout({"nn.deformable_conv2d": ["NHWC", "default"]})) + # - in the before() func, its last argument "NCHW" is also the layout of max_pool b = run_opt_pass( - expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC"), transform.InferType() + # max_pool has its own layout argument + expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC", max_pool_layout="NCHW"), + transform.InferType(), ) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -691,15 +696,15 @@ def expected(): y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1)) y2 = relay.nn.relu(y2) y = y + y2 - y = relay.nn.global_max_pool2d(y) y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.nn.global_max_pool2d(y, layout="NHWC") return relay.Function(analysis.free_vars(y), y) a = before() a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) def test_scalar_convert_layout():