From adfd2fc15422090b3ece44a9d2368f393ba3ba0a Mon Sep 17 00:00:00 2001 From: nicolasb0 <48060924+nicolasb0@users.noreply.github.com> Date: Mon, 15 Jul 2024 04:29:33 +0200 Subject: [PATCH] fix(torch-frontend): fix test_torch_max_pool2d (#28790) Co-authored-by: Sam-Armstrong --- .../backends/paddle/experimental/layers.py | 41 +++++++++++++---- .../torch/nn/functional/pooling_functions.py | 46 ++++++++++++++++--- ivy/functional/ivy/experimental/layers.py | 8 +++- ivy/functional/ivy/general.py | 9 ++-- .../test_functional/test_pooling_functions.py | 4 +- 5 files changed, 87 insertions(+), 21 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/layers.py b/ivy/functional/backends/paddle/experimental/layers.py index 2464dffbc17f6..e2240922838a6 100644 --- a/ivy/functional/backends/paddle/experimental/layers.py +++ b/ivy/functional/backends/paddle/experimental/layers.py @@ -1,6 +1,8 @@ # global +import math from typing import Optional, Union, Tuple, List, Literal, Sequence, Callable import paddle +from ivy.functional.ivy.experimental.layers import _padding_ceil_mode from ivy.functional.ivy.layers import ( _handle_padding, _depth_max_pooling_helper, @@ -150,14 +152,37 @@ def max_pool2d( " backend" ) - padding = ( - [item for sublist in padding for item in sublist] - if not isinstance(padding, str) - else padding - ) # paddle's expected format - res = paddle.nn.functional.max_pool2d( - x, kernel, strides, padding=padding, ceil_mode=ceil_mode - ) + x_shape = list(x.shape[2:]) + if not depth_pooling: + new_kernel = [ + kernel[i] + (kernel[i] - 1) * (dilation[i] - 1) for i in range(dims) + ] + if isinstance(padding, str): + pad_h = _handle_padding(x_shape[0], strides[0], new_kernel[0], padding) + pad_w = _handle_padding(x_shape[1], strides[1], new_kernel[1], padding) + padding = [ + (pad_h // 2, pad_h - pad_h // 2), + (pad_w // 2, pad_w - pad_w // 2), + ] + + if ceil_mode: + for i in range(dims): + padding[i] = _padding_ceil_mode( + x_shape[i], new_kernel[i], padding[i], strides[i] + ) + # paddle pad takes width padding first, then height padding + padding = (padding[1], padding[0]) + pad_list = [item for sublist in padding for item in sublist] + x = paddle.nn.functional.pad(x, pad_list, value=-math.inf) + else: + if isinstance(padding, list) and any( + [item != 0 for sublist in padding for item in sublist] + ): + raise NotImplementedError( + "Nonzero explicit padding is not supported for depthwise max pooling" + ) + + res = paddle.nn.functional.max_pool2d(x, kernel, strides, padding="VALID") if depth_pooling: res = paddle.transpose(res, perm=[0, 2, 1, 3]) diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index c442c55afc917..6555dcf1eaf7f 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -8,6 +8,7 @@ from ivy.functional.frontends.torch.func_wrapper import ( to_ivy_arrays_and_back, ) +from ivy.functional.ivy.experimental.layers import _padding_ceil_mode @with_unsupported_dtypes( @@ -291,6 +292,29 @@ def max_pool2d( if isinstance(stride, (list, tuple)) and len(stride) == 1: stride = stride[0] + DIMS = 2 + x_shape = list(input.shape[2:]) + new_kernel = [ + kernel_size[i] + (kernel_size[i] - 1) * (dilation[i] - 1) for i in range(DIMS) + ] + + if isinstance(padding, int): + padding = [(padding,) * 2] * DIMS + elif isinstance(padding, (list, tuple)) and len(padding) == DIMS: + padding = [(padding[i],) * 2 for i in range(DIMS)] + + if isinstance(stride, int): + stride = (stride,) * DIMS + + if ceil_mode: + for i in range(DIMS): + padding[i] = _padding_ceil_mode( + x_shape[i], new_kernel[i], padding[i], stride[i] + ) + # torch pad takes width padding first, then height padding + padding = (padding[1], padding[0]) + pad_array = ivy.flatten(padding) + in_shape = input.shape H = in_shape[-2] W = in_shape[-1] @@ -300,20 +324,28 @@ def max_pool2d( # for each position in the sliding window input_indices = torch_frontend.arange(0, n_indices, dtype=torch_frontend.int64) input_indices = input_indices.reshape((1, 1, H, W)) + + # find the indices of the max value for each position of the sliding window + input = torch_frontend.nn.functional.pad( + input, + pad_array, + value=float("-inf"), + ) + + input_indices = torch_frontend.nn.functional.pad( + input_indices, + pad_array, + value=0, + ) + unfolded_indices = torch_frontend.nn.functional.unfold( input_indices, kernel_size=kernel_size, - padding=padding, + padding=0, dilation=dilation, stride=stride, ).permute((0, 2, 1))[0] - # find the indices of the max value for each position of the sliding window - input = torch_frontend.nn.functional.pad( - input, - [padding] * 4 if isinstance(padding, int) else padding * 2, - value=float("-inf"), - ) unfolded_values = torch_frontend.nn.functional.unfold( input, kernel_size=kernel_size, padding=0, dilation=dilation, stride=stride ) diff --git a/ivy/functional/ivy/experimental/layers.py b/ivy/functional/ivy/experimental/layers.py index d70889c3745c9..c78b823779337 100644 --- a/ivy/functional/ivy/experimental/layers.py +++ b/ivy/functional/ivy/experimental/layers.py @@ -2017,15 +2017,19 @@ def _output_ceil_shape(w, f, p, s): def _padding_ceil_mode(w, f, p, s, return_added_padding=False): - remaining_pixels = (w - f + p[0]) % s + remaining_pixels = (w - f + sum(p)) % s added_padding = 0 + # if the additional pixels potentially captured thanks to ceil mode + # are all in the padding then no padding is added + if remaining_pixels <= p[1] and s + p[1] - remaining_pixels >= f: + return (p, added_padding) if return_added_padding else p if s > 1 and remaining_pixels != 0 and f > 1: input_size = w + sum(p) # making sure that the remaining pixels are supposed # to be covered by the window # they won't be covered if stride is big enough to skip them if input_size - remaining_pixels - (f - 1) + s > input_size: - return p + return (p, added_padding) if return_added_padding else p output_shape = _output_ceil_shape( w, f, diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index d312ecdb72daf..198ce630eb909 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2879,15 +2879,18 @@ def set_item( ivy.array([[ 0, -1, 20], [10, 10, 10]]) """ - # TODO: we may be able to remove this logic by instead tracing _parse_query as a node in the graph?? + # TODO: we may be able to remove this logic by instead tracing _parse_query + # as a node in the graph?? if isinstance(query, (list, tuple)) and any( [q is Ellipsis or (isinstance(q, slice) and q.stop is None) for q in query] ): # use numpy for item setting when an ellipsis or unbounded slice is present, # as they would otherwise cause static dim sizes to be traced into the graph # NOTE: this does however cause tf.function to be incompatible - np_array = x.numpy() - np_array[query] = np.asarray(val) + x_stop_gradient = ivy.stop_gradient(x, preserve_type=False) + np_array = x_stop_gradient.numpy() + val_stop_gradient = ivy.stop_gradient(val, preserve_type=False) + np_array[query] = np.asarray(val_stop_gradient) return ivy.array(np_array) if copy: diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index 665e54d9156c5..1f55c74e0d547 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -1,5 +1,5 @@ # global -from hypothesis import strategies as st +from hypothesis import assume, strategies as st # local import ivy_tests.test_ivy.helpers as helpers @@ -495,6 +495,8 @@ def test_torch_max_pool2d( dtype, x, kernel, stride, padding, dilation = x_k_s_p if not isinstance(padding, int): padding = [pad[0] for pad in padding] + # TODO: Remove this once the paddle backend supports dilation + assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw,