Skip to content

Commit

Permalink
Merge pull request #466 from aleju/meanshift
Browse files Browse the repository at this point in the history
Add MeanShiftBlur
  • Loading branch information
aleju authored Nov 8, 2019
2 parents 2f28830 + 9ec9490 commit fbfc77e
Show file tree
Hide file tree
Showing 4 changed files with 377 additions and 4 deletions.
4 changes: 4 additions & 0 deletions changelogs/master/added/20191023_mean_shift_blur.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Added Augmenter `MeanShiftBlur` #466

* Added function `imgaug.augmenters.blur.blur_mean_shift_(image)`.
* Added augmenter `imgaug.augmenters.blur.MeanShiftBlur`.
14 changes: 14 additions & 0 deletions checks/check_mean_shift_blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import print_function, division, absolute_import
import imgaug as ia
import imgaug.augmenters as iaa


def main():
image = ia.quokka_square((128, 128))
aug = iaa.MeanShiftBlur()
images_aug = aug(images=[image] * 16)
ia.imshow(ia.draw_grid(images_aug))


if __name__ == "__main__":
main()
213 changes: 209 additions & 4 deletions imgaug/augmenters/blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* MedianBlur
* BilateralBlur
* MotionBlur
* MeanShiftBlur
"""
from __future__ import print_function, division, absolute_import
Expand All @@ -38,10 +39,9 @@

# TODO add border mode, cval
def blur_gaussian_(image, sigma, ksize=None, backend="auto", eps=1e-3):
"""
Blur an image using gaussian blurring.
"""Blur an image using gaussian blurring in-place.
This operation might change the input image in-place.
This operation *may* change the input image in-place.
dtype support::
Expand Down Expand Up @@ -156,8 +156,9 @@ def blur_gaussian_(image, sigma, ksize=None, backend="auto", eps=1e-3):
Returns
-------
image : numpy.ndarray
numpy.ndarray
The blurred image. Same shape and dtype as the input.
(Input image *might* have been altered in-place.)
"""
has_zero_sized_axes = (image.size == 0)
Expand Down Expand Up @@ -274,6 +275,107 @@ def blur_gaussian_(image, sigma, ksize=None, backend="auto", eps=1e-3):
return image


def blur_mean_shift_(image, spatial_window_radius, color_window_radius):
"""Apply a pyramidic mean shift filter to the input image in-place.
This produces an output image that has similarity with one modified by
a bilateral filter. That is different from mean shift *segmentation*,
which averages the colors in segments found by mean shift clustering.
This function is a thin wrapper around ``cv2.pyrMeanShiftFiltering``.
.. note ::
This function does *not* change the image's colorspace to ``RGB``
before applying the mean shift filter. A non-``RGB`` colorspace will
hence influence the results.
.. note ::
This function is quite slow.
dtype support::
* ``uint8``: yes; fully tested
* ``uint16``: no (1)
* ``uint32``: no (1)
* ``uint64``: no (1)
* ``int8``: no (1)
* ``int16``: no (1)
* ``int32``: no (1)
* ``int64``: no (1)
* ``float16``: no (1)
* ``float32``: no (1)
* ``float64``: no (1)
* ``float128``: no (1)
* ``bool``: no (1)
- (1) Not supported by ``cv2.pyrMeanShiftFiltering``.
Parameters
----------
image : ndarray
``(H,W)`` or ``(H,W,1)`` or ``(H,W,3)`` image to blur.
Images with no or one channel will be temporarily tiled to have
three channels.
spatial_window_radius : number
Spatial radius for pixels that are assumed to be similar.
color_window_radius : number
Color radius for pixels that are assumed to be similar.
Returns
-------
ndarray
Blurred input image. Same shape and dtype as the input.
(Input image *might* have been altered in-place.)
"""
if 0 in image.shape[0:2]:
return image

# opencv method only supports uint8
assert image.dtype.name == "uint8", (
"Expected image with dtype \"uint8\", "
"got \"%s\"." % (image.dtype.name,))

shape_is_hw = (image.ndim == 2)
shape_is_hw1 = (image.ndim == 3 and image.shape[-1] == 1)
shape_is_hw3 = (image.ndim == 3 and image.shape[-1] == 3)

assert shape_is_hw or shape_is_hw1 or shape_is_hw3, (
"Expected (H,W) or (H,W,1) or (H,W,3) image, "
"got shape %s." % (image.shape,))

# opencv method only supports (H,W,3), so we have to tile here for (H,W)
# and (H,W,1)
if shape_is_hw:
image = np.tile(image[..., np.newaxis], (1, 1, 3))
elif shape_is_hw1:
image = np.tile(image, (1, 1, 3))

# prevent image from becoming cv2.UMat
if image.flags["C_CONTIGUOUS"] is False:
image = np.ascontiguousarray(image)

spatial_window_radius = max(spatial_window_radius, 0)
color_window_radius = max(color_window_radius, 0)

image = cv2.pyrMeanShiftFiltering(
image,
sp=spatial_window_radius,
sr=color_window_radius,
dst=image)

if shape_is_hw:
image = image[..., 0]
elif shape_is_hw1:
image = image[..., 0:1]

return image


def _compute_gaussian_blur_ksize(sigma):
if sigma < 3.0:
ksize = 3.3 * sigma # 99% of weight
Expand Down Expand Up @@ -956,3 +1058,106 @@ def _create_matrices(_image, nb_channels, random_state_func):
super(MotionBlur, self).__init__(
_create_matrices, name=name, deterministic=deterministic,
random_state=random_state)


# TODO add a per_channel flag?
# TODO make spatial_radius a fraction of the input image size?
class MeanShiftBlur(meta.Augmenter):
"""Apply a pyramidic mean shift filter to each image.
See also :func:`blur_mean_shift_` for details.
This augmenter expects input images of shape ``(H,W)`` or ``(H,W,1)``
or ``(H,W,3)``.
.. note ::
This augmenter is quite slow.
dtype support::
See :func:`imgaug.augmenters.blur.blur_mean_shift_`.
Parameters
----------
spatial_radius : number or tuple of number or list of number or imgaug.parameters.StochasticParameter, optional
Spatial radius for pixels that are assumed to be similar.
* If ``number``: Exactly that value will be used for all images.
* If ``tuple`` ``(a, b)``: A random value will be uniformly
sampled per image from the interval ``[a, b)``.
* If ``list``: A random value will be sampled from that ``list``
per image.
* If ``StochasticParameter``: The parameter will be queried once
per batch for ``(N,)`` values with ``N`` denoting the number of
images.
color_radius : number or tuple of number or list of number or imgaug.parameters.StochasticParameter, optional
Color radius for pixels that are assumed to be similar.
* If ``number``: Exactly that value will be used for all images.
* If ``tuple`` ``(a, b)``: A random value will be uniformly
sampled per image from the interval ``[a, b)``.
* If ``list``: A random value will be sampled from that ``list``
per image.
* If ``StochasticParameter``: The parameter will be queried once
per batch for ``(N,)`` values with ``N`` denoting the number of
images.
name : None or str, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
deterministic : bool, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.bit_generator.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional
See :func:`imgaug.augmenters.meta.Augmenter.__init__`.
Examples
--------
>>> import imgaug.augmenters as iaa
>>> import numpy as np
>>> image = np.arange(5*5*3).astype(np.uint8).reshape((5, 5, 3))
>>> aug = iaa.MeanShiftBlur()
>>> image_aug = aug(image=image)
Create a mean shift blur augmenter and apply it to a simple ``5x5x3``
example image.
"""
def __init__(self, spatial_radius=(5.0, 40.0), color_radius=(5.0, 40.0),
name=None, deterministic=False, random_state=None):
super(MeanShiftBlur, self).__init__(
name=name, deterministic=deterministic, random_state=random_state)
self.spatial_window_radius = iap.handle_continuous_param(
spatial_radius, "spatial_radius",
value_range=(0.01, None), tuple_to_uniform=True,
list_to_choice=True)
self.color_window_radius = iap.handle_continuous_param(
color_radius, "color_radius",
value_range=(0.01, None), tuple_to_uniform=True,
list_to_choice=True)

def _augment_batch(self, batch, random_state, parents, hooks):
if batch.images is not None:
samples = self._draw_samples(batch, random_state)
for i, image in enumerate(batch.images):
batch.images[i] = blur_mean_shift_(
image,
spatial_window_radius=samples[0][i],
color_window_radius=samples[1][i]
)

return batch

def _draw_samples(self, batch, random_state):
nb_rows = batch.nb_rows
return (
self.spatial_window_radius.draw_samples((nb_rows,),
random_state=random_state),
self.color_window_radius.draw_samples((nb_rows,),
random_state=random_state)
)

def get_parameters(self):
return [self.spatial_window_radius, self.color_window_radius]
Loading

0 comments on commit fbfc77e

Please sign in to comment.