Skip to content

Commit

Permalink
Add Jigsaw
Browse files Browse the repository at this point in the history
  • Loading branch information
aleju committed Nov 1, 2019
1 parent 8980641 commit a104368
Show file tree
Hide file tree
Showing 3 changed files with 649 additions and 2 deletions.
8 changes: 7 additions & 1 deletion checks/check_jigsaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@


def main():
image = ia.quokka_square((200, 200))
aug = iaa.Jigsaw(10, 10)

images_aug = aug(images=[image] * 16)
ia.imshow(ia.draw_grid(images_aug))

gen_time = timeit.timeit(
"iaa.generate_jigsaw_destinations(10, 10, 2, rng)",
number=128,
Expand All @@ -15,7 +21,7 @@ def main():
)
)
print("Time to generate 128x dest:", gen_time)
image = ia.quokka_square((200, 200))

destinations = iaa.generate_jigsaw_destinations(10, 10, 1, random_state=1)
image_jig = iaa.apply_jigsaw(image, destinations)
ia.imshow(image_jig)
Expand Down
269 changes: 268 additions & 1 deletion imgaug/augmenters/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* PerspectiveTransform
* ElasticTransformation
* Rot90
* Jigsaw
"""
from __future__ import print_function, division, absolute_import
Expand All @@ -34,6 +35,7 @@

from . import meta
from . import blur as blur_lib
from . import size as size_lib
import imgaug as ia
from imgaug.augmentables.polys import _ConcavePolygonRecoverer
from .. import parameters as iap
Expand Down Expand Up @@ -4031,7 +4033,6 @@ class Rot90(meta.Augmenter):
random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.bit_generator.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
Examples
--------
>>> import imgaug.augmenters as iaa
Expand Down Expand Up @@ -4190,3 +4191,269 @@ def _augment_keypoints_by_samples(self, keypoints_on_images, ks):

def get_parameters(self):
return [self.k, self.keep_size]


class Jigsaw(meta.Augmenter):
"""Move cells within images similar to jigsaw patterns.
.. note::
This augmenter will by default pad images until their height is a
multiple of `nb_rows`. Analogous for `nb_cols`.
.. note::
This augmenter will resize heatmaps and segmentation maps to the
image size, then apply similar padding as for the corresponding images
and resize back to the original map size. That also means that images
may change in shape (due to padding), but heatmaps/segmaps will not
change. For heatmaps/segmaps, this deviates from pad augmenters that
will change images and heatmaps/segmaps in corresponding ways and then
keep the heatmaps/segmaps at the new size.
.. warning::
This augmenter currently only supports augmentation of images,
heatmaps, segmentation maps and keypoints. Other augmentables will
produce errors.
dtype support::
See :func:`apply_jigsaw`.
Parameters
----------
nb_rows : int or list of int or tuple of int or imgaug.parameters.StochasticParameter
How many rows the jigsaw pattern should have.
* If a single ``int``, then that value will be used for all images.
* If a tuple ``(a, b)``, then a random value will be uniformly
sampled per image from the discrete interval ``[a..b]``.
* If a list, then for each image a random value will be sampled
from that list.
* If ``StochasticParameter``, then that parameter is queried per
image to sample the value to use.
nb_cols : int or list of int or tuple of int or imgaug.parameters.StochasticParameter
How many cols the jigsaw pattern should have.
* If a single ``int``, then that value will be used for all images.
* If a tuple ``(a, b)``, then a random value will be uniformly
sampled per image from the discrete interval ``[a..b]``.
* If a list, then for each image a random value will be sampled
from that list.
* If ``StochasticParameter``, then that parameter is queried per
image to sample the value to use.
max_steps : int or list of int or tuple of int or imgaug.parameters.StochasticParameter, optional
How many steps each jigsaw cell may be moved.
* If a single ``int``, then that value will be used for all images.
* If a tuple ``(a, b)``, then a random value will be uniformly
sampled per image from the discrete interval ``[a..b]``.
* If a list, then for each image a random value will be sampled
from that list.
* If ``StochasticParameter``, then that parameter is queried per
image to sample the value to use.
allow_pad : bool, optional
Whether to allow automatically padding images until they are evenly
divisible by ``nb_rows`` and ``nb_cols``.
name : None or str, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
deterministic : bool, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.bit_generator.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
Examples
--------
>>> import numpy as np
>>> import imgaug.augmenters as iaa
>>> image = np.mod(np.arange(100*100*3), 255).astype(np.uint8)
>>> image = image.reshape((100, 100, 3))
>>> images = [image] * 16
>>> aug = iaa.Jigsaw(nb_rows=10, nb_cols=10)
>>> images_aug = aug(images=images)
Create a jigsaw augmenter and use it to augment a simple example batch
of ``16`` ``100x100x3`` images. Each image will be split into ``10x10``
cells, i.e. each cell will be ``10px`` high and ``10px`` wide.
>>> aug = iaa.Jigsaw(nb_rows=(1, 4), nb_cols=(1, 4))
Create a jigsaw augmenter that splits each image into a maximum of ``4x4``
cells.
>>> aug = iaa.Jigsaw(nb_rows=10, nb_cols=10, max_steps=(1, 5))
Create a jigsaw augmenter that moves the cells in each image by a random
amount between ``1`` and ``5`` times (decided per image). Some images will
be barely changed, some will be fairly distorted.
"""

def __init__(self, nb_rows, nb_cols, max_steps=2, allow_pad=True,
name=None, deterministic=False, random_state=None):
super(Jigsaw, self).__init__(
name=name, deterministic=deterministic, random_state=random_state)

self.nb_rows = iap.handle_discrete_param(
nb_rows, "nb_rows", value_range=(1, None), tuple_to_uniform=True,
list_to_choice=True, allow_floats=False)
self.nb_cols = iap.handle_discrete_param(
nb_cols, "nb_cols", value_range=(1, None), tuple_to_uniform=True,
list_to_choice=True, allow_floats=False)
self.max_steps = iap.handle_discrete_param(
max_steps, "max_steps", value_range=(0, None),
tuple_to_uniform=True, list_to_choice=True, allow_floats=False)
self.allow_pad = allow_pad

def _augment_batch(self, batch, random_state, parents, hooks):
samples = self._draw_samples(batch, random_state)

# We resize here heatmaps/segmaps early to the image size in order to
# avoid problems where the jigsaw cells don't fit perfectly into
# the heatmap/segmap arrays or there are minor padding-related
# differences.
# TODO This step could most likely be avoided.
# TODO add something like
# 'with batch.maps_resized_to_image_sizes(): ...'
batch, maps_shapes_orig = self._resize_maps(batch)

if self.allow_pad:
# this is a bit more difficult than one might expect, because we
# (a) might have different numbers of rows/cols per image
# (b) might have different shapes per image
# (c) have non-image data that also requires padding
# TODO enable support for stochastic parameters in
# PadToMultiplesOf, then we can simple use two
# DeterministicLists here to generate rowwise values

for i in np.arange(len(samples.destinations)):
padder = size_lib.CenterPadToMultiplesOf(
width_multiple=samples.nb_cols[i],
height_multiple=samples.nb_rows[i])
row = batch.subselect_rows_by_indices([i])
row = padder.augment_batch(row, parents=parents + [self],
hooks=hooks)
batch = batch.invert_subselect_rows_by_indices_([i], row)

if batch.images is not None:
for i, image in enumerate(batch.images):
image[...] = apply_jigsaw(image, samples.destinations[i])

if batch.heatmaps is not None:
for i, heatmap in enumerate(batch.heatmaps):
heatmap.arr_0to1 = apply_jigsaw(heatmap.arr_0to1,
samples.destinations[i])

if batch.segmentation_maps is not None:
for i, segmap in enumerate(batch.segmentation_maps):
segmap.arr = apply_jigsaw(segmap.arr, samples.destinations[i])

if batch.keypoints is not None:
for i, kpsoi in enumerate(batch.keypoints):
xy = kpsoi.to_xy_array()
xy[...] = apply_jigsaw_to_coords(xy,
samples.destinations[i],
image_shape=kpsoi.shape)
kpsoi.fill_from_xy_array_(xy)

has_other_cbaoi = any([getattr(batch, attr_name) is not None
for attr_name
in ["bounding_boxes", "polygons",
"line_strings"]])
if has_other_cbaoi:
raise NotImplementedError(
"Jigsaw currently only supports augmentation of images "
"and keypoints.")

# We don't crop back to the original size, partly because it is
# rather cumbersome to implement, partly because the padded
# borders might have been moved into the inner parts of the image

batch = self._invert_resize_maps(batch, maps_shapes_orig)

return batch

def _draw_samples(self, batch, random_state):
nb_images = batch.nb_rows
nb_rows = self.nb_rows.draw_samples((nb_images,),
random_state=random_state)
nb_cols = self.nb_cols.draw_samples((nb_images,),
random_state=random_state)
max_steps = self.max_steps.draw_samples((nb_images,),
random_state=random_state)
destinations = []
for i in np.arange(nb_images):
destinations.append(
generate_jigsaw_destinations(
nb_rows[i], nb_cols[i], max_steps[i],
random_state=random_state)
)

samples = _JigsawSamples(nb_rows, nb_cols, max_steps, destinations)
return samples

@classmethod
def _resize_maps(cls, batch):
# skip computation of rowwise shapes
if batch.heatmaps is None and batch.segmentation_maps is None:
return batch, (None, None)

image_shapes = batch.get_rowwise_shapes()
batch.heatmaps, heatmaps_shapes_orig = cls._resize_maps_single_list(
batch.heatmaps, "arr_0to1", image_shapes)
batch.segmentation_maps, sm_shapes_orig = cls._resize_maps_single_list(
batch.segmentation_maps, "arr", image_shapes)

return batch, (heatmaps_shapes_orig, sm_shapes_orig)

@classmethod
def _resize_maps_single_list(cls, augmentables, arr_attr_name,
image_shapes):
if augmentables is None:
return None, None

shapes_orig = []
augms_resized = []
for augmentable, image_shape in zip(augmentables, image_shapes):
shape_orig = getattr(augmentable, arr_attr_name).shape
augm_rs = augmentable.resize(image_shape[0:2])
augms_resized.append(augm_rs)
shapes_orig.append(shape_orig)
return augms_resized, shapes_orig

@classmethod
def _invert_resize_maps(cls, batch, shapes_orig):
batch.heatmaps = cls._invert_resize_maps_single_list(
batch.heatmaps, shapes_orig[0])
batch.segmentation_maps = cls._invert_resize_maps_single_list(
batch.segmentation_maps, shapes_orig[1])

return batch

@classmethod
def _invert_resize_maps_single_list(cls, augmentables, shapes_orig):
if shapes_orig is None:
return None

augms_resized = []
for augmentable, shape_orig in zip(augmentables, shapes_orig):
augms_resized.append(augmentable.resize(shape_orig[0:2]))
return augms_resized

def get_parameters(self):
return [self.nb_rows, self.nb_cols, self.max_steps, self.allow_pad]


class _JigsawSamples(object):
def __init__(self, nb_rows, nb_cols, max_steps, destinations):
self.nb_rows = nb_rows
self.nb_cols = nb_cols
self.max_steps = max_steps
self.destinations = destinations
Loading

0 comments on commit a104368

Please sign in to comment.