Skip to content

Commit

Permalink
fix(torch-frontend): fix test_torch_max_pool2d (#28790)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam-Armstrong <samuel_e_armstrong@yahoo.co.uk>
  • Loading branch information
nicolasb0 and Sam-Armstrong authored Jul 15, 2024
1 parent 911b0f7 commit adfd2fc
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 21 deletions.
41 changes: 33 additions & 8 deletions ivy/functional/backends/paddle/experimental/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# global
import math
from typing import Optional, Union, Tuple, List, Literal, Sequence, Callable
import paddle
from ivy.functional.ivy.experimental.layers import _padding_ceil_mode
from ivy.functional.ivy.layers import (
_handle_padding,
_depth_max_pooling_helper,
Expand Down Expand Up @@ -150,14 +152,37 @@ def max_pool2d(
" backend"
)

padding = (
[item for sublist in padding for item in sublist]
if not isinstance(padding, str)
else padding
) # paddle's expected format
res = paddle.nn.functional.max_pool2d(
x, kernel, strides, padding=padding, ceil_mode=ceil_mode
)
x_shape = list(x.shape[2:])
if not depth_pooling:
new_kernel = [
kernel[i] + (kernel[i] - 1) * (dilation[i] - 1) for i in range(dims)
]
if isinstance(padding, str):
pad_h = _handle_padding(x_shape[0], strides[0], new_kernel[0], padding)
pad_w = _handle_padding(x_shape[1], strides[1], new_kernel[1], padding)
padding = [
(pad_h // 2, pad_h - pad_h // 2),
(pad_w // 2, pad_w - pad_w // 2),
]

if ceil_mode:
for i in range(dims):
padding[i] = _padding_ceil_mode(
x_shape[i], new_kernel[i], padding[i], strides[i]
)
# paddle pad takes width padding first, then height padding
padding = (padding[1], padding[0])
pad_list = [item for sublist in padding for item in sublist]
x = paddle.nn.functional.pad(x, pad_list, value=-math.inf)
else:
if isinstance(padding, list) and any(
[item != 0 for sublist in padding for item in sublist]
):
raise NotImplementedError(
"Nonzero explicit padding is not supported for depthwise max pooling"
)

res = paddle.nn.functional.max_pool2d(x, kernel, strides, padding="VALID")

if depth_pooling:
res = paddle.transpose(res, perm=[0, 2, 1, 3])
Expand Down
46 changes: 39 additions & 7 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ivy.functional.frontends.torch.func_wrapper import (
to_ivy_arrays_and_back,
)
from ivy.functional.ivy.experimental.layers import _padding_ceil_mode


@with_unsupported_dtypes(
Expand Down Expand Up @@ -291,6 +292,29 @@ def max_pool2d(
if isinstance(stride, (list, tuple)) and len(stride) == 1:
stride = stride[0]

DIMS = 2
x_shape = list(input.shape[2:])
new_kernel = [
kernel_size[i] + (kernel_size[i] - 1) * (dilation[i] - 1) for i in range(DIMS)
]

if isinstance(padding, int):
padding = [(padding,) * 2] * DIMS
elif isinstance(padding, (list, tuple)) and len(padding) == DIMS:
padding = [(padding[i],) * 2 for i in range(DIMS)]

if isinstance(stride, int):
stride = (stride,) * DIMS

if ceil_mode:
for i in range(DIMS):
padding[i] = _padding_ceil_mode(
x_shape[i], new_kernel[i], padding[i], stride[i]
)
# torch pad takes width padding first, then height padding
padding = (padding[1], padding[0])
pad_array = ivy.flatten(padding)

in_shape = input.shape
H = in_shape[-2]
W = in_shape[-1]
Expand All @@ -300,20 +324,28 @@ def max_pool2d(
# for each position in the sliding window
input_indices = torch_frontend.arange(0, n_indices, dtype=torch_frontend.int64)
input_indices = input_indices.reshape((1, 1, H, W))

# find the indices of the max value for each position of the sliding window
input = torch_frontend.nn.functional.pad(
input,
pad_array,
value=float("-inf"),
)

input_indices = torch_frontend.nn.functional.pad(
input_indices,
pad_array,
value=0,
)

unfolded_indices = torch_frontend.nn.functional.unfold(
input_indices,
kernel_size=kernel_size,
padding=padding,
padding=0,
dilation=dilation,
stride=stride,
).permute((0, 2, 1))[0]

# find the indices of the max value for each position of the sliding window
input = torch_frontend.nn.functional.pad(
input,
[padding] * 4 if isinstance(padding, int) else padding * 2,
value=float("-inf"),
)
unfolded_values = torch_frontend.nn.functional.unfold(
input, kernel_size=kernel_size, padding=0, dilation=dilation, stride=stride
)
Expand Down
8 changes: 6 additions & 2 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,15 +2017,19 @@ def _output_ceil_shape(w, f, p, s):


def _padding_ceil_mode(w, f, p, s, return_added_padding=False):
remaining_pixels = (w - f + p[0]) % s
remaining_pixels = (w - f + sum(p)) % s
added_padding = 0
# if the additional pixels potentially captured thanks to ceil mode
# are all in the padding then no padding is added
if remaining_pixels <= p[1] and s + p[1] - remaining_pixels >= f:
return (p, added_padding) if return_added_padding else p
if s > 1 and remaining_pixels != 0 and f > 1:
input_size = w + sum(p)
# making sure that the remaining pixels are supposed
# to be covered by the window
# they won't be covered if stride is big enough to skip them
if input_size - remaining_pixels - (f - 1) + s > input_size:
return p
return (p, added_padding) if return_added_padding else p
output_shape = _output_ceil_shape(
w,
f,
Expand Down
9 changes: 6 additions & 3 deletions ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,15 +2879,18 @@ def set_item(
ivy.array([[ 0, -1, 20],
[10, 10, 10]])
"""
# TODO: we may be able to remove this logic by instead tracing _parse_query as a node in the graph??
# TODO: we may be able to remove this logic by instead tracing _parse_query
# as a node in the graph??
if isinstance(query, (list, tuple)) and any(
[q is Ellipsis or (isinstance(q, slice) and q.stop is None) for q in query]
):
# use numpy for item setting when an ellipsis or unbounded slice is present,
# as they would otherwise cause static dim sizes to be traced into the graph
# NOTE: this does however cause tf.function to be incompatible
np_array = x.numpy()
np_array[query] = np.asarray(val)
x_stop_gradient = ivy.stop_gradient(x, preserve_type=False)
np_array = x_stop_gradient.numpy()
val_stop_gradient = ivy.stop_gradient(val, preserve_type=False)
np_array[query] = np.asarray(val_stop_gradient)
return ivy.array(np_array)

if copy:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from hypothesis import strategies as st
from hypothesis import assume, strategies as st

# local
import ivy_tests.test_ivy.helpers as helpers
Expand Down Expand Up @@ -495,6 +495,8 @@ def test_torch_max_pool2d(
dtype, x, kernel, stride, padding, dilation = x_k_s_p
if not isinstance(padding, int):
padding = [pad[0] for pad in padding]
# TODO: Remove this once the paddle backend supports dilation
assume(not (backend_fw == "paddle" and max(list(dilation)) > 1))
helpers.test_frontend_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
Expand Down

0 comments on commit adfd2fc

Please sign in to comment.