Skip to content

Commit

Permalink
Refactor torch's affine_transform (#929)
Browse files Browse the repository at this point in the history
* Refactor torch's `affine_transform`

* Update docstring

* Update RandomTranslation test
  • Loading branch information
james77777778 authored Sep 20, 2023
1 parent 4c3697f commit 5af4344
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 222 deletions.
165 changes: 51 additions & 114 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import operator

import torch
import torch.nn.functional as tnn

from keras_core.backend.torch.core import convert_to_tensor

Expand Down Expand Up @@ -82,78 +81,19 @@ def resize(
return resized


AFFINE_TRANSFORM_INTERPOLATIONS = (
"nearest",
"bilinear",
)
AFFINE_TRANSFORM_INTERPOLATIONS = {
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "zeros",
"nearest": "border",
# "wrap", not supported by torch
"mirror": "reflection", # torch's reflection is mirror in other backends
"reflect": "reflection", # if fill_mode==reflect, redirect to mirror
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def _apply_grid_transform(
img,
grid,
interpolation="bilinear",
fill_mode="zeros",
fill_value=None,
):
"""
Modified from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_geometry.py
""" # noqa: E501

# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype)

shape = float_img.shape
# Append a dummy mask for customized fill colors, should be faster than
# grid_sample() twice
if fill_value is not None:
mask = torch.ones(
(shape[0], 1, shape[2], shape[3]),
dtype=float_img.dtype,
device=float_img.device,
)
float_img = torch.cat((float_img, mask), dim=1)

float_img = tnn.grid_sample(
float_img,
grid,
mode=interpolation,
padding_mode=fill_mode,
align_corners=True,
)
# Fill with required color
if fill_value is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = (
fill_value
if isinstance(fill_value, (tuple, list))
else [float(fill_value)]
)
fill_img = torch.tensor(
fill_list, dtype=float_img.dtype, device=float_img.device
).view(1, -1, 1, 1)
if interpolation == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill =
# img * mask - fill * mask + fill =
# mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

img = float_img.round_().to(img.dtype) if not fp else float_img
return img


def affine_transform(
image,
transform,
Expand All @@ -162,17 +102,16 @@ def affine_transform(
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
)

image = convert_to_tensor(image)
Expand All @@ -191,10 +130,6 @@ def affine_transform(
f"transform.shape={transform.shape}"
)

# the default fill_value of tnn.grid_sample is "zeros"
if fill_mode != "constant" or (fill_mode == "constant" and fill_value == 0):
fill_value = None

# unbatched case
need_squeeze = False
if image.ndim == 3:
Expand All @@ -203,22 +138,23 @@ def affine_transform(
if transform.ndim == 1:
transform = transform.unsqueeze(dim=0)

if data_format == "channels_last":
image = image.permute((0, 3, 1, 2))
if data_format == "channels_first":
image = image.permute((0, 2, 3, 1))

batch_size = image.shape[0]
h, w, c = image.shape[-2], image.shape[-1], image.shape[-3]

# get indices
shape = [h, w, c] # (H, W, C)
meshgrid = torch.meshgrid(
*[torch.arange(size) for size in shape], indexing="ij"
*[
torch.arange(size, dtype=transform.dtype, device=transform.device)
for size in image.shape[1:]
],
indexing="ij",
)
indices = torch.concatenate(
[torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1
)
indices = torch.tile(indices, (batch_size, 1, 1, 1, 1))
indices = indices.to(transform)

# swap the values
a0 = transform[:, 0].clone()
Expand All @@ -243,27 +179,23 @@ def affine_transform(
coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = torch.moveaxis(coordinates, source=-1, destination=1)
coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1))
coordinates = coordinates[:, 0:2, ..., 0]
coordinates = coordinates.permute((0, 2, 3, 1))

# normalize coordinates
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / (w - 1) * 2.0 - 1.0
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / (h - 1) * 2.0 - 1.0
grid = torch.stack(
[coordinates[:, :, :, 1], coordinates[:, :, :, 0]], dim=-1
)

affined = _apply_grid_transform(
image,
grid,
interpolation=interpolation,
# if fill_mode==reflect, redirect to mirror
fill_mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
fill_value=fill_value,
# Note: torch.stack is faster than torch.vmap when the batch size is small.
affined = torch.stack(
[
map_coordinates(
image[i],
coordinates[i],
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
fill_mode=fill_mode,
fill_value=fill_value,
)
for i in range(len(image))
],
)

if data_format == "channels_last":
affined = affined.permute((0, 2, 3, 1))
if data_format == "channels_first":
affined = affined.permute((0, 3, 1, 2))
if need_squeeze:
affined = affined.squeeze(dim=0)
return affined
Expand All @@ -282,7 +214,8 @@ def _reflect_index_fixer(index, size):


_INDEX_FIXERS = {
"constant": lambda index, size: index,
# we need to take care of out-of-bound indices in torch
"constant": lambda index, size: torch.clip(index, 0, size - 1),
"nearest": lambda index, size: torch.clip(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
Expand All @@ -301,8 +234,7 @@ def _nearest_indices_and_weights(coordinate):
coordinate if _is_integer(coordinate) else torch.round(coordinate)
)
index = coordinate.to(torch.int32)
weight = torch.tensor(1).to(torch.int32)
return [(index, weight)]
return [(index, 1)]


def _linear_indices_and_weights(coordinate):
Expand All @@ -318,7 +250,9 @@ def map_coordinates(
):
input_arr = convert_to_tensor(input)
coordinate_arrs = [convert_to_tensor(c) for c in coordinates]
fill_value = convert_to_tensor(fill_value, input_arr.dtype)
# skip tensor creation as possible
if isinstance(fill_value, (int, float)) and _is_integer(input_arr):
fill_value = int(fill_value)

if len(coordinates) != len(input_arr.shape):
raise ValueError(
Expand All @@ -330,23 +264,26 @@ def map_coordinates(
if index_fixer is None:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(_INDEX_FIXERS.keys())}. Received: "
f"fill_mode={fill_mode}"
f"{set(_INDEX_FIXERS.keys())}. Received: fill_mode={fill_mode}"
)

def is_valid(index, size):
if fill_mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

if fill_mode == "constant":

def is_valid(index, size):
return (0 <= index) & (index < size)

else:

def is_valid(index, size):
return True

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
Expand Down
Loading

0 comments on commit 5af4344

Please sign in to comment.