From 3a0ab23402cb4337da2c0a790b6a494c9edffdc9 Mon Sep 17 00:00:00 2001 From: aleju Date: Sun, 27 Oct 2019 20:57:56 +0100 Subject: [PATCH 1/3] Add apply_jigsaw(), generate_jigsaw_destinations() --- changelogs/master/added/20191027_jigsaw.md | 4 + checks/check_jigsaw.py | 25 +++ imgaug/augmenters/geometric.py | 159 ++++++++++++++++++ test/augmenters/test_geometric.py | 180 +++++++++++++++++++++ 4 files changed, 368 insertions(+) create mode 100644 changelogs/master/added/20191027_jigsaw.md create mode 100644 checks/check_jigsaw.py diff --git a/changelogs/master/added/20191027_jigsaw.md b/changelogs/master/added/20191027_jigsaw.md new file mode 100644 index 000000000..1fd716988 --- /dev/null +++ b/changelogs/master/added/20191027_jigsaw.md @@ -0,0 +1,4 @@ +# Jigsaw Augmenter #476 + +* Added function `imgaug.augmenters.geometric.apply_jigsaw()`. +* Added function `imgaug.augmenters.geometric.generate_jigsaw_destinations()`. diff --git a/checks/check_jigsaw.py b/checks/check_jigsaw.py new file mode 100644 index 000000000..25b4d4660 --- /dev/null +++ b/checks/check_jigsaw.py @@ -0,0 +1,25 @@ +from __future__ import print_function, division, absolute_import +import imgaug as ia +import imgaug.augmenters as iaa +import timeit + + +def main(): + gen_time = timeit.timeit( + "iaa.generate_jigsaw_destinations(10, 10, 2, rng)", + number=128, + setup=( + "import imgaug.augmenters as iaa; " + "import imgaug.random as iarandom; " + "rng = iarandom.RNG(0)" + ) + ) + print("Time to generate 128x dest:", gen_time) + image = ia.quokka_square((200, 200)) + destinations = iaa.generate_jigsaw_destinations(10, 10, 1, random_state=1) + image_jig = iaa.apply_jigsaw(image, destinations) + ia.imshow(image_jig) + + +if __name__ == "__main__": + main() diff --git a/imgaug/augmenters/geometric.py b/imgaug/augmenters/geometric.py index d3398f20d..ab854e0dd 100644 --- a/imgaug/augmenters/geometric.py +++ b/imgaug/augmenters/geometric.py @@ -39,6 +39,7 @@ from imgaug.augmentables.polys import _ConcavePolygonRecoverer from .. import parameters as iap from .. import dtypes as iadt +from .. import random as iarandom _VALID_DTYPES_CV2_ORDER_0 = {"uint8", "uint16", "int8", "int16", "int32", @@ -355,6 +356,164 @@ def _compute_affine_warp_output_shape(matrix, input_shape): return matrix, output_shape +# TODO allow -1 destinations +def apply_jigsaw(arr, destinations): + """Move cells of an image similar to a jigsaw puzzle. + + This function will split the image into ``rows x cols`` cells and + move each cell to the target index given in `destinations`. + + dtype support:: + + * ``uint8``: yes; fully tested + * ``uint16``: yes; fully tested + * ``uint32``: yes; fully tested + * ``uint64``: yes; fully tested + * ``int8``: yes; fully tested + * ``int16``: yes; fully tested + * ``int32``: yes; fully tested + * ``int64``: yes; fully tested + * ``float16``: yes; fully tested + * ``float32``: yes; fully tested + * ``float64``: yes; fully tested + * ``float128``: yes; fully tested + * ``bool``: yes; fully tested + + Parameters + ---------- + arr : ndarray + Array with at least two dimensions denoting height and width. + + destinations : ndarray + 2-dimensional array containing for each cell the id of the destination + cell. The order is expected to a flattened c-order, i.e. row by row. + The height of the image must be evenly divisible by the number of + rows in this array. Analogous for the width and columns. + + Returns + ------- + ndarray + Modified image with cells moved according to `destioations`. + + """ + nb_rows, nb_cols = destinations.shape[0:2] + + assert arr.ndim >= 2, ( + "Expected array with at least two dimensions, but got %d with " + "shape %s." % (arr.ndim, arr.shape)) + assert (arr.shape[0] % nb_rows) == 0, ( + "Expected image height to by divisible by number of rows, but got " + "height %d and %d rows. Use cropping or padding to modify the image " + "height or change the number of rows." % (arr.shape[0], nb_rows) + ) + assert (arr.shape[1] % nb_cols) == 0, ( + "Expected image width to by divisible by number of columns, but got " + "width %d and %d columns. Use cropping or padding to modify the image " + "width or change the number of columns." % (arr.shape[1], nb_cols) + ) + + cell_height = arr.shape[0] // nb_rows + cell_width = arr.shape[1] // nb_cols + + dest_rows, dest_cols = np.unravel_index( + destinations.flatten(), (nb_rows, nb_cols)) + + result = np.zeros_like(arr) + i = 0 + for source_row in np.arange(nb_rows): + for source_col in np.arange(nb_cols): + # TODO vectorize coords computation + dest_row, dest_col = dest_rows[i], dest_cols[i] + + source_y1 = source_row * cell_height + source_y2 = source_y1 + cell_height + source_x1 = source_col * cell_width + source_x2 = source_x1 + cell_width + + dest_y1 = dest_row * cell_height + dest_y2 = dest_y1 + cell_height + dest_x1 = dest_col * cell_width + dest_x2 = dest_x1 + cell_width + + source = arr[source_y1:source_y2, source_x1:source_x2] + result[dest_y1:dest_y2, dest_x1:dest_x2] = source + + i += 1 + + return result + + +def generate_jigsaw_destinations(nb_rows, nb_cols, max_steps, random_state, + connectivity=4): + """Generate a destination pattern for :func:`apply_jigsaw`. + + Parameters + ---------- + nb_rows : int + Number of rows to split the image into. + + nb_cols : int + Number of columns to split the image into. + + max_steps : int + Maximum number of cells that each cell may be moved. + + random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.bit_generator.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState + RNG or seed to use. If ``None`` the global RNG will be used. + + connectivity : int, optional + Whether a diagonal move of a cell counts as one step + (``connectivity=8``) or two steps (``connectivity=4``). + + Returns + ------- + ndarray + 2-dimensional array containing for each cell the id of the target + cell. + + """ + assert connectivity in (4, 8), ( + "Expected connectivity of 4 or 8, got %d." % (connectivity,)) + random_state = iarandom.RNG(random_state) + steps = random_state.integers(0, max_steps, size=(nb_rows, nb_cols), + endpoint=True) + directions = random_state.integers(0, connectivity, + size=(nb_rows, nb_cols, max_steps), + endpoint=False) + destinations = np.arange(nb_rows*nb_cols).reshape((nb_rows, nb_cols)) + + for step in np.arange(max_steps): + directions_step = directions[:, :, step] + + for y in np.arange(nb_rows): + for x in np.arange(nb_cols): + if steps[y, x] > 0: + y_target, x_target = { + 0: (y-1, x+0), + 1: (y+0, x+1), + 2: (y+1, x+0), + 3: (y+0, x-1), + 4: (y-1, x-1), + 5: (y-1, x+1), + 6: (y+1, x+1), + 7: (y+1, x-1) + }[directions_step[y, x]] + y_target = max(min(y_target, nb_rows-1), 0) + x_target = max(min(x_target, nb_cols-1), 0) + + target_steps = steps[y_target, x_target] + if (y, x) != (y_target, x_target) and target_steps >= 1: + source_dest = destinations[y, x] + target_dest = destinations[y_target, x_target] + destinations[y, x] = target_dest + destinations[y_target, x_target] = source_dest + + steps[y, x] -= 1 + steps[y_target, x_target] -= 1 + + return destinations + + class _AffineSamplingResult(object): def __init__(self, scale=None, translate=None, rotate=None, shear=None, cval=None, mode=None, order=None): diff --git a/test/augmenters/test_geometric.py b/test/augmenters/test_geometric.py index 8f91f692c..83b59c944 100644 --- a/test/augmenters/test_geometric.py +++ b/test/augmenters/test_geometric.py @@ -9075,3 +9075,183 @@ def test___repr___and___str__(self): assert aug.__repr__() == expected assert aug.__str__() == expected + + +class Test_apply_jigsaw(unittest.TestCase): + def test_no_movement(self): + dtypes = ["bool", + "uint8", "uint16", "uint32", "uint64", + "int8", "int16", "int32", "int64", + "float16", "float32", "float64", "float128"] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + arr = np.arange(20*20*1).reshape((20, 20, 1)) + if dtype == "bool": + mask = np.logical_or( + arr % 4 == 0, + arr % 7 == 0) + arr[mask] = 1 + arr[~mask] = 0 + arr = arr.astype(dtype) + min_value, center_value, max_value = \ + iadt.get_value_range_of_dtype(dtype) + arr[0, 0] = min_value + arr[0, 1] = max_value + + destinations = np.arange(5*5).reshape((5, 5)) + + observed = iaa.apply_jigsaw(arr, destinations) + + if arr.dtype.kind != "f": + assert np.array_equal(observed, arr) + else: + atol = 1e-4 if dtype == "float16" else 1e-8 + assert np.allclose(observed, arr, rtol=0, atol=atol) + + def test_no_movement_zero_sized_axes(self): + sizes = [ + (0, 1), + (1, 0), + (0, 0) + ] + + dtype = "uint8" + for size in sizes: + with self.subTest(size=size): + arr = np.zeros(size, dtype=dtype) + destinations = np.arange(1*1).reshape((1, 1)) + + observed = iaa.apply_jigsaw(arr, destinations) + + assert np.array_equal(observed, arr) + + def _test_two_cells_moved__n_channels(self, nb_channels): + dtypes = ["bool", + "uint8", "uint16", "uint32", "uint64", + "int8", "int16", "int32", "int64", + "float16", "float32", "float64", "float128"] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + c = 1 if nb_channels is None else nb_channels + arr = np.arange(20*20*c) + if dtype == "bool": + mask = np.logical_or( + arr % 4 == 0, + arr % 7 == 0) + arr[mask] = 1 + arr[~mask] = 0 + if nb_channels is not None: + arr = arr.reshape((20, 20, c)) + else: + arr = arr.reshape((20, 20)) + arr = arr.astype(dtype) + min_value, center_value, max_value = \ + iadt.get_value_range_of_dtype(dtype) + arr[0, 0] = min_value + arr[0, 1] = max_value + + destinations = np.arange(5*5).reshape((5, 5)) + destinations[0, 0] = 4 # cell 0 will be filled with 4 + destinations[0, 4] = 0 # cell 4 will be filled with 0 + destinations[0, 1] = 6 # cell 1 will be filled with 6 + destinations[1, 1] = 1 # cell 6 will be filled with 1 + + observed = iaa.apply_jigsaw(arr, destinations) + + cell_0_obs = observed[0:4, 0:4] + cell_0_exp = arr[0:4, 16:20] + cell_4_obs = observed[0:4, 16:20] + cell_4_exp = arr[0:4, 0:4] + cell_1_obs = observed[0:4, 4:8] + cell_1_exp = arr[4:8, 4:8] + cell_6_obs = observed[4:8, 4:8] + cell_6_exp = arr[0:4, 4:8] + cell_2_obs = observed[0:4, 8:12] + cell_2_exp = arr[0:4, 8:12] + if arr.dtype.kind != "f": + assert np.array_equal(cell_0_obs, cell_0_exp) + assert np.array_equal(cell_4_obs, cell_4_exp) + assert np.array_equal(cell_1_obs, cell_1_exp) + assert np.array_equal(cell_6_obs, cell_6_exp) + assert np.array_equal(cell_2_obs, cell_2_exp) + else: + atol = 1e-4 if dtype == "float16" else 1e-8 + kwargs = {"rtol": 0, "atol": atol} + assert np.allclose(cell_0_obs, cell_0_exp, **kwargs) + assert np.allclose(cell_4_obs, cell_4_exp, **kwargs) + assert np.allclose(cell_1_obs, cell_1_exp, **kwargs) + assert np.allclose(cell_6_obs, cell_6_exp, **kwargs) + assert np.allclose(cell_2_obs, cell_2_exp, **kwargs) + + assert observed.shape == arr.shape + assert observed.dtype.name == dtype + + def test_two_cells_moved__no_channels(self): + self._test_two_cells_moved__n_channels(None) + + def test_two_cells_moved__1_channel(self): + self._test_two_cells_moved__n_channels(1) + + def test_two_cells_moved__3_channels(self): + self._test_two_cells_moved__n_channels(3) + + +class Test_generate_jigsaw_destinations(unittest.TestCase): + def test_max_steps_0(self): + rng = iarandom.RNG(0) + max_steps = 0 + rows = 10 + cols = 20 + + observed = iaa.generate_jigsaw_destinations(rows, cols, max_steps, rng, + connectivity=8) + + assert np.array_equal( + observed, + np.arange(rows*cols).reshape((rows, cols))) + + def test_max_steps_1(self): + rng = iarandom.RNG(0) + max_steps = 1 + rows = 10 + cols = 20 + + observed = iaa.generate_jigsaw_destinations(rows, cols, max_steps, rng, + connectivity=8) + + yy = (observed // cols).reshape((rows, cols)) + xx = np.mod(observed, cols).reshape((rows, cols)) + yy_expected = np.tile(np.arange(rows).reshape((rows, 1)), (1, cols)) + xx_expected = np.tile(np.arange(cols).reshape((1, cols)), (rows, 1)) + + yy_diff = yy_expected - yy + xx_diff = xx_expected - xx + dist = np.sqrt(yy_diff ** 2 + xx_diff ** 2) + + assert np.min(dist) <= 0.01 + assert np.any(dist >= np.sqrt(2) - 1e-4) + assert np.max(dist) <= np.sqrt(2) + 1e-4 + + def test_max_steps_1_connectivity_4(self): + rng = iarandom.RNG(0) + max_steps = 1 + rows = 10 + cols = 20 + + observed = iaa.generate_jigsaw_destinations(rows, cols, max_steps, rng, + connectivity=4) + + yy = (observed // cols).reshape((rows, cols)) + xx = np.mod(observed, cols).reshape((rows, cols)) + yy_expected = np.tile(np.arange(rows).reshape((rows, 1)), (1, cols)) + xx_expected = np.tile(np.arange(cols).reshape((1, cols)), (rows, 1)) + + yy_diff = yy_expected - yy + xx_diff = xx_expected - xx + dist = np.sqrt(yy_diff ** 2 + xx_diff ** 2) + + assert np.min(dist) <= 0.01 + assert np.any(dist >= 0.99) + assert np.max(dist) <= 1.01 From 5c1a279826c4322dcac1bf3a3e548702373be700 Mon Sep 17 00:00:00 2001 From: aleju Date: Sun, 27 Oct 2019 21:49:39 +0100 Subject: [PATCH 2/3] Add apply_jigsaw_to_coords() --- changelogs/master/added/20191027_jigsaw.md | 1 + imgaug/augmenters/geometric.py | 62 ++++++++++++++++++++- test/augmenters/test_geometric.py | 63 ++++++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/changelogs/master/added/20191027_jigsaw.md b/changelogs/master/added/20191027_jigsaw.md index 1fd716988..b34da0b22 100644 --- a/changelogs/master/added/20191027_jigsaw.md +++ b/changelogs/master/added/20191027_jigsaw.md @@ -1,4 +1,5 @@ # Jigsaw Augmenter #476 * Added function `imgaug.augmenters.geometric.apply_jigsaw()`. +* Added function `imgaug.augmenters.geometric.apply_jigsaw_to_coords()`. * Added function `imgaug.augmenters.geometric.generate_jigsaw_destinations()`. diff --git a/imgaug/augmenters/geometric.py b/imgaug/augmenters/geometric.py index ab854e0dd..8cf4de2e9 100644 --- a/imgaug/augmenters/geometric.py +++ b/imgaug/augmenters/geometric.py @@ -393,7 +393,7 @@ def apply_jigsaw(arr, destinations): Returns ------- ndarray - Modified image with cells moved according to `destioations`. + Modified image with cells moved according to `destinations`. """ nb_rows, nb_cols = destinations.shape[0:2] @@ -443,6 +443,66 @@ def apply_jigsaw(arr, destinations): return result +def apply_jigsaw_to_coords(coords, destinations, image_shape): + """Move coordinates on an image similar to a jigsaw puzzle. + + This is the same as :func:`apply_jigsaw`, but moves coordinates within + the cells. + + Parameters + ---------- + coords : ndarray + ``(N, 2)`` array denoting xy-coordinates. + + destinations : ndarray + See :func:`apply_jigsaw`. + + image_shape : tuple of int + ``(height, width, ...)`` shape of the image on which the + coordinates are placed. Only height and width are required. + + Returns + ------- + ndarray + Moved coordinates. + + """ + nb_rows, nb_cols = destinations.shape[0:2] + + height, width = image_shape[0:2] + cell_height = height // nb_rows + cell_width = width // nb_cols + + dest_rows, dest_cols = np.unravel_index( + destinations.flatten(), (nb_rows, nb_cols)) + + result = np.copy(coords) + + # TODO vectorize this loop + for i, (x, y) in enumerate(coords): + ooi_x = (x < 0 or x >= width) + ooi_y = (y < 0 or y >= height) + if ooi_x or ooi_y: + continue + + source_row = int(y // cell_height) + source_col = int(x // cell_width) + source_cell_idx = (source_row * nb_cols) + source_col + dest_row = dest_rows[source_cell_idx] + dest_col = dest_cols[source_cell_idx] + + source_y1 = source_row * cell_height + source_x1 = source_col * cell_width + + dest_y1 = dest_row * cell_height + dest_x1 = dest_col * cell_width + + result[i, 0] = dest_x1 + (x - source_x1) + result[i, 1] = dest_y1 + (y - source_y1) + + return result + + def generate_jigsaw_destinations(nb_rows, nb_cols, max_steps, random_state, connectivity=4): """Generate a destination pattern for :func:`apply_jigsaw`. diff --git a/test/augmenters/test_geometric.py b/test/augmenters/test_geometric.py index 83b59c944..89de79d76 100644 --- a/test/augmenters/test_geometric.py +++ b/test/augmenters/test_geometric.py @@ -9198,6 +9198,69 @@ def test_two_cells_moved__3_channels(self): self._test_two_cells_moved__n_channels(3) +class Test_apply_jigsaw_to_coords(unittest.TestCase): + def test_no_movement(self): + arr = np.float32([ + (0.0, 0.0), + (5.0, 5.0), + (25.0, 50.5), + (10.01, 21.0) + ]) + destinations = np.arange(10*10).reshape((10, 10)) + + observed = iaa.apply_jigsaw_to_coords(arr, destinations, (50, 100)) + + assert np.allclose(observed, arr) + + def test_with_movement(self): + arr = np.float32([ + (0.0, 0.0), # in cell (0,0) = idx 0 + (5.0, 5.0), # in cell (0,0) = idx 0 + (25.0, 50.5), # in cell (5,2) = idx 52 + (10.01, 21.0) # in cell (2,1) = idx 12 + ]) + destinations = np.arange(10*10).reshape((10, 10)) + destinations[0, 0] = 1 + destinations[0, 1] = 0 + destinations[5, 2] = 7 + destinations[0, 7] = 52 + + observed = iaa.apply_jigsaw_to_coords(arr, destinations, (100, 100)) + + expected = np.float32([ + (10.0, 0.0), + (15.0, 5.0), + (75.0, 0.5), + (10.01, 21.0) + ]) + assert np.allclose(observed, expected) + + def test_with_movement_non_square_image(self): + arr = np.float32([ + (0.5, 0.6), # in cell (0,0) = idx 0 + (180.7, 90.8), # in cell (9,9) = idx 99 + ]) + destinations = np.arange(10*10).reshape((10, 10)) + destinations[0, 0] = 99 + destinations[9, 9] = 0 + + observed = iaa.apply_jigsaw_to_coords(arr, destinations, (100, 200)) + + expected = np.float32([ + (180+0.5, 90+0.6), + (0+0.7, 0+0.8) + ]) + assert np.allclose(observed, expected) + + def test_empty_coords(self): + arr = np.zeros((0, 2), dtype=np.float32) + destinations = np.arange(10*10).reshape((10, 10)) + + observed = iaa.apply_jigsaw_to_coords(arr, destinations, (100, 100)) + + assert np.allclose(observed, arr) + + class Test_generate_jigsaw_destinations(unittest.TestCase): def test_max_steps_0(self): rng = iarandom.RNG(0) From a6f0cda8c7a4b526e158d51f8e8661b3e7e4401c Mon Sep 17 00:00:00 2001 From: aleju Date: Sun, 27 Oct 2019 22:23:55 +0100 Subject: [PATCH 3/3] Add Jigsaw --- checks/check_jigsaw.py | 8 +- imgaug/augmenters/geometric.py | 269 ++++++++++++++++++++- test/augmenters/test_geometric.py | 374 ++++++++++++++++++++++++++++++ 3 files changed, 649 insertions(+), 2 deletions(-) diff --git a/checks/check_jigsaw.py b/checks/check_jigsaw.py index 25b4d4660..d372b8b97 100644 --- a/checks/check_jigsaw.py +++ b/checks/check_jigsaw.py @@ -5,6 +5,12 @@ def main(): + image = ia.quokka_square((200, 200)) + aug = iaa.Jigsaw(10, 10) + + images_aug = aug(images=[image] * 16) + ia.imshow(ia.draw_grid(images_aug)) + gen_time = timeit.timeit( "iaa.generate_jigsaw_destinations(10, 10, 2, rng)", number=128, @@ -15,7 +21,7 @@ def main(): ) ) print("Time to generate 128x dest:", gen_time) - image = ia.quokka_square((200, 200)) + destinations = iaa.generate_jigsaw_destinations(10, 10, 1, random_state=1) image_jig = iaa.apply_jigsaw(image, destinations) ia.imshow(image_jig) diff --git a/imgaug/augmenters/geometric.py b/imgaug/augmenters/geometric.py index 8cf4de2e9..e7716f34c 100644 --- a/imgaug/augmenters/geometric.py +++ b/imgaug/augmenters/geometric.py @@ -20,6 +20,7 @@ * ElasticTransformation * Rot90 * WithPolarWarping + * Jigsaw """ from __future__ import print_function, division, absolute_import @@ -35,6 +36,7 @@ from . import meta from . import blur as blur_lib +from . import size as size_lib import imgaug as ia from imgaug.augmentables.polys import _ConcavePolygonRecoverer from .. import parameters as iap @@ -4112,7 +4114,6 @@ class Rot90(meta.Augmenter): random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.bit_generator.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional See :func:`imgaug.augmenters.meta.Augmenter.__init__`. - Examples -------- >>> import imgaug.augmenters as iaa @@ -4910,3 +4911,269 @@ def __str__(self): ")") return pattern % (self.__class__.__name__, self.name, self.children, self.deterministic) + + +class Jigsaw(meta.Augmenter): + """Move cells within images similar to jigsaw patterns. + + .. note:: + + This augmenter will by default pad images until their height is a + multiple of `nb_rows`. Analogous for `nb_cols`. + + .. note:: + + This augmenter will resize heatmaps and segmentation maps to the + image size, then apply similar padding as for the corresponding images + and resize back to the original map size. That also means that images + may change in shape (due to padding), but heatmaps/segmaps will not + change. For heatmaps/segmaps, this deviates from pad augmenters that + will change images and heatmaps/segmaps in corresponding ways and then + keep the heatmaps/segmaps at the new size. + + .. warning:: + + This augmenter currently only supports augmentation of images, + heatmaps, segmentation maps and keypoints. Other augmentables will + produce errors. + + dtype support:: + + See :func:`apply_jigsaw`. + + Parameters + ---------- + nb_rows : int or list of int or tuple of int or imgaug.parameters.StochasticParameter + How many rows the jigsaw pattern should have. + + * If a single ``int``, then that value will be used for all images. + * If a tuple ``(a, b)``, then a random value will be uniformly + sampled per image from the discrete interval ``[a..b]``. + * If a list, then for each image a random value will be sampled + from that list. + * If ``StochasticParameter``, then that parameter is queried per + image to sample the value to use. + + nb_cols : int or list of int or tuple of int or imgaug.parameters.StochasticParameter + How many cols the jigsaw pattern should have. + + * If a single ``int``, then that value will be used for all images. + * If a tuple ``(a, b)``, then a random value will be uniformly + sampled per image from the discrete interval ``[a..b]``. + * If a list, then for each image a random value will be sampled + from that list. + * If ``StochasticParameter``, then that parameter is queried per + image to sample the value to use. + + max_steps : int or list of int or tuple of int or imgaug.parameters.StochasticParameter, optional + How many steps each jigsaw cell may be moved. + + * If a single ``int``, then that value will be used for all images. + * If a tuple ``(a, b)``, then a random value will be uniformly + sampled per image from the discrete interval ``[a..b]``. + * If a list, then for each image a random value will be sampled + from that list. + * If ``StochasticParameter``, then that parameter is queried per + image to sample the value to use. + + allow_pad : bool, optional + Whether to allow automatically padding images until they are evenly + divisible by ``nb_rows`` and ``nb_cols``. + + name : None or str, optional + See :func:`imgaug.augmenters.meta.Augmenter.__init__`. + + deterministic : bool, optional + See :func:`imgaug.augmenters.meta.Augmenter.__init__`. + + random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.bit_generator.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional + See :func:`imgaug.augmenters.meta.Augmenter.__init__`. + + Examples + -------- + >>> import numpy as np + >>> import imgaug.augmenters as iaa + >>> image = np.mod(np.arange(100*100*3), 255).astype(np.uint8) + >>> image = image.reshape((100, 100, 3)) + >>> images = [image] * 16 + >>> aug = iaa.Jigsaw(nb_rows=10, nb_cols=10) + >>> images_aug = aug(images=images) + + Create a jigsaw augmenter and use it to augment a simple example batch + of ``16`` ``100x100x3`` images. Each image will be split into ``10x10`` + cells, i.e. each cell will be ``10px`` high and ``10px`` wide. + + >>> aug = iaa.Jigsaw(nb_rows=(1, 4), nb_cols=(1, 4)) + + Create a jigsaw augmenter that splits each image into a maximum of ``4x4`` + cells. + + >>> aug = iaa.Jigsaw(nb_rows=10, nb_cols=10, max_steps=(1, 5)) + + Create a jigsaw augmenter that moves the cells in each image by a random + amount between ``1`` and ``5`` times (decided per image). Some images will + be barely changed, some will be fairly distorted. + + """ + + def __init__(self, nb_rows, nb_cols, max_steps=2, allow_pad=True, + name=None, deterministic=False, random_state=None): + super(Jigsaw, self).__init__( + name=name, deterministic=deterministic, random_state=random_state) + + self.nb_rows = iap.handle_discrete_param( + nb_rows, "nb_rows", value_range=(1, None), tuple_to_uniform=True, + list_to_choice=True, allow_floats=False) + self.nb_cols = iap.handle_discrete_param( + nb_cols, "nb_cols", value_range=(1, None), tuple_to_uniform=True, + list_to_choice=True, allow_floats=False) + self.max_steps = iap.handle_discrete_param( + max_steps, "max_steps", value_range=(0, None), + tuple_to_uniform=True, list_to_choice=True, allow_floats=False) + self.allow_pad = allow_pad + + def _augment_batch(self, batch, random_state, parents, hooks): + samples = self._draw_samples(batch, random_state) + + # We resize here heatmaps/segmaps early to the image size in order to + # avoid problems where the jigsaw cells don't fit perfectly into + # the heatmap/segmap arrays or there are minor padding-related + # differences. + # TODO This step could most likely be avoided. + # TODO add something like + # 'with batch.maps_resized_to_image_sizes(): ...' + batch, maps_shapes_orig = self._resize_maps(batch) + + if self.allow_pad: + # this is a bit more difficult than one might expect, because we + # (a) might have different numbers of rows/cols per image + # (b) might have different shapes per image + # (c) have non-image data that also requires padding + # TODO enable support for stochastic parameters in + # PadToMultiplesOf, then we can simple use two + # DeterministicLists here to generate rowwise values + + for i in np.arange(len(samples.destinations)): + padder = size_lib.CenterPadToMultiplesOf( + width_multiple=samples.nb_cols[i], + height_multiple=samples.nb_rows[i]) + row = batch.subselect_rows_by_indices([i]) + row = padder.augment_batch(row, parents=parents + [self], + hooks=hooks) + batch = batch.invert_subselect_rows_by_indices_([i], row) + + if batch.images is not None: + for i, image in enumerate(batch.images): + image[...] = apply_jigsaw(image, samples.destinations[i]) + + if batch.heatmaps is not None: + for i, heatmap in enumerate(batch.heatmaps): + heatmap.arr_0to1 = apply_jigsaw(heatmap.arr_0to1, + samples.destinations[i]) + + if batch.segmentation_maps is not None: + for i, segmap in enumerate(batch.segmentation_maps): + segmap.arr = apply_jigsaw(segmap.arr, samples.destinations[i]) + + if batch.keypoints is not None: + for i, kpsoi in enumerate(batch.keypoints): + xy = kpsoi.to_xy_array() + xy[...] = apply_jigsaw_to_coords(xy, + samples.destinations[i], + image_shape=kpsoi.shape) + kpsoi.fill_from_xy_array_(xy) + + has_other_cbaoi = any([getattr(batch, attr_name) is not None + for attr_name + in ["bounding_boxes", "polygons", + "line_strings"]]) + if has_other_cbaoi: + raise NotImplementedError( + "Jigsaw currently only supports augmentation of images " + "and keypoints.") + + # We don't crop back to the original size, partly because it is + # rather cumbersome to implement, partly because the padded + # borders might have been moved into the inner parts of the image + + batch = self._invert_resize_maps(batch, maps_shapes_orig) + + return batch + + def _draw_samples(self, batch, random_state): + nb_images = batch.nb_rows + nb_rows = self.nb_rows.draw_samples((nb_images,), + random_state=random_state) + nb_cols = self.nb_cols.draw_samples((nb_images,), + random_state=random_state) + max_steps = self.max_steps.draw_samples((nb_images,), + random_state=random_state) + destinations = [] + for i in np.arange(nb_images): + destinations.append( + generate_jigsaw_destinations( + nb_rows[i], nb_cols[i], max_steps[i], + random_state=random_state) + ) + + samples = _JigsawSamples(nb_rows, nb_cols, max_steps, destinations) + return samples + + @classmethod + def _resize_maps(cls, batch): + # skip computation of rowwise shapes + if batch.heatmaps is None and batch.segmentation_maps is None: + return batch, (None, None) + + image_shapes = batch.get_rowwise_shapes() + batch.heatmaps, heatmaps_shapes_orig = cls._resize_maps_single_list( + batch.heatmaps, "arr_0to1", image_shapes) + batch.segmentation_maps, sm_shapes_orig = cls._resize_maps_single_list( + batch.segmentation_maps, "arr", image_shapes) + + return batch, (heatmaps_shapes_orig, sm_shapes_orig) + + @classmethod + def _resize_maps_single_list(cls, augmentables, arr_attr_name, + image_shapes): + if augmentables is None: + return None, None + + shapes_orig = [] + augms_resized = [] + for augmentable, image_shape in zip(augmentables, image_shapes): + shape_orig = getattr(augmentable, arr_attr_name).shape + augm_rs = augmentable.resize(image_shape[0:2]) + augms_resized.append(augm_rs) + shapes_orig.append(shape_orig) + return augms_resized, shapes_orig + + @classmethod + def _invert_resize_maps(cls, batch, shapes_orig): + batch.heatmaps = cls._invert_resize_maps_single_list( + batch.heatmaps, shapes_orig[0]) + batch.segmentation_maps = cls._invert_resize_maps_single_list( + batch.segmentation_maps, shapes_orig[1]) + + return batch + + @classmethod + def _invert_resize_maps_single_list(cls, augmentables, shapes_orig): + if shapes_orig is None: + return None + + augms_resized = [] + for augmentable, shape_orig in zip(augmentables, shapes_orig): + augms_resized.append(augmentable.resize(shape_orig[0:2])) + return augms_resized + + def get_parameters(self): + return [self.nb_rows, self.nb_cols, self.max_steps, self.allow_pad] + + +class _JigsawSamples(object): + def __init__(self, nb_rows, nb_cols, max_steps, destinations): + self.nb_rows = nb_rows + self.nb_cols = nb_cols + self.max_steps = max_steps + self.destinations = destinations diff --git a/test/augmenters/test_geometric.py b/test/augmenters/test_geometric.py index 89de79d76..6149e93a9 100644 --- a/test/augmenters/test_geometric.py +++ b/test/augmenters/test_geometric.py @@ -29,6 +29,7 @@ array_equal_lists, keypoints_equal, reseed, assert_cbaois_equal) from imgaug.augmentables.heatmaps import HeatmapsOnImage from imgaug.augmentables.segmaps import SegmentationMapsOnImage +import imgaug.augmenters.geometric as geometriclib def _assert_same_min_max(observed, actual): @@ -9318,3 +9319,376 @@ def test_max_steps_1_connectivity_4(self): assert np.min(dist) <= 0.01 assert np.any(dist >= 0.99) assert np.max(dist) <= 1.01 + + +class TestJigsaw(unittest.TestCase): + def setUp(self): + reseed() + + def test___init___defaults(self): + aug = iaa.Jigsaw(nb_rows=1, nb_cols=2) + assert aug.nb_rows.value == 1 + assert aug.nb_cols.value == 2 + assert aug.max_steps.value == 2 + assert aug.allow_pad is True + + def test___init___custom(self): + aug = iaa.Jigsaw(nb_rows=1, nb_cols=2, max_steps=3, allow_pad=False) + assert aug.nb_rows.value == 1 + assert aug.nb_cols.value == 2 + assert aug.max_steps.value == 3 + assert aug.allow_pad is False + + def test__draw_samples(self): + aug = iaa.Jigsaw(nb_rows=(1, 5), nb_cols=(1, 6), max_steps=(1, 3)) + batch = mock.Mock() + batch.nb_rows = 100 + + samples = aug._draw_samples(batch, iarandom.RNG(0)) + + assert len(np.unique(samples.nb_rows)) > 1 + assert len(np.unique(samples.nb_cols)) > 1 + assert len(np.unique(samples.max_steps)) > 1 + assert np.all(samples.nb_rows >= 1) + assert np.all(samples.nb_rows <= 5) + assert np.all(samples.nb_cols >= 1) + assert np.all(samples.nb_cols <= 6) + assert np.all(samples.max_steps >= 1) + assert np.all(samples.max_steps <= 3) + + all_same = True + first = samples.destinations[0] + for dest in samples.destinations: + this_same = (dest.shape == first.shape + and np.array_equal(dest, first)) + all_same = all_same and this_same + assert not all_same + + def test_images_without_shifts(self): + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=0) + image = np.mod(np.arange(20*20*3), 255).astype(np.uint8) + image = image.reshape((20, 20, 3)) + + image_aug = aug(image=image) + + assert image_aug.dtype.name == "uint8" + assert image_aug.shape == (20, 20, 3) + assert np.array_equal(image_aug, image) + + def test_heatmaps_without_shifts(self): + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=0) + arr = np.linspace(0, 1.0, 20*20*1).astype(np.float32) + arr = arr.reshape((20, 20, 1)) + heatmap = ia.HeatmapsOnImage(arr, shape=(20, 20, 3)) + + heatmap_aug = aug(heatmaps=heatmap) + + assert heatmap_aug.shape == (20, 20, 3) + assert np.allclose(heatmap_aug.arr_0to1, heatmap.arr_0to1) + + def test_segmaps_without_shifts(self): + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=0) + arr = np.zeros((20, 20, 1), dtype=np.int32) + arr[0:10, :] = 1 + arr[10:20, 10:20] = 2 + arr = arr.reshape((20, 20, 1)) + segmap = ia.SegmentationMapsOnImage(arr, shape=(20, 20, 3)) + + segmap_aug = aug(segmentation_maps=segmap) + + assert segmap_aug.shape == (20, 20, 3) + assert np.array_equal(segmap_aug.arr, segmap.arr) + + def test_keypoints_without_shifts(self): + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=0) + kpsoi = ia.KeypointsOnImage.from_xy_array([ + (0, 0), + (5.5, 3.5), + (12.1, 23.5) + ], shape=(20, 20, 3)) + + kpsoi_aug = aug(keypoints=kpsoi) + + assert kpsoi_aug.shape == (20, 20, 3) + assert np.allclose(kpsoi_aug.to_xy_array(), kpsoi.to_xy_array()) + + def test_images_with_shifts(self): + # these rows/cols/max_steps parameters are mostly ignored due to the + # mocked _draw_samples method below + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=1) + image = np.mod(np.arange(19*19*3), 255).astype(np.uint8) + image = image.reshape((19, 19, 3)) + destinations = np.array([ + [3, 1], + [2, 0] + ], dtype=np.int32) + + old_func = aug._draw_samples + + def _mocked_draw_samples(batch, random_state): + samples = old_func(batch, random_state) + return geometriclib._JigsawSamples( + nb_rows=samples.nb_rows, + nb_cols=samples.nb_cols, + max_steps=samples.max_steps, + destinations=[destinations]) + + aug._draw_samples = _mocked_draw_samples + + image_aug = aug(image=image) + + expected = iaa.pad(image, bottom=1, right=1, cval=0) + expected = iaa.apply_jigsaw(expected, destinations) + assert np.array_equal(image_aug, expected) + + def test_heatmaps_with_shifts(self): + # these rows/cols/max_steps parameters are mostly ignored due to the + # mocked _draw_samples method below + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=1) + arr = np.linspace(0, 1.0, 18*18*1).astype(np.float32) + arr = arr.reshape((18, 18, 1)) + heatmap = ia.HeatmapsOnImage(arr, shape=(19, 19, 3)) + destinations = np.array([ + [3, 1], + [2, 0] + ], dtype=np.int32) + + old_func = aug._draw_samples + + def _mocked_draw_samples(batch, random_state): + samples = old_func(batch, random_state) + return geometriclib._JigsawSamples( + nb_rows=samples.nb_rows, + nb_cols=samples.nb_cols, + max_steps=samples.max_steps, + destinations=[destinations]) + + aug._draw_samples = _mocked_draw_samples + + heatmap_aug = aug(heatmaps=heatmap) + + expected = ia.imresize_single_image(arr, (19, 19), + interpolation="cubic") + expected = np.clip(expected, 0, 1.0) + expected = iaa.pad(expected, bottom=1, right=1, cval=0.0) + expected = iaa.apply_jigsaw(expected, destinations) + expected = ia.imresize_single_image(expected, (18, 18), + interpolation="cubic") + expected = np.clip(expected, 0, 1.0) + assert np.allclose(heatmap_aug.arr_0to1, expected) + + def test_segmaps_with_shifts(self): + # these rows/cols/max_steps parameters are mostly ignored due to the + # mocked _draw_samples method below + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=1) + arr = np.zeros((18, 18, 1), dtype=np.int32) + arr[0:10, :] = 1 + arr[10:18, 10:18] = 2 + arr = arr.reshape((18, 18, 1)) + segmap = ia.SegmentationMapsOnImage(arr, shape=(19, 19, 3)) + destinations = np.array([ + [3, 1], + [2, 0] + ], dtype=np.int32) + + old_func = aug._draw_samples + + def _mocked_draw_samples(batch, random_state): + samples = old_func(batch, random_state) + return geometriclib._JigsawSamples( + nb_rows=samples.nb_rows, + nb_cols=samples.nb_cols, + max_steps=samples.max_steps, + destinations=[destinations]) + + aug._draw_samples = _mocked_draw_samples + + segmap_aug = aug(segmentation_maps=segmap) + + expected = ia.imresize_single_image(arr, (19, 19), + interpolation="nearest") + expected = iaa.pad(expected, bottom=1, right=1, cval=0) + expected = iaa.apply_jigsaw(expected, destinations) + expected = ia.imresize_single_image(expected, (18, 18), + interpolation="nearest") + assert np.array_equal(segmap_aug.arr, expected) + + def test_keypoints_with_shifts(self): + # these rows/cols/max_steps parameters are mostly ignored due to the + # mocked _draw_samples method below + aug = iaa.Jigsaw(nb_rows=5, nb_cols=5, max_steps=1) + kpsoi = ia.KeypointsOnImage.from_xy_array([ + (0, 0), + (5.5, 3.5), + (4.0, 12.5), + (11.1, 11.2), + (12.1, 23.5) + ], shape=(18, 18, 3)) + destinations = np.array([ + [3, 1], + [2, 0] + ], dtype=np.int32) + + old_func = aug._draw_samples + + def _mocked_draw_samples(batch, random_state): + samples = old_func(batch, random_state) + return geometriclib._JigsawSamples( + nb_rows=samples.nb_rows, + nb_cols=samples.nb_cols, + max_steps=samples.max_steps, + destinations=[destinations]) + + aug._draw_samples = _mocked_draw_samples + + kpsoi_aug = aug(keypoints=kpsoi) + + expected = kpsoi.deepcopy() + expected.shape = (20, 20, 3) + # (0.0, 0.0) to cell at bottom-right, 1px pad at top and left + expected.keypoints[0].x = 10.0 + (0.0 - 0.0) + 1.0 + expected.keypoints[0].y = 10.0 + (0.0 - 0.0) + 1.0 + # (5.5, 3.5) to cell at bottom-right, 1px pad at top and left + expected.keypoints[1].x = 10.0 + (5.5 - 0.0) + 1.0 + expected.keypoints[1].y = 10.0 + (3.5 - 0.0) + 1.0 + # (4.0, 12.5) not moved to other cell, but 1px pad at top and left + expected.keypoints[2].x = 4.0 + 1.0 + expected.keypoints[2].y = 12.5 + 1.0 + # (11.0, 11.0) to cell at top-left, 1px pad at top and left + expected.keypoints[3].x = 0.0 + (11.1 - 10.0) + 1.0 + expected.keypoints[3].y = 0.0 + (11.2 - 10.0) + 1.0 + # (12.1, 23.5) not moved to other cell, but 1px pad at top and left + expected.keypoints[4].x = 12.1 + 1.0 + expected.keypoints[4].y = 23.5 + 1.0 + expected.shape = (20, 20, 3) + assert kpsoi_aug.shape == (20, 20, 3) + assert np.allclose(kpsoi_aug.to_xy_array(), expected.to_xy_array()) + + def test_images_and_heatmaps_aligned(self): + nb_changed = 0 + rs = iarandom.RNG(0) + for _ in np.arange(10): + aug = iaa.Jigsaw(nb_rows=(2, 5), nb_cols=(2, 5), max_steps=(0, 3)) + image_small = rs.integers(0, 10, size=(10, 15)).astype(np.float32) + image_small = image_small / 10.0 + image = ia.imresize_single_image(image_small, (20, 30), + interpolation="cubic") + image = np.clip(image, 0, 1.0) + hm = ia.HeatmapsOnImage(image_small, shape=(20, 30)) + + images_aug, hms_aug = aug(images=[image, image, image], + heatmaps=[hm, hm, hm]) + + for image_aug, hm_aug in zip(images_aug, hms_aug): + # TODO added squeeze here because get_arr() falsely returns + # (H,W,1) for 2D inputs + arr = np.squeeze(hm_aug.get_arr()) + image_aug_rs = ia.imresize_single_image( + image_aug.astype(np.float32), + arr.shape[0:2], + interpolation="cubic") + image_aug_rs = np.clip(image_aug_rs, 0, 1.0) + overlap = np.average(np.isclose(image_aug_rs, arr)) + + assert overlap > 0.99 + if not np.array_equal(arr, hm.get_arr()): + nb_changed += 1 + assert nb_changed > 5 + + def test_images_and_segmaps_aligned(self): + nb_changed = 0 + rs = iarandom.RNG(0) + for _ in np.arange(10): + aug = iaa.Jigsaw(nb_rows=(2, 5), nb_cols=(2, 5), max_steps=(0, 3)) + image_small = rs.integers(0, 10, size=(10, 15)) + image = ia.imresize_single_image(image_small, (20, 30), + interpolation="nearest") + image = image.astype(np.uint8) + segm = ia.SegmentationMapsOnImage(image_small, shape=(20, 30)) + + images_aug, sms_aug = aug(images=[image, image, image], + segmentation_maps=[segm, segm, segm]) + + for image_aug, sm_aug in zip(images_aug, sms_aug): + arr = sm_aug.get_arr() + image_aug_rs = ia.imresize_single_image( + image_aug, arr.shape[0:2], interpolation="nearest") + overlap = np.average(image_aug_rs == arr) + + assert overlap > 0.99 + if not np.array_equal(arr, segm.arr): + nb_changed += 1 + assert nb_changed > 5 + + def test_images_and_keypoints_aligned(self): + rs = iarandom.RNG(0) + for _ in np.arange(10): + aug = iaa.Jigsaw(nb_rows=(2, 5), nb_cols=(2, 5), max_steps=(0, 3)) + y = rs.integers(0, 20, size=(1,), endpoint=False) + x = rs.integers(0, 30, size=(1,), endpoint=False) + kpsoi = ia.KeypointsOnImage([ia.Keypoint(x=x, y=y)], shape=(20, 30)) + image = np.zeros((20, 30), dtype=np.uint8) + image[y, x] = 255 + + images_aug, kpsois_aug = aug(images=[image, image, image], + keypoints=[kpsoi, kpsoi, kpsoi]) + + for image_aug, kpsoi_aug in zip(images_aug, kpsois_aug): + x_aug = kpsoi_aug.keypoints[0].x + y_aug = kpsoi_aug.keypoints[0].y + idx = np.argmax(image_aug) + y_aug_img, x_aug_img = np.unravel_index(idx, + image_aug.shape) + dist = np.sqrt((x_aug - x_aug_img)**2 + (y_aug - y_aug_img)**2) + assert dist < 1.5 + + def test_no_error_for_1x1_grids(self): + aug = iaa.Jigsaw(nb_rows=1, nb_cols=1, max_steps=2) + image = np.mod(np.arange(19*19*3), 255).astype(np.uint8) + image = image.reshape((19, 19, 3)) + kpsoi = ia.KeypointsOnImage.from_xy_array([ + (0, 0), + (5.5, 3.5), + (4.0, 12.5), + (11.1, 11.2), + (12.1, 23.5) + ], shape=(19, 19, 3)) + + image_aug, kpsoi_aug = aug(image=image, keypoints=kpsoi) + + assert np.array_equal(image_aug, image) + assert np.allclose(kpsoi_aug.to_xy_array(), kpsoi.to_xy_array()) + + def test_zero_sized_axes(self): + shapes = [ + (0, 0), + (0, 1), + (1, 0), + (0, 1, 0), + (1, 0, 0), + (0, 1, 1), + (1, 0, 1) + ] + + for shape in shapes: + with self.subTest(shape=shape): + for _ in sm.xrange(3): + image = np.zeros(shape, dtype=np.uint8) + aug = iaa.Jigsaw(nb_rows=2, nb_cols=2, max_steps=2) + + image_aug = aug(image=image) + + # (2, 2, [C]) here, because rows/cols are padded to be + # multiple of nb_rows and nb_cols + shape_exp = tuple([2, 2] + list(shape[2:])) + assert image_aug.dtype.name == "uint8" + assert np.array_equal(image_aug, + np.zeros(shape_exp, dtype=np.uint8)) + + def test_get_parameters(self): + aug = iaa.Jigsaw(nb_rows=1, nb_cols=2) + params = aug.get_parameters() + assert params[0] is aug.nb_rows + assert params[1] is aug.nb_cols + assert params[2] is aug.max_steps + assert params[3] is True