Skip to content

Commit

Permalink
Refactor augmenter to use new LUT functions
Browse files Browse the repository at this point in the history
This patch switches all augmenters from cv2.lut() to
imgaug's apply_lut() and apply_lut_() functions.
This decreases code duplication and likely fixes
some bugs that were not yet discovered.
  • Loading branch information
aleju committed Dec 30, 2019
1 parent 2326e4d commit dc87059
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 85 deletions.
3 changes: 3 additions & 0 deletions changelogs/master/added/20191230_standardized_lut.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@

* Added `imgaug.imgaug.apply_lut()`, which applies a lookup table to an image.
* Added `imgaug.imgaug.apply_lut_()`. In-place version of `apply_lut()`.
* Refactored all augmenters to use these new LUT functions.
This likely fixed some so-far undiscovered bugs in augmenters using LUT
tables.
54 changes: 13 additions & 41 deletions imgaug/augmenters/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,23 +146,13 @@ def _add_scalar_to_uint8(image, value):
result = []
# TODO check if tile() is here actually needed
tables = np.tile(
value_range[np.newaxis, :],
(nb_channels, 1)
) + value[:, np.newaxis]
tables = np.clip(tables, 0, 255).astype(image.dtype)

for c, table in enumerate(tables):
result.append(cv2.LUT(image[..., c], table))

return np.stack(result, axis=-1)
value_range[:, np.newaxis],
(1, nb_channels)
) + value[np.newaxis, :]
else:
table = value_range + value
image_aug = cv2.LUT(
image,
iadt.clip_(table, 0, 255).astype(image.dtype))
if image_aug.ndim == 2 and image.ndim == 3:
image_aug = image_aug[..., np.newaxis]
return image_aug
tables = value_range + value
tables = np.clip(tables, 0, 255).astype(image.dtype)
return ia.apply_lut(image, tables)


def _add_scalar_to_non_uint8(image, value):
Expand Down Expand Up @@ -436,23 +426,13 @@ def _multiply_scalar_to_uint8(image, multiplier):
result = []
# TODO check if tile() is here actually needed
tables = np.tile(
value_range[np.newaxis, :],
(nb_channels, 1)
) * multiplier[:, np.newaxis]
tables = np.clip(tables, 0, 255).astype(image.dtype)

for c, table in enumerate(tables):
arr_aug = cv2.LUT(image[..., c], table)
result.append(arr_aug)

return np.stack(result, axis=-1)
value_range[:, np.newaxis],
(1, nb_channels)
) * multiplier[np.newaxis, :]
else:
table = value_range * multiplier
image_aug = cv2.LUT(
image, np.clip(table, 0, 255).astype(image.dtype))
if image_aug.ndim == 2 and image.ndim == 3:
image_aug = image_aug[..., np.newaxis]
return image_aug
tables = value_range * multiplier
tables = np.clip(tables, 0, 255).astype(image.dtype)
return ia.apply_lut(image, tables)


def _multiply_scalar_to_non_uint8(image, multiplier):
Expand Down Expand Up @@ -939,17 +919,9 @@ def _invert_bool(arr, min_value, max_value):

def _invert_uint8_(arr, min_value, max_value, threshold,
invert_above_threshold):
if 0 in arr.shape:
return np.copy(arr)

if arr.flags["OWNDATA"] is False:
arr = np.copy(arr)
if arr.flags["C_CONTIGUOUS"] is False:
arr = np.ascontiguousarray(arr)

table = _generate_table_for_invert_uint8(
min_value, max_value, threshold, invert_above_threshold)
arr = cv2.LUT(arr, table, dst=arr)
arr = ia.apply_lut_(arr, table)
return arr


Expand Down
1 change: 1 addition & 0 deletions imgaug/augmenters/blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def blur_mean_shift_(image, spatial_window_radius, color_window_radius):
image = np.tile(image, (1, 1, 3))

# prevent image from becoming cv2.UMat
# TODO merge this with apply_lut() normalization/validation
if image.flags["C_CONTIGUOUS"] is False:
image = np.ascontiguousarray(image)

Expand Down
19 changes: 10 additions & 9 deletions imgaug/augmenters/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def _get_dst(image, from_to_cspace):
# images that are views (e.g. image[..., 0:3]) and returns a
# cv2.UMat instance instead of an array. So we check here first
# if the array looks like it is non-contiguous or a view.
# TODO merge this with apply_lut() normalization/validation
if image.flags["C_CONTIGUOUS"]:
return image
return None
Expand Down Expand Up @@ -2326,11 +2327,13 @@ def _transform_image_cv2(cls, image_hsv, hue, saturation):
# code with using cache (at best maybe 10% faster for 64x64):
table_hue = cls._LUT_CACHE[0]
table_saturation = cls._LUT_CACHE[1]
tables = [
table_hue[255+int(hue)],
table_saturation[255+int(saturation)]
]

image_hsv[..., 0] = cv2.LUT(
image_hsv[..., 0], table_hue[255+int(hue)])
image_hsv[..., 1] = cv2.LUT(
image_hsv[..., 1], table_saturation[255+int(saturation)])
image_hsv[..., [0, 1]] = ia.apply_lut(image_hsv[..., [0, 1]],
tables)

return image_hsv

Expand Down Expand Up @@ -3095,7 +3098,7 @@ def _generate_pixelwise_alpha_mask(cls, image_hsv, hue_to_alpha):
hue = image_hsv[:, :, 0]
table = hue_to_alpha * 255
table = np.clip(np.round(table), 0, 255).astype(np.uint8)
mask = cv2.LUT(hue, table)
mask = ia.apply_lut(hue, table)
return mask.astype(np.float32) / 255.0

def get_parameters(self):
Expand Down Expand Up @@ -4060,22 +4063,20 @@ def quantize_uniform_(arr, nb_bins, to_bin_centers=True):
if nb_bins == 256 or 0 in arr.shape:
return arr

# TODO remove dtype check here? apply_lut_() does that already
assert arr.dtype.name == "uint8", "Expected uint8 image, got %s." % (
arr.dtype.name,)
assert 2 <= nb_bins <= 256, (
"Expected nb_bins to be in the discrete interval [2..256]. "
"Got a value of %d instead." % (nb_bins,))

if arr.flags["C_CONTIGUOUS"] is False:
arr = np.ascontiguousarray(arr)

table_class = (_QuantizeUniformCenterizedLUTTableSingleton
if to_bin_centers
else _QuantizeUniformNotCenterizedLUTTableSingleton)
table = (table_class
.get_instance()
.get_for_nb_bins(nb_bins))
arr = cv2.LUT(arr, table, dst=arr)
arr = ia.apply_lut_(arr, table)
return arr


Expand Down
57 changes: 22 additions & 35 deletions imgaug/augmenters/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def adjust_contrast_gamma(arr, gamma):
return np.copy(arr)

# int8 is also possible according to docs
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT , but here it seemed
# like `d` was 0 for CV_8S, causing that to fail
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT ,
# but here it seemed like `d` was 0 for CV_8S, causing that to fail
if arr.dtype.name == "uint8":
min_value, _center_value, max_value = \
iadt.get_value_range_of_dtype(arr.dtype)
Expand All @@ -174,10 +174,8 @@ def adjust_contrast_gamma(arr, gamma):
table = (min_value
+ (value_range ** np.float32(gamma))
* dynamic_range)
arr_aug = cv2.LUT(
arr, np.clip(table, min_value, max_value).astype(arr.dtype))
if arr.ndim == 3 and arr_aug.ndim == 2:
return arr_aug[..., np.newaxis]
table = np.clip(table, min_value, max_value).astype(arr.dtype)
arr_aug = ia.apply_lut(arr, table)
return arr_aug
return ski_exposure.adjust_gamma(arr, gamma)

Expand Down Expand Up @@ -244,8 +242,8 @@ def adjust_contrast_sigmoid(arr, gain, cutoff):
return np.copy(arr)

# int8 is also possible according to docs
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT , but here it seemed
# like `d` was 0 for CV_8S, causing that to fail
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT ,
# but here it seemed like `d` was 0 for CV_8S, causing that to fail
if arr.dtype.name == "uint8":
min_value, _center_value, max_value = \
iadt.get_value_range_of_dtype(arr.dtype)
Expand All @@ -262,10 +260,8 @@ def adjust_contrast_sigmoid(arr, gain, cutoff):
table = (min_value
+ dynamic_range
* 1/(1 + np.exp(gain * (cutoff - value_range))))
arr_aug = cv2.LUT(
arr, np.clip(table, min_value, max_value).astype(arr.dtype))
if arr.ndim == 3 and arr_aug.ndim == 2:
return arr_aug[..., np.newaxis]
table = np.clip(table, min_value, max_value).astype(arr.dtype)
arr_aug = ia.apply_lut(arr, table)
return arr_aug
return ski_exposure.adjust_sigmoid(arr, cutoff=cutoff, gain=gain)

Expand Down Expand Up @@ -331,8 +327,8 @@ def adjust_contrast_log(arr, gain):
return np.copy(arr)

# int8 is also possible according to docs
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT , but here it seemed
# like `d` was 0 for CV_8S, causing that to fail
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT ,
# but here it seemed like `d` was 0 for CV_8S, causing that to fail
if arr.dtype.name == "uint8":
min_value, _center_value, max_value = \
iadt.get_value_range_of_dtype(arr.dtype)
Expand All @@ -346,10 +342,8 @@ def adjust_contrast_log(arr, gain):
# of size 1
gain = np.float32(gain)
table = min_value + dynamic_range * gain * np.log2(1 + value_range)
arr_aug = cv2.LUT(
arr, np.clip(table, min_value, max_value).astype(arr.dtype))
if arr.ndim == 3 and arr_aug.ndim == 2:
return arr_aug[..., np.newaxis]
table = np.clip(table, min_value, max_value).astype(arr.dtype)
arr_aug = ia.apply_lut(arr, table)
return arr_aug
return ski_exposure.adjust_log(arr, gain=gain)

Expand Down Expand Up @@ -403,8 +397,8 @@ def adjust_contrast_linear(arr, alpha):
return np.copy(arr)

# int8 is also possible according to docs
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT , but here it seemed
# like `d` was 0 for CV_8S, causing that to fail
# https://docs.opencv.org/3.0-beta/modules/core/doc/operations_on_arrays.html#cv2.LUT ,
# but here it seemed like `d` was 0 for CV_8S, causing that to fail
if arr.dtype.name == "uint8":
min_value, center_value, max_value = \
iadt.get_value_range_of_dtype(arr.dtype)
Expand All @@ -418,10 +412,8 @@ def adjust_contrast_linear(arr, alpha):
# of size 1
alpha = np.float32(alpha)
table = center_value + alpha * (value_range - center_value)
arr_aug = cv2.LUT(
arr, np.clip(table, min_value, max_value).astype(arr.dtype))
if arr.ndim == 3 and arr_aug.ndim == 2:
return arr_aug[..., np.newaxis]
table = np.clip(table, min_value, max_value).astype(arr.dtype)
arr_aug = ia.apply_lut(arr, table)
return arr_aug
else:
input_dtype = arr.dtype
Expand Down Expand Up @@ -538,13 +530,8 @@ def equalize_(image, mask=None):
# note that this is supposed to be a non-PIL reimplementation of PIL's
# equalize, which produces slightly different results from cv2.equalizeHist()
def _equalize_no_pil_(image, mask=None):
flags = image.flags
if not flags["OWNDATA"]:
image = np.copy(image)
if not flags["C_CONTIGUOUS"]:
image = np.ascontiguousarray(image)

nb_channels = 1 if image.ndim == 2 else image.shape[-1]
# TODO remove the first axis, no longer needed
lut = np.empty((1, 256, nb_channels), dtype=np.int32)

for c_idx in range(nb_channels):
Expand All @@ -568,9 +555,7 @@ def _equalize_no_pil_(image, mask=None):
lut[0, 1:, c_idx] = n + cumsum[0:-1]
lut[0, :, c_idx] //= int(step)
lut = np.clip(lut, None, 255, out=lut).astype(np.uint8)
image = cv2.LUT(image, lut, dst=image)
if image.ndim == 2 and image.ndim == 3:
return image[..., np.newaxis]
image = ia.apply_lut_(image, lut)
return image


Expand Down Expand Up @@ -677,7 +662,7 @@ def _autocontrast(image, cutoff, ignore): # noqa: C901
# using [0] instead of [int(c_idx)] allows this to work with >4
# channels
if image.ndim == 2:
image_c = image
image_c = image[:, :, np.newaxis]
else:
image_c = image[:, :, c_idx:c_idx+1]
h = cv2.calcHist([image_c], [0], None, [256], [0, 256])
Expand Down Expand Up @@ -745,7 +730,9 @@ def _autocontrast(image, cutoff, ignore): # noqa: C901
# ix = np.clip(ix, 0, 255).astype(np.uint8)
# lut = ix

result[:, :, c_idx] = cv2.LUT(image_c, lut)
# TODO change to a single call instead of one per channel
image_c_aug = ia.apply_lut(image_c, lut)
result[:, :, c_idx:c_idx+1] = image_c_aug
if image.ndim == 2:
return result[..., 0]
return result
Expand Down

0 comments on commit dc87059

Please sign in to comment.