From b6b4376314fd4979e233e8f4b4cf1717bdb27e12 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Wed, 26 Jul 2023 11:44:14 +0800 Subject: [PATCH] Convert RandomZoom to backend-agnostic and improve `affine_transform` (#574) * Convert RandomZoom * Fix docstring * Update docstring * Update docstring * Address comments * Update `affine_transform` * Update * Fix fill_mode in torch * Update `affine_transform` * Revert RandomTranslation tests * Fix typo * Remove docstring * Redirect `fill_mode=reflect` to `fill_mode=mirror` using torch * Update docstring * Update tests because using torch `fill_mode=reflect` is actually `fill_mode=mirror` --- keras_core/backend/torch/image.py | 71 +++-- .../preprocessing/random_translation.py | 28 +- .../preprocessing/random_translation_test.py | 129 +++++--- .../layers/preprocessing/random_zoom.py | 289 ++++++++++++------ .../layers/preprocessing/random_zoom_test.py | 21 +- keras_core/ops/image.py | 16 +- keras_core/ops/image_test.py | 127 +++++--- 7 files changed, 483 insertions(+), 198 deletions(-) diff --git a/keras_core/backend/torch/image.py b/keras_core/backend/torch/image.py index a948b9a22..9c85861d6 100644 --- a/keras_core/backend/torch/image.py +++ b/keras_core/backend/torch/image.py @@ -86,8 +86,8 @@ def resize( "constant": "zeros", "nearest": "border", # "wrap", not supported by torch - # "mirror", not supported by torch - "reflect": "reflection", + "mirror": "reflection", # torch's reflection is mirror in other backends + "reflect": "reflection", # if fill_mode==reflect, redirect to mirror } @@ -122,7 +122,7 @@ def _apply_grid_transform( grid, mode=interpolation, padding_mode=fill_mode, - align_corners=False, + align_corners=True, ) # Fill with required color if fill_value is not None: @@ -187,9 +187,9 @@ def affine_transform( f"transform.shape={transform.shape}" ) - if fill_mode != "constant": + # the default fill_value of tnn.grid_sample is "zeros" + if fill_mode != "constant" or (fill_mode == "constant" and fill_value == 0): fill_value = None - fill_mode = AFFINE_TRANSFORM_FILL_MODES[fill_mode] # unbatched case need_squeeze = False @@ -202,23 +202,60 @@ def affine_transform( if data_format == "channels_last": image = image.permute((0, 3, 1, 2)) + batch_size = image.shape[0] + h, w, c = image.shape[-2], image.shape[-1], image.shape[-3] + + # get indices + shape = [h, w, c] # (H, W, C) + meshgrid = torch.meshgrid( + *[torch.arange(size) for size in shape], indexing="ij" + ) + indices = torch.concatenate( + [torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1 + ) + indices = torch.tile(indices, (batch_size, 1, 1, 1, 1)) + indices = indices.to(transform) + + # swap the values + a0 = transform[:, 0].clone() + a2 = transform[:, 2].clone() + b1 = transform[:, 4].clone() + b2 = transform[:, 5].clone() + transform[:, 0] = b1 + transform[:, 2] = b2 + transform[:, 4] = a0 + transform[:, 5] = a2 + # deal with transform - h, w = image.shape[2], image.shape[3] - theta = torch.zeros((image.shape[0], 2, 3)).to(transform) - theta[:, 0, 0] = transform[:, 0] - theta[:, 0, 1] = transform[:, 1] * h / w - theta[:, 0, 2] = ( - transform[:, 2] * 2 / w + theta[:, 0, 0] + theta[:, 0, 1] - 1 + transform = torch.nn.functional.pad( + transform, pad=[0, 1, 0, 0], mode="constant", value=1 ) - theta[:, 1, 0] = transform[:, 3] * w / h - theta[:, 1, 1] = transform[:, 4] - theta[:, 1, 2] = ( - transform[:, 5] * 2 / h + theta[:, 1, 0] + theta[:, 1, 1] - 1 + transform = torch.reshape(transform, (batch_size, 3, 3)) + offset = transform[:, 0:2, 2].clone() + offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0]) + transform[:, 0:2, 2] = 0 + + # transform the indices + coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = torch.moveaxis(coordinates, source=-1, destination=1) + coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1)) + coordinates = coordinates[:, 0:2, ..., 0] + coordinates = coordinates.permute((0, 2, 3, 1)) + + # normalize coordinates + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / (w - 1) * 2.0 - 1.0 + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / (h - 1) * 2.0 - 1.0 + grid = torch.stack( + [coordinates[:, :, :, 1], coordinates[:, :, :, 0]], dim=-1 ) - grid = tnn.affine_grid(theta, image.shape) affined = _apply_grid_transform( - image, grid, interpolation, fill_mode, fill_value + image, + grid, + interpolation=interpolation, + # if fill_mode==reflect, redirect to mirror + fill_mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode], + fill_value=fill_value, ) if data_format == "channels_last": diff --git a/keras_core/layers/preprocessing/random_translation.py b/keras_core/layers/preprocessing/random_translation.py index 4127bb9cf..5bfff919e 100644 --- a/keras_core/layers/preprocessing/random_translation.py +++ b/keras_core/layers/preprocessing/random_translation.py @@ -48,17 +48,23 @@ class RandomTranslation(TFDataLayer): left by 20%, and shifted right by 30%. `width_factor=0.2` results in an output height shifted left or right by 20%. fill_mode: Points outside the boundaries of the input are filled - according to the given mode - (one of `{"constant", "reflect", "wrap", "nearest"}`). - - *reflect*: `(d c b a | a b c d | d c b a)` The input is extended - by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` The input is extended - by filling all values beyond the edge with the same constant - value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by - wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` The input is extended - by the nearest pixel. + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. interpolation: Interpolation mode. Supported values: `"nearest"`, `"bilinear"`. seed: Integer. Used to create a random seed. diff --git a/keras_core/layers/preprocessing/random_translation_test.py b/keras_core/layers/preprocessing/random_translation_test.py index f88ac0357..73a2ff0c8 100644 --- a/keras_core/layers/preprocessing/random_translation_test.py +++ b/keras_core/layers/preprocessing/random_translation_test.py @@ -58,15 +58,27 @@ def test_random_translation_with_inference_mode(self): @parameterized.parameters(["channels_first", "channels_last"]) def test_random_translation_up_numeric_reflect(self, data_format): input_image = np.arange(0, 25) - expected_output = np.asarray( - [ - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - [20, 21, 22, 23, 24], - [20, 21, 22, 23, 24], - ] - ) + if backend.backend() == "torch": + # redirect fill_mode=reflect to fill_mode=mirror + expected_output = np.asarray( + [ + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [15, 16, 17, 18, 19], + ] + ) + else: + expected_output = np.asarray( + [ + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [20, 21, 22, 23, 24], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 5, 5, 1)) expected_output = backend.convert_to_tensor( @@ -133,15 +145,27 @@ def test_random_translation_up_numeric_constant(self, data_format): def test_random_translation_down_numeric_reflect(self, data_format): input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. - expected_output = np.asarray( - [ - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - ] - ) + if backend.backend() == "torch": + # redirect fill_mode=reflect to fill_mode=mirror + expected_output = np.asarray( + [ + [5, 6, 7, 8, 9], + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + ] + ) + else: + expected_output = np.asarray( + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 5, 5, 1)) expected_output = backend.convert_to_tensor( @@ -172,18 +196,33 @@ def test_random_translation_asymmetric_size_numeric_reflect( ): input_image = np.arange(0, 16) # Shifting by .2 * 5 = 1 pixel. - expected_output = np.asarray( - [ - [6, 7], - [4, 5], - [2, 3], - [0, 1], - [0, 1], - [2, 3], - [4, 5], - [6, 7], - ] - ) + if backend.backend() == "torch": + # redirect fill_mode=reflect to fill_mode=mirror + expected_output = np.asarray( + [ + [8, 9], + [6, 7], + [4, 5], + [2, 3], + [0, 1], + [2, 3], + [4, 5], + [6, 7], + ] + ) + else: + expected_output = np.asarray( + [ + [6, 7], + [4, 5], + [2, 3], + [0, 1], + [0, 1], + [2, 3], + [4, 5], + [6, 7], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 8, 2, 1)) expected_output = backend.convert_to_tensor( @@ -251,15 +290,27 @@ def test_random_translation_down_numeric_constant(self, data_format): def test_random_translation_left_numeric_reflect(self, data_format): input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. - expected_output = np.asarray( - [ - [1, 2, 3, 4, 4], - [6, 7, 8, 9, 9], - [11, 12, 13, 14, 14], - [16, 17, 18, 19, 19], - [21, 22, 23, 24, 24], - ] - ) + if backend.backend() == "torch": + # redirect fill_mode=reflect to fill_mode=mirror + expected_output = np.asarray( + [ + [1, 2, 3, 4, 3], + [6, 7, 8, 9, 8], + [11, 12, 13, 14, 13], + [16, 17, 18, 19, 18], + [21, 22, 23, 24, 23], + ] + ) + else: + expected_output = np.asarray( + [ + [1, 2, 3, 4, 4], + [6, 7, 8, 9, 9], + [11, 12, 13, 14, 14], + [16, 17, 18, 19, 19], + [21, 22, 23, 24, 24], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 5, 5, 1)) expected_output = backend.convert_to_tensor( diff --git a/keras_core/layers/preprocessing/random_zoom.py b/keras_core/layers/preprocessing/random_zoom.py index 385464460..f5afb1fda 100644 --- a/keras_core/layers/preprocessing/random_zoom.py +++ b/keras_core/layers/preprocessing/random_zoom.py @@ -1,14 +1,11 @@ -import numpy as np - from keras_core import backend from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer -from keras_core.utils import backend_utils -from keras_core.utils.module_utils import tensorflow as tf +from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer +from keras_core.random.seed_generator import SeedGenerator @keras_core_export("keras_core.layers.RandomZoom") -class RandomZoom(Layer): +class RandomZoom(TFDataLayer): """A preprocessing layer which randomly zooms images during training. This layer will randomly zoom in or out on each axis of an image @@ -18,75 +15,85 @@ class RandomZoom(Layer): of integer or floating point dtype. By default, the layer will output floats. - **Note:** This layer wraps `tf.keras.layers.RandomZoom`. It cannot - be used as part of the compiled computation graph of a model with - any backend other than TensorFlow. - It can however be used with any backend when running eagerly. - It can also always be used as part of an input preprocessing pipeline - with any backend (outside the model itself), which is how we recommend - to use this layer. + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. **Note:** This layer is safe to use inside a `tf.data` pipeline (independently of which backend you're using). Args: - height_factor: a float represented as fraction of value, - or a tuple of size 2 representing lower and upper bound - for zooming vertically. When represented as a single float, - this value is used for both the upper and - lower bound. A positive value means zooming out, - while a negative value - means zooming in. For instance, `height_factor=(0.2, 0.3)` - result in an output zoomed out by a random amount - in the range `[+20%, +30%]`. - `height_factor=(-0.3, -0.2)` result in an output zoomed - in by a random amount in the range `[+20%, +30%]`. - width_factor: a float represented as fraction of value, - or a tuple of size 2 representing lower and upper bound - for zooming horizontally. When - represented as a single float, this value is used - for both the upper and - lower bound. For instance, `width_factor=(0.2, 0.3)` - result in an output - zooming out between 20% to 30%. - `width_factor=(-0.3, -0.2)` result in an - output zooming in between 20% to 30%. `None` means - i.e., zooming vertical and horizontal directions - by preserving the aspect ratio. Defaults to `None`. - fill_mode: Points outside the boundaries of the input are - filled according to the given mode - (one of `{"constant", "reflect", "wrap", "nearest"}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about - the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` + height_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for zooming vertically. + When represented as a single float, this value is used for both the + upper and lower bound. A positive value means zooming out, while a + negative value means zooming in. For instance, + `height_factor=(0.2, 0.3)` result in an output zoomed out by a + random amount in the range `[+20%, +30%]`. + `height_factor=(-0.3, -0.2)` result in an output zoomed in by a + random amount in the range `[+20%, +30%]`. + width_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for zooming horizontally. + When represented as a single float, this value is used for both the + upper and lower bound. For instance, `width_factor=(0.2, 0.3)` + result in an output zooming out between 20% to 30%. + `width_factor=(-0.3, -0.2)` result in an output zooming in between + 20% to 30%. `None` means i.e., zooming vertical and horizontal + directions by preserving the aspect ratio. Defaults to `None`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` The input is extended by filling all values beyond - the edge with the same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by - wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. interpolation: Interpolation mode. Supported values: `"nearest"`, `"bilinear"`. seed: Integer. Used to create a random seed. fill_value: a float represents the value to be filled outside the boundaries when `fill_mode="constant"`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. Example: >>> input_img = np.random.random((32, 224, 224, 3)) >>> layer = keras_core.layers.RandomZoom(.5, .2) >>> out_img = layer(input_img) - - Input shape: - 3D (unbatched) or 4D (batched) tensor with shape: - `(..., height, width, channels)`, in `"channels_last"` format. - - Output shape: - 3D (unbatched) or 4D (batched) tensor with shape: - `(..., height, width, channels)`, in `"channels_last"` format. """ + _FACTOR_VALIDATION_ERROR = ( + "The `factor` argument should be a number (or a list of two numbers) " + "in the range [-1.0, 1.0]. " + ) + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + def __init__( self, height_factor, @@ -95,46 +102,156 @@ def __init__( interpolation="bilinear", seed=None, fill_value=0.0, - name=None, + data_format=None, **kwargs, ): - if not tf.available: - raise ImportError( - "Layer RandomZoom requires TensorFlow. " - "Install it via `pip install tensorflow`." + super().__init__(**kwargs) + self.height_factor = height_factor + self.height_lower, self.height_upper = self._set_factor( + height_factor, "height_factor" + ) + self.width_factor = width_factor + if width_factor is not None: + self.width_lower, self.width_upper = self._set_factor( + width_factor, "width_factor" + ) + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." ) - super().__init__(name=name, **kwargs) - self.seed = seed or backend.random.make_default_seed() - self.layer = tf.keras.layers.RandomZoom( - height_factor=height_factor, - width_factor=width_factor, - fill_mode=fill_mode, - interpolation=interpolation, - seed=self.seed, - name=name, - fill_value=fill_value, - **kwargs, - ) - self._allow_non_tensor_positional_args = True - self._convert_input_args = False + self.fill_mode = fill_mode + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.data_format = backend.standardize_data_format(data_format) + self.supports_jit = False + def _set_factor(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + factor = abs(factor) + lower, upper = [-factor, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < -1.0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + def call(self, inputs, training=True): - if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): - inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) - outputs = self.layer.call(inputs, training=training) - if ( - backend.backend() != "tensorflow" - and not backend_utils.in_tf_graph() - ): - outputs = backend.convert_to_tensor(outputs) + inputs = self.backend.cast(inputs, self.compute_dtype) + if training: + return self._randomly_zoom_inputs(inputs) + else: + return inputs + + def _randomly_zoom_inputs(self, inputs): + unbatched = len(inputs.shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + batch_size = self.backend.shape(inputs)[0] + if self.data_format == "channels_first": + height = inputs.shape[-2] + width = inputs.shape[-1] + else: + height = inputs.shape[-3] + width = inputs.shape[-2] + + seed_generator = self._get_seed_generator(self.backend._backend) + height_zoom = self.backend.random.uniform( + minval=1.0 + self.height_lower, + maxval=1.0 + self.height_upper, + shape=[batch_size, 1], + seed=seed_generator, + ) + if self.width_factor is not None: + width_zoom = self.backend.random.uniform( + minval=1.0 + self.width_lower, + maxval=1.0 + self.width_upper, + shape=[batch_size, 1], + seed=seed_generator, + ) + else: + width_zoom = height_zoom + zooms = self.backend.cast( + self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1), + dtype="float32", + ) + + outputs = self.backend.image.affine_transform( + inputs, + transform=self._get_zoom_matrix(zooms, height, width), + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) return outputs + def _get_zoom_matrix(self, zooms, image_height, image_width): + num_zooms = self.backend.shape(zooms)[0] + # The zoom matrix looks like: + # [[zx 0 0] + # [0 zy 0] + # [0 0 1]] + # where the last entry is implicit. + # zoom matrices are always float32. + x_offset = ((image_width - 1.0) / 2.0) * (1.0 - zooms[:, 0:1]) + y_offset = ((image_height - 1.0) / 2.0) * (1.0 - zooms[:, 1:]) + return self.backend.numpy.concatenate( + [ + zooms[:, 0:1], + self.backend.numpy.zeros((num_zooms, 1)), + x_offset, + self.backend.numpy.zeros((num_zooms, 1)), + zooms[:, 1:], + y_offset, + self.backend.numpy.zeros((num_zooms, 2)), + ], + axis=1, + ) + def compute_output_shape(self, input_shape): - return tuple(self.layer.compute_output_shape(input_shape)) + return input_shape def get_config(self): - config = self.layer.get_config() - config.update({"seed": self.seed}) - return config + base_config = super().get_config() + config = { + "height_factor": self.height_factor, + "width_factor": self.width_factor, + "fill_mode": self.fill_mode, + "interpolation": self.interpolation, + "seed": self.seed, + "fill_value": self.fill_value, + "data_format": self.data_format, + } + return {**base_config, **config} diff --git a/keras_core/layers/preprocessing/random_zoom_test.py b/keras_core/layers/preprocessing/random_zoom_test.py index 331a14ef3..db438a7c6 100644 --- a/keras_core/layers/preprocessing/random_zoom_test.py +++ b/keras_core/layers/preprocessing/random_zoom_test.py @@ -34,9 +34,9 @@ def test_random_zoom_out_correctness(self): expected_output = np.asarray( [ [0, 0, 0, 0, 0], - [0, 5, 7, 9, 0], - [0, 10, 12, 14, 0], - [0, 20, 22, 24, 0], + [0, 2.7, 4.5, 6.3, 0], + [0, 10.2, 12.0, 13.8, 0], + [0, 17.7, 19.5, 21.3, 0], [0, 0, 0, 0, 0], ] ) @@ -48,7 +48,7 @@ def test_random_zoom_out_correctness(self): init_kwargs={ "height_factor": (0.5, 0.5), "width_factor": (0.8, 0.8), - "interpolation": "nearest", + "interpolation": "bilinear", "fill_mode": "constant", }, input_shape=None, @@ -62,11 +62,11 @@ def test_random_zoom_in_correctness(self): input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) expected_output = np.asarray( [ - [6, 7, 7, 8, 8], - [11, 12, 12, 13, 13], - [11, 12, 12, 13, 13], - [16, 17, 17, 18, 18], - [16, 17, 17, 18, 18], + [6.0, 6.5, 7.0, 7.5, 8.0], + [8.5, 9.0, 9.5, 10.0, 10.5], + [11.0, 11.5, 12.0, 12.5, 13.0], + [13.5, 14.0, 14.5, 15.0, 15.5], + [16.0, 16.5, 17.0, 17.5, 18.0], ] ) expected_output = backend.convert_to_tensor( @@ -77,7 +77,8 @@ def test_random_zoom_in_correctness(self): init_kwargs={ "height_factor": (-0.5, -0.5), "width_factor": (-0.5, -0.5), - "interpolation": "nearest", + "interpolation": "bilinear", + "fill_mode": "constant", }, input_shape=None, input_data=input_image, diff --git a/keras_core/ops/image.py b/keras_core/ops/image.py index 5383d6d99..a7af118b8 100644 --- a/keras_core/ops/image.py +++ b/keras_core/ops/image.py @@ -184,7 +184,21 @@ def affine_transform( fill_mode: Points outside the boundaries of the input are filled according to the given mode. Available methods are `"constant"`, `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. - Note that `"wrap"` is not supported by Torch backend. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. fill_value: Value used for points outside the boundaries of the input if `fill_mode="constant"`. Defaults to `0`. data_format: string, either `"channels_last"` or `"channels_first"`. diff --git a/keras_core/ops/image_test.py b/keras_core/ops/image_test.py index f5c409a7f..6452df04a 100644 --- a/keras_core/ops/image_test.py +++ b/keras_core/ops/image_test.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import scipy.ndimage import tensorflow as tf from absl.testing import parameterized @@ -51,6 +52,57 @@ def test_extract_patches(self): self.assertEqual(out.shape, (4, 4, 75)) +AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order + "nearest": 0, + "bilinear": 1, +} +AFFINE_TRANSFORM_FILL_MODES = { + "constant": "grid-constant", + "nearest": "nearest", + "wrap": "grid-wrap", + "mirror": "mirror", + "reflect": "reflect", +} + + +def _compute_affine_transform_coordinates(image, transform): + need_squeeze = False + if len(image.shape) == 3: # unbatched + need_squeeze = True + image = np.expand_dims(image, axis=0) + transform = np.expand_dims(transform, axis=0) + batch_size = image.shape[0] + # get indices + meshgrid = np.meshgrid( + *[np.arange(size) for size in image.shape[1:]], indexing="ij" + ) + indices = np.concatenate( + [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 + ) + indices = np.tile(indices, (batch_size, 1, 1, 1, 1)) + # swap the values + transform[:, 4], transform[:, 0] = ( + transform[:, 0].copy(), + transform[:, 4].copy(), + ) + transform[:, 5], transform[:, 2] = ( + transform[:, 2].copy(), + transform[:, 5].copy(), + ) + # deal with transform + transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1) + transform = np.reshape(transform, (batch_size, 3, 3)) + offset = np.pad(transform[:, 0:2, 2], pad_width=[[0, 0], [0, 1]]) + transform[:, 0:2, 2] = 0 + # transform the indices + coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = np.moveaxis(coordinates, source=-1, destination=1) + coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1)) + if need_squeeze: + coordinates = np.squeeze(coordinates, axis=0) + return coordinates + + class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( [ @@ -135,22 +187,34 @@ def test_resize(self, interpolation, antialias, data_format): ("nearest", "nearest", "channels_last"), ("bilinear", "wrap", "channels_last"), ("nearest", "wrap", "channels_last"), + ("bilinear", "mirror", "channels_last"), + ("nearest", "mirror", "channels_last"), ("bilinear", "reflect", "channels_last"), ("nearest", "reflect", "channels_last"), ("bilinear", "constant", "channels_first"), ] ) def test_affine_transform(self, interpolation, fill_mode, data_format): - if fill_mode == "wrap" and backend.backend() == "torch": + if backend.backend() == "torch" and fill_mode == "wrap": + self.skipTest( + "In torch backend, applying affine_transform with " + "fill_mode=wrap is not supported" + ) + if backend.backend() == "torch" and fill_mode == "reflect": self.skipTest( - "Applying affine transform with fill_mode=wrap is not support" - " in torch backend" + "In torch backend, applying affine_transform with " + "fill_mode=reflect is redirected to fill_mode=mirror" ) - if fill_mode == "wrap" and backend.backend() in ("jax", "numpy"): + if backend.backend() == "tensorflow" and fill_mode == "mirror": self.skipTest( - "The numerical results of applying affine transform with " - "fill_mode=wrap in tensorflow is inconsistent with jax and " - "numpy backends" + "In tensorflow backend, applying affine_transform with " + "fill_mode=mirror is not supported" + ) + if backend.backend() == "tensorflow" and fill_mode == "wrap": + self.skipTest( + "In tensorflow backend, the numerical results of applying " + "affine_transform with fill_mode=wrap is inconsistent with" + "scipy" ) # Unbatched case @@ -169,24 +233,18 @@ def test_affine_transform(self, interpolation, fill_mode, data_format): ) if data_format == "channels_first": x = np.transpose(x, (1, 2, 0)) - ref_out = tf.raw_ops.ImageProjectiveTransformV3( - images=tf.expand_dims(x, axis=0), - transforms=tf.cast(tf.expand_dims(transform, axis=0), tf.float32), - output_shape=tf.shape(x)[:-1], - fill_value=0, - interpolation=interpolation.upper(), - fill_mode=fill_mode.upper(), + coordinates = _compute_affine_transform_coordinates(x, transform) + ref_out = scipy.ndimage.map_coordinates( + x, + coordinates, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode], + prefilter=False, ) - ref_out = ref_out[0] if data_format == "channels_first": ref_out = np.transpose(ref_out, (2, 0, 1)) self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) - if backend.backend() == "torch": - # TODO: cannot pass with torch backend - with self.assertRaises(AssertionError): - self.assertAllClose(ref_out, out, atol=0.3) - else: - self.assertAllClose(ref_out, out, atol=0.3) + self.assertAllClose(ref_out, out, atol=0.3) # Batched case if data_format == "channels_first": @@ -204,23 +262,24 @@ def test_affine_transform(self, interpolation, fill_mode, data_format): ) if data_format == "channels_first": x = np.transpose(x, (0, 2, 3, 1)) - ref_out = tf.raw_ops.ImageProjectiveTransformV3( - images=x, - transforms=tf.cast(transform, tf.float32), - output_shape=tf.shape(x)[1:-1], - fill_value=0, - interpolation=interpolation.upper(), - fill_mode=fill_mode.upper(), + coordinates = _compute_affine_transform_coordinates(x, transform) + ref_out = np.stack( + [ + scipy.ndimage.map_coordinates( + x[i], + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode], + prefilter=False, + ) + for i in range(x.shape[0]) + ], + axis=0, ) if data_format == "channels_first": ref_out = np.transpose(ref_out, (0, 3, 1, 2)) self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) - if backend.backend() == "torch": - # TODO: cannot pass with torch backend - with self.assertRaises(AssertionError): - self.assertAllClose(ref_out, out, atol=0.3) - else: - self.assertAllClose(ref_out, out, atol=0.3) + self.assertAllClose(ref_out, out, atol=0.3) @parameterized.parameters( [