Skip to content

Commit

Permalink
Convert RandomZoom to backend-agnostic and improve affine_transform (
Browse files Browse the repository at this point in the history
…#574)

* Convert RandomZoom

* Fix docstring

* Update docstring

* Update docstring

* Address comments

* Update `affine_transform`

* Update

* Fix fill_mode in torch

* Update `affine_transform`

* Revert RandomTranslation tests

* Fix typo

* Remove docstring

* Redirect `fill_mode=reflect` to `fill_mode=mirror` using torch

* Update docstring

* Update tests because using torch `fill_mode=reflect` is actually `fill_mode=mirror`
  • Loading branch information
james77777778 authored Jul 26, 2023
1 parent 28c29e7 commit b6b4376
Show file tree
Hide file tree
Showing 7 changed files with 483 additions and 198 deletions.
71 changes: 54 additions & 17 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def resize(
"constant": "zeros",
"nearest": "border",
# "wrap", not supported by torch
# "mirror", not supported by torch
"reflect": "reflection",
"mirror": "reflection", # torch's reflection is mirror in other backends
"reflect": "reflection", # if fill_mode==reflect, redirect to mirror
}


Expand Down Expand Up @@ -122,7 +122,7 @@ def _apply_grid_transform(
grid,
mode=interpolation,
padding_mode=fill_mode,
align_corners=False,
align_corners=True,
)
# Fill with required color
if fill_value is not None:
Expand Down Expand Up @@ -187,9 +187,9 @@ def affine_transform(
f"transform.shape={transform.shape}"
)

if fill_mode != "constant":
# the default fill_value of tnn.grid_sample is "zeros"
if fill_mode != "constant" or (fill_mode == "constant" and fill_value == 0):
fill_value = None
fill_mode = AFFINE_TRANSFORM_FILL_MODES[fill_mode]

# unbatched case
need_squeeze = False
Expand All @@ -202,23 +202,60 @@ def affine_transform(
if data_format == "channels_last":
image = image.permute((0, 3, 1, 2))

batch_size = image.shape[0]
h, w, c = image.shape[-2], image.shape[-1], image.shape[-3]

# get indices
shape = [h, w, c] # (H, W, C)
meshgrid = torch.meshgrid(
*[torch.arange(size) for size in shape], indexing="ij"
)
indices = torch.concatenate(
[torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1
)
indices = torch.tile(indices, (batch_size, 1, 1, 1, 1))
indices = indices.to(transform)

# swap the values
a0 = transform[:, 0].clone()
a2 = transform[:, 2].clone()
b1 = transform[:, 4].clone()
b2 = transform[:, 5].clone()
transform[:, 0] = b1
transform[:, 2] = b2
transform[:, 4] = a0
transform[:, 5] = a2

# deal with transform
h, w = image.shape[2], image.shape[3]
theta = torch.zeros((image.shape[0], 2, 3)).to(transform)
theta[:, 0, 0] = transform[:, 0]
theta[:, 0, 1] = transform[:, 1] * h / w
theta[:, 0, 2] = (
transform[:, 2] * 2 / w + theta[:, 0, 0] + theta[:, 0, 1] - 1
transform = torch.nn.functional.pad(
transform, pad=[0, 1, 0, 0], mode="constant", value=1
)
theta[:, 1, 0] = transform[:, 3] * w / h
theta[:, 1, 1] = transform[:, 4]
theta[:, 1, 2] = (
transform[:, 5] * 2 / h + theta[:, 1, 0] + theta[:, 1, 1] - 1
transform = torch.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2].clone()
offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0])
transform[:, 0:2, 2] = 0

# transform the indices
coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = torch.moveaxis(coordinates, source=-1, destination=1)
coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1))
coordinates = coordinates[:, 0:2, ..., 0]
coordinates = coordinates.permute((0, 2, 3, 1))

# normalize coordinates
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / (w - 1) * 2.0 - 1.0
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / (h - 1) * 2.0 - 1.0
grid = torch.stack(
[coordinates[:, :, :, 1], coordinates[:, :, :, 0]], dim=-1
)

grid = tnn.affine_grid(theta, image.shape)
affined = _apply_grid_transform(
image, grid, interpolation, fill_mode, fill_value
image,
grid,
interpolation=interpolation,
# if fill_mode==reflect, redirect to mirror
fill_mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
fill_value=fill_value,
)

if data_format == "channels_last":
Expand Down
28 changes: 17 additions & 11 deletions keras_core/layers/preprocessing/random_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,23 @@ class RandomTranslation(TFDataLayer):
left by 20%, and shifted right by 30%. `width_factor=0.2` results
in an output height shifted left or right by 20%.
fill_mode: Points outside the boundaries of the input are filled
according to the given mode
(one of `{"constant", "reflect", "wrap", "nearest"}`).
- *reflect*: `(d c b a | a b c d | d c b a)` The input is extended
by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)` The input is extended
by filling all values beyond the edge with the same constant
value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by
wrapping around to the opposite edge.
- *nearest*: `(a a a a | a b c d | d d d d)` The input is extended
by the nearest pixel.
according to the given mode. Available methods are `"constant"`,
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
- `"reflect"`: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last
pixel.
- `"constant"`: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond
the edge with the same constant value k specified by
`fill_value`.
- `"wrap"`: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
- `"nearest"`: `(a a a a | a b c d | d d d d)`
The input is extended by the nearest pixel.
Note that when using torch backend, `"reflect"` is redirected to
`"mirror"` `(c d c b | a b c d | c b a b)` because torch does not
support `"reflect"`.
Note that torch backend does not support `"wrap"`.
interpolation: Interpolation mode. Supported values: `"nearest"`,
`"bilinear"`.
seed: Integer. Used to create a random seed.
Expand Down
129 changes: 90 additions & 39 deletions keras_core/layers/preprocessing/random_translation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,27 @@ def test_random_translation_with_inference_mode(self):
@parameterized.parameters(["channels_first", "channels_last"])
def test_random_translation_up_numeric_reflect(self, data_format):
input_image = np.arange(0, 25)
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[20, 21, 22, 23, 24],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[15, 16, 17, 18, 19],
]
)
else:
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[20, 21, 22, 23, 24],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 5, 5, 1))
expected_output = backend.convert_to_tensor(
Expand Down Expand Up @@ -133,15 +145,27 @@ def test_random_translation_up_numeric_constant(self, data_format):
def test_random_translation_down_numeric_reflect(self, data_format):
input_image = np.arange(0, 25)
# Shifting by .2 * 5 = 1 pixel.
expected_output = np.asarray(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
]
)
else:
expected_output = np.asarray(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 5, 5, 1))
expected_output = backend.convert_to_tensor(
Expand Down Expand Up @@ -172,18 +196,33 @@ def test_random_translation_asymmetric_size_numeric_reflect(
):
input_image = np.arange(0, 16)
# Shifting by .2 * 5 = 1 pixel.
expected_output = np.asarray(
[
[6, 7],
[4, 5],
[2, 3],
[0, 1],
[0, 1],
[2, 3],
[4, 5],
[6, 7],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[8, 9],
[6, 7],
[4, 5],
[2, 3],
[0, 1],
[2, 3],
[4, 5],
[6, 7],
]
)
else:
expected_output = np.asarray(
[
[6, 7],
[4, 5],
[2, 3],
[0, 1],
[0, 1],
[2, 3],
[4, 5],
[6, 7],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 8, 2, 1))
expected_output = backend.convert_to_tensor(
Expand Down Expand Up @@ -251,15 +290,27 @@ def test_random_translation_down_numeric_constant(self, data_format):
def test_random_translation_left_numeric_reflect(self, data_format):
input_image = np.arange(0, 25)
# Shifting by .2 * 5 = 1 pixel.
expected_output = np.asarray(
[
[1, 2, 3, 4, 4],
[6, 7, 8, 9, 9],
[11, 12, 13, 14, 14],
[16, 17, 18, 19, 19],
[21, 22, 23, 24, 24],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[1, 2, 3, 4, 3],
[6, 7, 8, 9, 8],
[11, 12, 13, 14, 13],
[16, 17, 18, 19, 18],
[21, 22, 23, 24, 23],
]
)
else:
expected_output = np.asarray(
[
[1, 2, 3, 4, 4],
[6, 7, 8, 9, 9],
[11, 12, 13, 14, 14],
[16, 17, 18, 19, 19],
[21, 22, 23, 24, 24],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 5, 5, 1))
expected_output = backend.convert_to_tensor(
Expand Down
Loading

0 comments on commit b6b4376

Please sign in to comment.