Skip to content

Commit

Permalink
[Frontend][Tensorflow] Support explicit_paddings for TF 2.x (apache#7445
Browse files Browse the repository at this point in the history
)

* Ignore some TF2.0 attributes

* Support explicit padding for conv2d, max_pool, conv3d

* Remove conv3d explicit padding test since TF API doesn't allow it
  • Loading branch information
Trevor Morris authored and trevor-m committed Mar 2, 2021
1 parent ba66c50 commit ddaac21
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 5 deletions.
44 changes: 39 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ def _impl(inputs, attr, params, mod):
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)

attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
elif attr["padding"] == "EXPLICIT":
paddings = attr["explicit_paddings"]
assert len(paddings) == 8
if flip_layout or attr["data_format"] == "NHWC":
attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]]
else:
attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]]
else:
msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid."
raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
Expand All @@ -278,7 +285,7 @@ def _impl(inputs, attr, params, mod):
out = AttrCvt(
op_name=_dimension_picker(name),
transforms={"kernel_shape": "pool_size", "data_format": "layout"},
ignores=["ksize"],
ignores=["ksize", "explicit_paddings"],
extras={"ceil_mode": False},
custom_check=_dimension_constraint(),
)(inputs, attr)
Expand Down Expand Up @@ -418,6 +425,13 @@ def _impl(inputs, attr, params, mod):
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
elif attr["padding"] == "EXPLICIT":
paddings = attr["explicit_paddings"]
assert len(paddings) == 8
if flip_layout or attr["data_format"] == "NHWC":
attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]]
else:
attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
Expand Down Expand Up @@ -626,7 +640,27 @@ def _impl(inputs, attr, params, mod):
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]

elif attr["padding"] == "EXPLICIT":
paddings = attr["explicit_paddings"]
assert len(paddings) == 10
if flip_layout or attr["data_format"] == "NDHWC":
attr["padding"] = [
paddings[2],
paddings[4],
paddings[6],
paddings[3],
paddings[5],
paddings[7],
]
else:
attr["padding"] = [
paddings[4],
paddings[6],
paddings[8],
paddings[5],
paddings[7],
paddings[9],
]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
Expand Down Expand Up @@ -1445,9 +1479,9 @@ def _squeeze():
def _impl(inputs, attr, params, mod):
if len(attr["squeeze_dims"]) == 0:
attr["squeeze_dims"] = None
return AttrCvt(op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T"])(
inputs, attr
)
return AttrCvt(
op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T", "_cloned"]
)(inputs, attr)

return _impl

Expand Down
40 changes: 40 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ def test_forward_pooling():
pooling_type=pool_type,
dilation_rate=[2],
)
# Explicit padding
if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
_test_pooling(
input_shape=[2, 9, 10, 2],
window_shape=[4, 4],
padding=[[0, 0], [0, 1], [2, 3], [0, 0]],
pooling_type="MAX",
dilation_rate=[1, 1],
strides=[1, 1],
)


#######################################################################
Expand Down Expand Up @@ -830,6 +840,36 @@ def test_forward_convolution():
[4, 8, 8, 176],
add_shapes_to_graph_def=False,
)
# Explicit padding
if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
_test_convolution(
"conv",
[4, 8, 8, 16],
[1, 1, 16, 32],
[1, 1],
[1, 1],
[[0, 0], [2, 3], [0, 1], [0, 0]],
"NHWC",
)
_test_convolution(
"depthwise",
[4, 8, 8, 16],
[1, 1, 16, 1],
[1, 1],
[1, 1],
[[0, 0], [2, 3], [0, 1], [0, 0]],
"NHWC",
)
_test_convolution(
"conv_transpose",
[4, 8, 8, 32],
[3, 3, 176, 32],
[1, 1],
[2, 2],
[[0, 0], [1, 0], [1, 0], [0, 0]],
"NHWC",
[4, 16, 16, 176],
)


#######################################################################
Expand Down

0 comments on commit ddaac21

Please sign in to comment.