From 5af4344b1e5d6eef8d2a8ac2fd02e048e10af1cb Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Thu, 21 Sep 2023 02:26:08 +0800 Subject: [PATCH] Refactor torch's `affine_transform` (#929) * Refactor torch's `affine_transform` * Update docstring * Update RandomTranslation test --- keras_core/backend/torch/image.py | 165 ++++++------------ .../preprocessing/random_translation_test.py | 129 +++++--------- keras_core/ops/image.py | 4 - keras_core/ops/image_test.py | 18 +- 4 files changed, 94 insertions(+), 222 deletions(-) diff --git a/keras_core/backend/torch/image.py b/keras_core/backend/torch/image.py index 7ef82eef2..aef4ea9de 100644 --- a/keras_core/backend/torch/image.py +++ b/keras_core/backend/torch/image.py @@ -3,7 +3,6 @@ import operator import torch -import torch.nn.functional as tnn from keras_core.backend.torch.core import convert_to_tensor @@ -82,78 +81,19 @@ def resize( return resized -AFFINE_TRANSFORM_INTERPOLATIONS = ( - "nearest", - "bilinear", -) +AFFINE_TRANSFORM_INTERPOLATIONS = { + "nearest": 0, + "bilinear": 1, +} AFFINE_TRANSFORM_FILL_MODES = { - "constant": "zeros", - "nearest": "border", - # "wrap", not supported by torch - "mirror": "reflection", # torch's reflection is mirror in other backends - "reflect": "reflection", # if fill_mode==reflect, redirect to mirror + "constant", + "nearest", + "wrap", + "mirror", + "reflect", } -def _apply_grid_transform( - img, - grid, - interpolation="bilinear", - fill_mode="zeros", - fill_value=None, -): - """ - Modified from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_geometry.py - """ # noqa: E501 - - # We are using context knowledge that grid should have float dtype - fp = img.dtype == grid.dtype - float_img = img if fp else img.to(grid.dtype) - - shape = float_img.shape - # Append a dummy mask for customized fill colors, should be faster than - # grid_sample() twice - if fill_value is not None: - mask = torch.ones( - (shape[0], 1, shape[2], shape[3]), - dtype=float_img.dtype, - device=float_img.device, - ) - float_img = torch.cat((float_img, mask), dim=1) - - float_img = tnn.grid_sample( - float_img, - grid, - mode=interpolation, - padding_mode=fill_mode, - align_corners=True, - ) - # Fill with required color - if fill_value is not None: - float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) - mask = mask.expand_as(float_img) - fill_list = ( - fill_value - if isinstance(fill_value, (tuple, list)) - else [float(fill_value)] - ) - fill_img = torch.tensor( - fill_list, dtype=float_img.dtype, device=float_img.device - ).view(1, -1, 1, 1) - if interpolation == "nearest": - bool_mask = mask < 0.5 - float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask] - else: # 'bilinear' - # The following is mathematically equivalent to: - # img * mask + (1.0 - mask) * fill = - # img * mask - fill * mask + fill = - # mask * (img - fill) + fill - float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) - - img = float_img.round_().to(img.dtype) if not fp else float_img - return img - - def affine_transform( image, transform, @@ -162,17 +102,16 @@ def affine_transform( fill_value=0, data_format="channels_last", ): - if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS: + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): raise ValueError( "Invalid value for argument `interpolation`. Expected of one " - f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " f"interpolation={interpolation}" ) - if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys(): + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: raise ValueError( "Invalid value for argument `fill_mode`. Expected of one " - f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. " - f"Received: fill_mode={fill_mode}" + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" ) image = convert_to_tensor(image) @@ -191,10 +130,6 @@ def affine_transform( f"transform.shape={transform.shape}" ) - # the default fill_value of tnn.grid_sample is "zeros" - if fill_mode != "constant" or (fill_mode == "constant" and fill_value == 0): - fill_value = None - # unbatched case need_squeeze = False if image.ndim == 3: @@ -203,22 +138,23 @@ def affine_transform( if transform.ndim == 1: transform = transform.unsqueeze(dim=0) - if data_format == "channels_last": - image = image.permute((0, 3, 1, 2)) + if data_format == "channels_first": + image = image.permute((0, 2, 3, 1)) batch_size = image.shape[0] - h, w, c = image.shape[-2], image.shape[-1], image.shape[-3] # get indices - shape = [h, w, c] # (H, W, C) meshgrid = torch.meshgrid( - *[torch.arange(size) for size in shape], indexing="ij" + *[ + torch.arange(size, dtype=transform.dtype, device=transform.device) + for size in image.shape[1:] + ], + indexing="ij", ) indices = torch.concatenate( [torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1 ) indices = torch.tile(indices, (batch_size, 1, 1, 1, 1)) - indices = indices.to(transform) # swap the values a0 = transform[:, 0].clone() @@ -243,27 +179,23 @@ def affine_transform( coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform) coordinates = torch.moveaxis(coordinates, source=-1, destination=1) coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1)) - coordinates = coordinates[:, 0:2, ..., 0] - coordinates = coordinates.permute((0, 2, 3, 1)) - - # normalize coordinates - coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / (w - 1) * 2.0 - 1.0 - coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / (h - 1) * 2.0 - 1.0 - grid = torch.stack( - [coordinates[:, :, :, 1], coordinates[:, :, :, 0]], dim=-1 - ) - affined = _apply_grid_transform( - image, - grid, - interpolation=interpolation, - # if fill_mode==reflect, redirect to mirror - fill_mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode], - fill_value=fill_value, + # Note: torch.stack is faster than torch.vmap when the batch size is small. + affined = torch.stack( + [ + map_coordinates( + image[i], + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for i in range(len(image)) + ], ) - if data_format == "channels_last": - affined = affined.permute((0, 2, 3, 1)) + if data_format == "channels_first": + affined = affined.permute((0, 3, 1, 2)) if need_squeeze: affined = affined.squeeze(dim=0) return affined @@ -282,7 +214,8 @@ def _reflect_index_fixer(index, size): _INDEX_FIXERS = { - "constant": lambda index, size: index, + # we need to take care of out-of-bound indices in torch + "constant": lambda index, size: torch.clip(index, 0, size - 1), "nearest": lambda index, size: torch.clip(index, 0, size - 1), "wrap": lambda index, size: index % size, "mirror": _mirror_index_fixer, @@ -301,8 +234,7 @@ def _nearest_indices_and_weights(coordinate): coordinate if _is_integer(coordinate) else torch.round(coordinate) ) index = coordinate.to(torch.int32) - weight = torch.tensor(1).to(torch.int32) - return [(index, weight)] + return [(index, 1)] def _linear_indices_and_weights(coordinate): @@ -318,7 +250,9 @@ def map_coordinates( ): input_arr = convert_to_tensor(input) coordinate_arrs = [convert_to_tensor(c) for c in coordinates] - fill_value = convert_to_tensor(fill_value, input_arr.dtype) + # skip tensor creation as possible + if isinstance(fill_value, (int, float)) and _is_integer(input_arr): + fill_value = int(fill_value) if len(coordinates) != len(input_arr.shape): raise ValueError( @@ -330,16 +264,9 @@ def map_coordinates( if index_fixer is None: raise ValueError( "Invalid value for argument `fill_mode`. Expected one of " - f"{set(_INDEX_FIXERS.keys())}. Received: " - f"fill_mode={fill_mode}" + f"{set(_INDEX_FIXERS.keys())}. Received: fill_mode={fill_mode}" ) - def is_valid(index, size): - if fill_mode == "constant": - return (0 <= index) & (index < size) - else: - return True - if order == 0: interp_fun = _nearest_indices_and_weights elif order == 1: @@ -347,6 +274,16 @@ def is_valid(index, size): else: raise NotImplementedError("map_coordinates currently requires order<=1") + if fill_mode == "constant": + + def is_valid(index, size): + return (0 <= index) & (index < size) + + else: + + def is_valid(index, size): + return True + valid_1d_interpolations = [] for coordinate, size in zip(coordinate_arrs, input_arr.shape): interp_nodes = interp_fun(coordinate) diff --git a/keras_core/layers/preprocessing/random_translation_test.py b/keras_core/layers/preprocessing/random_translation_test.py index 05ff78e22..4a3a4ad57 100644 --- a/keras_core/layers/preprocessing/random_translation_test.py +++ b/keras_core/layers/preprocessing/random_translation_test.py @@ -58,27 +58,15 @@ def test_random_translation_with_inference_mode(self): @parameterized.parameters(["channels_first", "channels_last"]) def test_random_translation_up_numeric_reflect(self, data_format): input_image = np.arange(0, 25) - if backend.backend() == "torch": - # redirect fill_mode=reflect to fill_mode=mirror - expected_output = np.asarray( - [ - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - [20, 21, 22, 23, 24], - [15, 16, 17, 18, 19], - ] - ) - else: - expected_output = np.asarray( - [ - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - [20, 21, 22, 23, 24], - [20, 21, 22, 23, 24], - ] - ) + expected_output = np.asarray( + [ + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [20, 21, 22, 23, 24], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 5, 5, 1)) expected_output = backend.convert_to_tensor( @@ -145,27 +133,15 @@ def test_random_translation_up_numeric_constant(self, data_format): def test_random_translation_down_numeric_reflect(self, data_format): input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. - if backend.backend() == "torch": - # redirect fill_mode=reflect to fill_mode=mirror - expected_output = np.asarray( - [ - [5, 6, 7, 8, 9], - [0, 1, 2, 3, 4], - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - ] - ) - else: - expected_output = np.asarray( - [ - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - ] - ) + expected_output = np.asarray( + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 5, 5, 1)) expected_output = backend.convert_to_tensor( @@ -196,33 +172,18 @@ def test_random_translation_asymmetric_size_numeric_reflect( ): input_image = np.arange(0, 16) # Shifting by .2 * 5 = 1 pixel. - if backend.backend() == "torch": - # redirect fill_mode=reflect to fill_mode=mirror - expected_output = np.asarray( - [ - [8, 9], - [6, 7], - [4, 5], - [2, 3], - [0, 1], - [2, 3], - [4, 5], - [6, 7], - ] - ) - else: - expected_output = np.asarray( - [ - [6, 7], - [4, 5], - [2, 3], - [0, 1], - [0, 1], - [2, 3], - [4, 5], - [6, 7], - ] - ) + expected_output = np.asarray( + [ + [6, 7], + [4, 5], + [2, 3], + [0, 1], + [0, 1], + [2, 3], + [4, 5], + [6, 7], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 8, 2, 1)) expected_output = backend.convert_to_tensor( @@ -290,27 +251,15 @@ def test_random_translation_down_numeric_constant(self, data_format): def test_random_translation_left_numeric_reflect(self, data_format): input_image = np.arange(0, 25) # Shifting by .2 * 5 = 1 pixel. - if backend.backend() == "torch": - # redirect fill_mode=reflect to fill_mode=mirror - expected_output = np.asarray( - [ - [1, 2, 3, 4, 3], - [6, 7, 8, 9, 8], - [11, 12, 13, 14, 13], - [16, 17, 18, 19, 18], - [21, 22, 23, 24, 23], - ] - ) - else: - expected_output = np.asarray( - [ - [1, 2, 3, 4, 4], - [6, 7, 8, 9, 9], - [11, 12, 13, 14, 14], - [16, 17, 18, 19, 19], - [21, 22, 23, 24, 24], - ] - ) + expected_output = np.asarray( + [ + [1, 2, 3, 4, 4], + [6, 7, 8, 9, 9], + [11, 12, 13, 14, 14], + [16, 17, 18, 19, 19], + [21, 22, 23, 24, 24], + ] + ) if data_format == "channels_last": input_image = np.reshape(input_image, (1, 5, 5, 1)) expected_output = backend.convert_to_tensor( diff --git a/keras_core/ops/image.py b/keras_core/ops/image.py index bc8e48d2d..d3ef02362 100644 --- a/keras_core/ops/image.py +++ b/keras_core/ops/image.py @@ -195,10 +195,6 @@ def affine_transform( The input is extended by wrapping around to the opposite edge. - `"nearest"`: `(a a a a | a b c d | d d d d)` The input is extended by the nearest pixel. - Note that when using torch backend, `"reflect"` is redirected to - `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not - support `"reflect"`. - Note that torch backend does not support `"wrap"`. fill_value: Value used for points outside the boundaries of the input if `fill_mode="constant"`. Defaults to `0`. data_format: string, either `"channels_last"` or `"channels_first"`. diff --git a/keras_core/ops/image_test.py b/keras_core/ops/image_test.py index f5869f950..6afe3f344 100644 --- a/keras_core/ops/image_test.py +++ b/keras_core/ops/image_test.py @@ -238,16 +238,6 @@ def test_resize(self, interpolation, antialias, data_format): ] ) def test_affine_transform(self, interpolation, fill_mode, data_format): - if backend.backend() == "torch" and fill_mode == "wrap": - self.skipTest( - "In torch backend, applying affine_transform with " - "fill_mode=wrap is not supported" - ) - if backend.backend() == "torch" and fill_mode == "reflect": - self.skipTest( - "In torch backend, applying affine_transform with " - "fill_mode=reflect is redirected to fill_mode=mirror" - ) if backend.backend() == "tensorflow" and fill_mode == "mirror": self.skipTest( "In tensorflow backend, applying affine_transform with " @@ -259,10 +249,10 @@ def test_affine_transform(self, interpolation, fill_mode, data_format): "affine_transform with fill_mode=wrap is inconsistent with" "scipy" ) - # TODO: `nearest` interpolation and `nearest` fill_mode in torch and jax - # causes random index shifting, resulting in significant differences in - # output which leads to failure - if backend.backend() in ("torch", "jax") and interpolation == "nearest": + # TODO: `nearest` interpolation in jax and torch causes random index + # shifting, resulting in significant differences in output which leads + # to failure + if backend.backend() in ("jax", "torch") and interpolation == "nearest": self.skipTest( f"In {backend.backend()} backend, " f"interpolation={interpolation} causes index shifting and "