Skip to content

Commit

Permalink
Add pad to size utility.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 532121341
  • Loading branch information
pabloduque0 authored and PIXDev committed May 15, 2023
1 parent 0ff5593 commit 1cbc89c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 1 deletion.
57 changes: 56 additions & 1 deletion dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

import functools
from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import chex
from dm_pix._src import color_conversion
Expand Down Expand Up @@ -291,6 +291,61 @@ def center_crop(
)


def pad_to_size(
image: chex.Array,
target_height: int,
target_width: int,
*,
mode: str = "constant",
pad_kwargs: Optional[Any] = None,
channel_axis: int = -1,
) -> chex.Array:
"""Pads an image to the given size keeping the original image centered.
For different padding methods and kwargs please see:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html
In case of odd size difference along any dimension the bottom/right side gets
the extra padding pixel.
Target size can be smaller than original size which results in a no-op for
such dimension.
Args:
image: a JAX array representing an image. Assumes that the image is either
...HWC or ...CHW.
target_height: target height to pad the image to.
target_width: target width to pad the image to.
mode: Mode for padding the images, see jax.numpy.pad for details. Default is
`constant`.
pad_kwargs: Keyword arguments to pass jax.numpy.pad, see documentation for
options.
channel_axis: the index of the channel axis.
Returns:
The padded image(s).
"""
chex.assert_rank(image, {3, 4})
batch, height, width, _ = _get_dimension_values(
image=image, channel_axis=channel_axis
)
delta_width = max(target_width - width, 0)
delta_height = max(target_height - height, 0)
if delta_width == 0 and delta_height == 0:
return image

left = delta_width // 2
right = max(target_width - (left + width), 0)
top = delta_height // 2
bottom = max(target_height - (top + height), 0)

pad_width = ((top, bottom), (left, right), (0, 0))
if batch:
pad_width = ((0, 0), *pad_width)

return jnp.pad(image, pad_width=pad_width, mode=mode, **pad_kwargs or {})


def flip_left_right(
image: chex.Array,
*,
Expand Down
47 changes: 47 additions & 0 deletions dm_pix/_src/augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,29 @@ def test_center_crop(self, images_list, height, width):
)
self._test_fn(images_list, jax_fn=center_crop, reference_fn=reference)

@parameterized.product(
images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE),
target_height=(156, 131, 200, 251),
target_width=(156, 111, 200, 251),
)
def test_pad_to_size(self, images_list, target_height, target_width):
pad_fn = functools.partial(
augment.pad_to_size,
target_height=target_height,
target_width=target_width,
mode="constant",
pad_kwargs={"constant_values": 0},
)
# We have to rely on `resize_with_crop_or_pad` as there are no pad to size
# equivalents.
reference_fn = functools.partial(
tf.image.resize_with_crop_or_pad,
target_height=target_height,
target_width=target_width,
)

self._test_fn(images_list, jax_fn=pad_fn, reference_fn=reference_fn)


class TestMatchReference(_ImageAugmentationTest):

Expand Down Expand Up @@ -487,6 +510,30 @@ def test_center_crop_size_bigger_than_original(
self.assertEqual(output.shape[1], expected_height)
self.assertEqual(output.shape[2], expected_width)

@parameterized.product(
images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE),
target_height=(55, 84),
target_width=(55, 84),
expected_height=(131, 131),
expected_width=(111, 111),
)
def test_pad_to_size_when_target_size_smaller_than_original(
self,
images_list,
target_height,
target_width,
expected_height,
expected_width,
):
output = augment.pad_to_size(
image=jnp.array(images_list),
target_height=target_height,
target_width=target_width,
)

self.assertEqual(output.shape[1], expected_height)
self.assertEqual(output.shape[2], expected_width)


if __name__ == "__main__":
jax.config.update("jax_default_matmul_precision", "float32")
Expand Down
7 changes: 7 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Augmentations
flip_left_right
flip_up_down
gaussian_blur
pad_to_size
random_brightness
random_contrast
random_crop
Expand Down Expand Up @@ -84,6 +85,12 @@ gaussian_blur

.. autofunction:: gaussian_blur

pad_to_size
~~~~~~~~~~~~~

.. autofunction:: pad_to_size


random_brightness
~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 1cbc89c

Please sign in to comment.