From 1cbc89ccdddb5568c6192a5429a2dc6ebbb06855 Mon Sep 17 00:00:00 2001 From: Pablo Duque Date: Mon, 15 May 2023 08:31:36 -0700 Subject: [PATCH] Add pad to size utility. PiperOrigin-RevId: 532121341 --- dm_pix/_src/augment.py | 57 ++++++++++++++++++++++++++++++++++++- dm_pix/_src/augment_test.py | 47 ++++++++++++++++++++++++++++++ docs/api.rst | 7 +++++ 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/dm_pix/_src/augment.py b/dm_pix/_src/augment.py index 01b898f..d28150c 100644 --- a/dm_pix/_src/augment.py +++ b/dm_pix/_src/augment.py @@ -20,7 +20,7 @@ """ import functools -from typing import Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import chex from dm_pix._src import color_conversion @@ -291,6 +291,61 @@ def center_crop( ) +def pad_to_size( + image: chex.Array, + target_height: int, + target_width: int, + *, + mode: str = "constant", + pad_kwargs: Optional[Any] = None, + channel_axis: int = -1, +) -> chex.Array: + """Pads an image to the given size keeping the original image centered. + + For different padding methods and kwargs please see: + https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html + + In case of odd size difference along any dimension the bottom/right side gets + the extra padding pixel. + + Target size can be smaller than original size which results in a no-op for + such dimension. + + Args: + image: a JAX array representing an image. Assumes that the image is either + ...HWC or ...CHW. + target_height: target height to pad the image to. + target_width: target width to pad the image to. + mode: Mode for padding the images, see jax.numpy.pad for details. Default is + `constant`. + pad_kwargs: Keyword arguments to pass jax.numpy.pad, see documentation for + options. + channel_axis: the index of the channel axis. + + Returns: + The padded image(s). + """ + chex.assert_rank(image, {3, 4}) + batch, height, width, _ = _get_dimension_values( + image=image, channel_axis=channel_axis + ) + delta_width = max(target_width - width, 0) + delta_height = max(target_height - height, 0) + if delta_width == 0 and delta_height == 0: + return image + + left = delta_width // 2 + right = max(target_width - (left + width), 0) + top = delta_height // 2 + bottom = max(target_height - (top + height), 0) + + pad_width = ((top, bottom), (left, right), (0, 0)) + if batch: + pad_width = ((0, 0), *pad_width) + + return jnp.pad(image, pad_width=pad_width, mode=mode, **pad_kwargs or {}) + + def flip_left_right( image: chex.Array, *, diff --git a/dm_pix/_src/augment_test.py b/dm_pix/_src/augment_test.py index b8d0872..ce0a4f5 100644 --- a/dm_pix/_src/augment_test.py +++ b/dm_pix/_src/augment_test.py @@ -376,6 +376,29 @@ def test_center_crop(self, images_list, height, width): ) self._test_fn(images_list, jax_fn=center_crop, reference_fn=reference) + @parameterized.product( + images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), + target_height=(156, 131, 200, 251), + target_width=(156, 111, 200, 251), + ) + def test_pad_to_size(self, images_list, target_height, target_width): + pad_fn = functools.partial( + augment.pad_to_size, + target_height=target_height, + target_width=target_width, + mode="constant", + pad_kwargs={"constant_values": 0}, + ) + # We have to rely on `resize_with_crop_or_pad` as there are no pad to size + # equivalents. + reference_fn = functools.partial( + tf.image.resize_with_crop_or_pad, + target_height=target_height, + target_width=target_width, + ) + + self._test_fn(images_list, jax_fn=pad_fn, reference_fn=reference_fn) + class TestMatchReference(_ImageAugmentationTest): @@ -487,6 +510,30 @@ def test_center_crop_size_bigger_than_original( self.assertEqual(output.shape[1], expected_height) self.assertEqual(output.shape[2], expected_width) + @parameterized.product( + images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), + target_height=(55, 84), + target_width=(55, 84), + expected_height=(131, 131), + expected_width=(111, 111), + ) + def test_pad_to_size_when_target_size_smaller_than_original( + self, + images_list, + target_height, + target_width, + expected_height, + expected_width, + ): + output = augment.pad_to_size( + image=jnp.array(images_list), + target_height=target_height, + target_width=target_width, + ) + + self.assertEqual(output.shape[1], expected_height) + self.assertEqual(output.shape[2], expected_width) + if __name__ == "__main__": jax.config.update("jax_default_matmul_precision", "float32") diff --git a/docs/api.rst b/docs/api.rst index fe79103..3c43f68 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -16,6 +16,7 @@ Augmentations flip_left_right flip_up_down gaussian_blur + pad_to_size random_brightness random_contrast random_crop @@ -84,6 +85,12 @@ gaussian_blur .. autofunction:: gaussian_blur +pad_to_size +~~~~~~~~~~~~~ + +.. autofunction:: pad_to_size + + random_brightness ~~~~~~~~~~~~~~~~~