Skip to content

Commit

Permalink
Add random_color_degeneration processing layer (#20679)
Browse files Browse the repository at this point in the history
* Add random_color_degeneration processing layer

* Fix mistypo

* Correct failed test case
  • Loading branch information
shashaka authored Dec 25, 2024
1 parent df002a9 commit f54c127
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
RandomColorJitter,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
RandomColorJitter,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
RandomColorJitter,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator


@keras_export("keras.layers.RandomColorDegeneration")
class RandomColorDegeneration(BaseImagePreprocessingLayer):
"""Randomly performs the color degeneration operation on given images.
The sharpness operation first converts an image to gray scale, then back to
color. It then takes a weighted average between original image and the
degenerated image. This makes colors appear more dull.
Args:
factor: A tuple of two floats or a single float.
`factor` controls the extent to which the
image sharpness is impacted. `factor=0.0` makes this layer perform a
no-op operation, while a value of 1.0 uses the degenerated result
entirely. Values between 0 and 1 result in linear interpolation
between the original image and the sharpened image.
Values should be between `0.0` and `1.0`. If a tuple is used, a
`factor` is sampled between the two values for every image
augmented. If a single float is used, a value between `0.0` and the
passed float is sampled. In order to ensure the value is always the
same, please pass a tuple with two identical floats: `(0.5, 0.5)`.
seed: Integer. Used to create a random seed.
"""

_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)

def __init__(
self,
factor,
value_range=(0, 255),
data_format=None,
seed=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self.seed = seed
self.generator = SeedGenerator(seed)

def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)

def get_random_transformation(self, data, training=True, seed=None):
if isinstance(data, dict):
images = data["images"]
else:
images = data
images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
batch_size = 1
elif rank == 4:
batch_size = images_shape[0]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received: "
f"inputs.shape={images_shape}"
)

if seed is None:
seed = self._get_seed_generator(self.backend._backend)

factor = self.backend.random.uniform(
(batch_size, 1, 1, 1),
minval=self.factor[0],
maxval=self.factor[1],
seed=seed,
)
factor = factor
return {"factor": factor}

def transform_images(self, images, transformation=None, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)
factor = self.backend.cast(
transformation["factor"], self.compute_dtype
)
degenerates = self.backend.image.rgb_to_grayscale(
images, data_format=self.data_format
)
images = images + factor * (degenerates - images)
images = self.backend.numpy.clip(
images, self.value_range[0], self.value_range[1]
)
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

def transform_bounding_boxes(
self, bounding_boxes, transformation, training=True
):
return bounding_boxes

def get_config(self):
config = super().get_config()
config.update(
{
"factor": self.factor,
"value_range": self.value_range,
"seed": self.seed,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

import keras
from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandomColorDegenerationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomColorDegeneration,
init_kwargs={
"factor": 0.75,
"value_range": (0, 1),
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_random_color_degeneration_value_range(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1))
adjusted_image = layer(image)

self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))

def test_random_color_degeneration_no_op(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))

layer = layers.RandomColorDegeneration((0.5, 0.5))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)

def test_random_color_degeneration_factor_zero(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))
layer = layers.RandomColorDegeneration(factor=(0.0, 0.0))
result = layer(inputs)

self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5)

def test_random_color_degeneration_randomness(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]

layer = layers.RandomColorDegeneration(0.2)
adjusted_images = layer(image)

self.assertNotAllClose(adjusted_images, image)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandomColorDegeneration(
factor=0.5, data_format=data_format, seed=1337
)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit f54c127

Please sign in to comment.