Skip to content

Commit

Permalink
Add RandomGrayscale Layer (#20639)
Browse files Browse the repository at this point in the history
* Add RandomGrayscale Layer

* Fix torch tests

* format

* fix

* fix

* Fix torch tests
  • Loading branch information
IMvision12 authored Dec 14, 2024
1 parent 4aa6a67 commit 5fc7b6a
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
RandomFlip,
)
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
RandomGrayscale,
)
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
RandomFlip,
)
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
RandomGrayscale,
)
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
RandomFlip,
)
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
RandomGrayscale,
)
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
Expand Down
103 changes: 103 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)


@keras_export("keras.layers.RandomGrayscale")
class RandomGrayscale(BaseImagePreprocessingLayer):
"""Preprocessing layer for random conversion of RGB images to grayscale.
This layer randomly converts input images to grayscale with a specified
factor. When applied, it maintains the original number of channels
but sets all channels to the same grayscale value. This can be useful
for data augmentation and training models to be robust to color
variations.
The conversion preserves the perceived luminance of the original color
image using standard RGB to grayscale conversion coefficients. Images
that are not selected for conversion remain unchanged.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
factor: Float between 0 and 1, specifying the factor of
converting each image to grayscale. Defaults to 0.5. A value of
1.0 means all images will be converted, while 0.0 means no images
will be converted.
data_format: String, one of `"channels_last"` (default) or
`"channels_first"`. The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch, height, width, channels)` while `"channels_first"`
corresponds to inputs with shape
`(batch, channels, height, width)`.
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format,
or `(..., channels, height, width)`, in `"channels_first"` format.
Output shape:
Same as input shape. The output maintains the same number of channels
as the input, even for grayscale-converted images where all channels
will have the same value.
"""

def __init__(self, factor=0.5, data_format=None, **kwargs):
super().__init__(**kwargs)
if factor < 0 or factor > 1:
raise ValueError(
"`factor` should be between 0 and 1. "
f"Received: factor={factor}"
)
self.factor = factor
self.data_format = backend.standardize_data_format(data_format)
self.random_generator = self.backend.random.SeedGenerator()

def get_random_transformation(self, images, training=True, seed=None):
random_values = self.backend.random.uniform(
shape=(self.backend.core.shape(images)[0],),
minval=0,
maxval=1,
seed=self.random_generator,
)
should_apply = self.backend.numpy.expand_dims(
random_values < self.factor, axis=[1, 2, 3]
)
return should_apply

def transform_images(self, images, transformations=None, **kwargs):
should_apply = (
transformations
if transformations is not None
else self.get_random_transformation(images)
)

grayscale_images = self.backend.image.rgb_to_grayscale(
images, data_format=self.data_format
)
return self.backend.numpy.where(should_apply, grayscale_images, images)

def compute_output_shape(self, input_shape):
return input_shape

def compute_output_spec(self, inputs, **kwargs):
return inputs

def transform_bounding_boxes(self, bounding_boxes, **kwargs):
return bounding_boxes

def transform_labels(self, labels, transformations=None, **kwargs):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformations=None, **kwargs
):
return segmentation_masks

def get_config(self):
config = super().get_config()
config.update({"factor": self.factor})
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import pytest
from absl.testing import parameterized
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import ops
from keras.src import testing


class RandomGrayscaleTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomGrayscale,
init_kwargs={
"factor": 0.5,
"data_format": "channels_last",
},
input_shape=(1, 2, 2, 3),
supports_masking=False,
expected_output_shape=(1, 2, 2, 3),
)

self.run_layer_test(
layers.RandomGrayscale,
init_kwargs={
"factor": 0.5,
"data_format": "channels_first",
},
input_shape=(1, 3, 2, 2),
supports_masking=False,
expected_output_shape=(1, 3, 2, 2),
)

@parameterized.named_parameters(
("channels_last", "channels_last"), ("channels_first", "channels_first")
)
def test_grayscale_conversion(self, data_format):
if data_format == "channels_last":
xs = np.random.uniform(0, 255, size=(2, 4, 4, 3)).astype(np.float32)
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)
transformed = ops.convert_to_numpy(layer(xs))
self.assertEqual(transformed.shape[-1], 3)
for img in transformed:
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
self.assertTrue(np.allclose(r, g) and np.allclose(g, b))
else:
xs = np.random.uniform(0, 255, size=(2, 3, 4, 4)).astype(np.float32)
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)
transformed = ops.convert_to_numpy(layer(xs))
self.assertEqual(transformed.shape[1], 3)
for img in transformed:
r, g, b = img[0], img[1], img[2]
self.assertTrue(np.allclose(r, g) and np.allclose(g, b))

def test_invalid_factor(self):
with self.assertRaises(ValueError):
layers.RandomGrayscale(factor=-0.1)

with self.assertRaises(ValueError):
layers.RandomGrayscale(factor=1.1)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3)) * 255
else:
input_data = np.random.random((2, 3, 8, 8)) * 255

layer = layers.RandomGrayscale(factor=0.5, data_format=data_format)
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)

for output in ds.take(1):
output_array = output.numpy()
self.assertEqual(output_array.shape, input_data.shape)

def test_grayscale_with_single_color_image(self):
test_cases = [
(np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"),
(np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"),
]

for xs, data_format in test_cases:
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)
transformed = ops.convert_to_numpy(layer(xs))

if data_format == "channels_last":
unique_vals = np.unique(transformed[0, :, :, 0])
self.assertEqual(len(unique_vals), 1)
else:
unique_vals = np.unique(transformed[0, 0, :, :])
self.assertEqual(len(unique_vals), 1)

0 comments on commit 5fc7b6a

Please sign in to comment.