Skip to content

Commit

Permalink
[torch_xla2] Fix max_pool2d_with_indices lowering (#7905)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored Aug 23, 2024
1 parent 7cbedca commit 64a1b7f
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,47 @@ def _aten_cat(tensors, dims=0):
return jnp.concatenate(tensors, dims)


def _ceil_mode_padding(
padding: list[int],
input_shape: list[int],
kernel_size: list[int],
stride: list[int],
ceil_mode: bool,
):
"""Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
Additional high padding could be required when ceil mode is set.
"""
ceil_mode_padding = []
for i in range(len(padding)):
left_padding = padding[i]
right_padding = left_padding

input_size = input_shape[2 + i]
output_size_rem = (input_size + 2 * left_padding - kernel_size[i]) % stride[
i
]
if ceil_mode and output_size_rem != 0:
extra_padding = stride[i] - output_size_rem
new_output_size = (
input_size
+ left_padding
+ right_padding
+ extra_padding
- kernel_size[i]
+ stride[i]
- 1
) // stride[i] + 1
# Ensure that the last pooling starts inside the image.
size_to_compare = input_size + left_padding

if (new_output_size - 1) * stride[i] < size_to_compare:
right_padding += extra_padding

ceil_mode_padding.append((left_padding, right_padding))
return ceil_mode_padding


@op(torch.ops.aten.max_pool2d_with_indices)
@op(torch.ops.aten.max_pool3d_with_indices)
def _aten_max_pool2d_with_indices(
Expand All @@ -808,9 +849,14 @@ def _aten_max_pool2d_with_indices(
kernel_size = tuple(kernel_size)
strides = tuple(strides)
if isinstance(padding, int):
padding = tuple((padding, padding) for _ in range(len(kernel_size)))
elif isinstance(padding, list):
padding = tuple((p, p) for p in padding)
padding = [padding for _ in range(len(kernel_size))]

input_shape = inputs.shape
if num_batch_dims == 0:
input_shape = [1, *input_shape]
padding = _ceil_mode_padding(
padding, input_shape, kernel_size, strides, ceil_mode
)

window_shape = kernel_size
num_batch_dims = inputs.ndim - (len(window_shape) + 1)
Expand Down Expand Up @@ -1402,40 +1448,6 @@ def adaptive_kernel_size(input_shape, output_shape):
return y


def _ceil_mode_padding(
padding: list[int],
input_shape: list[int],
kernel_size: list[int],
stride: list[int],
ceil_mode: bool,
):
"""Creates low and high padding specification for the given padding (which is symmetric) and ceil mode.
Additional high padding could be required when ceil mode is set.
"""
ceil_mode_padding = []
for i in range(len(padding)):
left_padding = padding[i]
right_padding = left_padding

input_size = input_shape[2 + i]
output_size_rem = (input_size + 2 * left_padding -
kernel_size[i]) % stride[i]
if ceil_mode and output_size_rem != 0:
extra_padding = stride[i] - output_size_rem
new_output_size = (input_size + left_padding + right_padding +
extra_padding - kernel_size[i] + stride[i] -
1) // stride[i] + 1
# Ensure that the last pooling starts inside the image.
size_to_compare = input_size + left_padding

if (new_output_size - 1) * stride[i] < size_to_compare:
right_padding += extra_padding

ceil_mode_padding.append((left_padding, right_padding))
return ceil_mode_padding


# aten.avg_pool2d
@op(torch.ops.aten.avg_pool2d)
@op(torch.ops.aten.avg_pool3d)
Expand Down

0 comments on commit 64a1b7f

Please sign in to comment.