diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b34e6c723645..b656e738eadc 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -268,6 +268,13 @@ def _impl(inputs, attr, params, mod): pad_h = _get_pad_pair(in_w, kernel_w, stride_w) attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] + elif attr["padding"] == "EXPLICIT": + paddings = attr["explicit_paddings"] + assert len(paddings) == 8 + if flip_layout or attr["data_format"] == "NHWC": + attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]] + else: + attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]] else: msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid." raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) @@ -278,7 +285,7 @@ def _impl(inputs, attr, params, mod): out = AttrCvt( op_name=_dimension_picker(name), transforms={"kernel_shape": "pool_size", "data_format": "layout"}, - ignores=["ksize"], + ignores=["ksize", "explicit_paddings"], extras={"ceil_mode": False}, custom_check=_dimension_constraint(), )(inputs, attr) @@ -418,6 +425,13 @@ def _impl(inputs, attr, params, mod): pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] + elif attr["padding"] == "EXPLICIT": + paddings = attr["explicit_paddings"] + assert len(paddings) == 8 + if flip_layout or attr["data_format"] == "NHWC": + attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]] + else: + attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid." raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) @@ -626,7 +640,27 @@ def _impl(inputs, attr, params, mod): pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] - + elif attr["padding"] == "EXPLICIT": + paddings = attr["explicit_paddings"] + assert len(paddings) == 10 + if flip_layout or attr["data_format"] == "NDHWC": + attr["padding"] = [ + paddings[2], + paddings[4], + paddings[6], + paddings[3], + paddings[5], + paddings[7], + ] + else: + attr["padding"] = [ + paddings[4], + paddings[6], + paddings[8], + paddings[5], + paddings[7], + paddings[9], + ] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid." raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) @@ -1416,9 +1450,9 @@ def _squeeze(): def _impl(inputs, attr, params, mod): if len(attr["squeeze_dims"]) == 0: attr["squeeze_dims"] = None - return AttrCvt(op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T"])( - inputs, attr - ) + return AttrCvt( + op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T", "_cloned"] + )(inputs, attr) return _impl diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 34ee0f3528ae..1bd286040934 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -414,6 +414,16 @@ def test_forward_pooling(): pooling_type=pool_type, dilation_rate=[2], ) + # Explicit padding + if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"): + _test_pooling( + input_shape=[2, 9, 10, 2], + window_shape=[4, 4], + padding=[[0, 0], [0, 1], [2, 3], [0, 0]], + pooling_type="MAX", + dilation_rate=[1, 1], + strides=[1, 1], + ) ####################################################################### @@ -830,6 +840,36 @@ def test_forward_convolution(): [4, 8, 8, 176], add_shapes_to_graph_def=False, ) + # Explicit padding + if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"): + _test_convolution( + "conv", + [4, 8, 8, 16], + [1, 1, 16, 32], + [1, 1], + [1, 1], + [[0, 0], [2, 3], [0, 1], [0, 0]], + "NHWC", + ) + _test_convolution( + "depthwise", + [4, 8, 8, 16], + [1, 1, 16, 1], + [1, 1], + [1, 1], + [[0, 0], [2, 3], [0, 1], [0, 0]], + "NHWC", + ) + _test_convolution( + "conv_transpose", + [4, 8, 8, 32], + [3, 3, 176, 32], + [1, 1], + [2, 2], + [[0, 0], [1, 0], [1, 0], [0, 0]], + "NHWC", + [4, 16, 16, 176], + ) #######################################################################