-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add RandomGrayscale Layer * Fix torch tests * format * fix * fix * Fix torch tests
- Loading branch information
1 parent
4aa6a67
commit 5fc7b6a
Showing
5 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from keras.src import backend | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
|
||
|
||
@keras_export("keras.layers.RandomGrayscale") | ||
class RandomGrayscale(BaseImagePreprocessingLayer): | ||
"""Preprocessing layer for random conversion of RGB images to grayscale. | ||
This layer randomly converts input images to grayscale with a specified | ||
factor. When applied, it maintains the original number of channels | ||
but sets all channels to the same grayscale value. This can be useful | ||
for data augmentation and training models to be robust to color | ||
variations. | ||
The conversion preserves the perceived luminance of the original color | ||
image using standard RGB to grayscale conversion coefficients. Images | ||
that are not selected for conversion remain unchanged. | ||
**Note:** This layer is safe to use inside a `tf.data` pipeline | ||
(independently of which backend you're using). | ||
Args: | ||
factor: Float between 0 and 1, specifying the factor of | ||
converting each image to grayscale. Defaults to 0.5. A value of | ||
1.0 means all images will be converted, while 0.0 means no images | ||
will be converted. | ||
data_format: String, one of `"channels_last"` (default) or | ||
`"channels_first"`. The ordering of the dimensions in the inputs. | ||
`"channels_last"` corresponds to inputs with shape | ||
`(batch, height, width, channels)` while `"channels_first"` | ||
corresponds to inputs with shape | ||
`(batch, channels, height, width)`. | ||
Input shape: | ||
3D (unbatched) or 4D (batched) tensor with shape: | ||
`(..., height, width, channels)`, in `"channels_last"` format, | ||
or `(..., channels, height, width)`, in `"channels_first"` format. | ||
Output shape: | ||
Same as input shape. The output maintains the same number of channels | ||
as the input, even for grayscale-converted images where all channels | ||
will have the same value. | ||
""" | ||
|
||
def __init__(self, factor=0.5, data_format=None, **kwargs): | ||
super().__init__(**kwargs) | ||
if factor < 0 or factor > 1: | ||
raise ValueError( | ||
"`factor` should be between 0 and 1. " | ||
f"Received: factor={factor}" | ||
) | ||
self.factor = factor | ||
self.data_format = backend.standardize_data_format(data_format) | ||
self.random_generator = self.backend.random.SeedGenerator() | ||
|
||
def get_random_transformation(self, images, training=True, seed=None): | ||
random_values = self.backend.random.uniform( | ||
shape=(self.backend.core.shape(images)[0],), | ||
minval=0, | ||
maxval=1, | ||
seed=self.random_generator, | ||
) | ||
should_apply = self.backend.numpy.expand_dims( | ||
random_values < self.factor, axis=[1, 2, 3] | ||
) | ||
return should_apply | ||
|
||
def transform_images(self, images, transformations=None, **kwargs): | ||
should_apply = ( | ||
transformations | ||
if transformations is not None | ||
else self.get_random_transformation(images) | ||
) | ||
|
||
grayscale_images = self.backend.image.rgb_to_grayscale( | ||
images, data_format=self.data_format | ||
) | ||
return self.backend.numpy.where(should_apply, grayscale_images, images) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def compute_output_spec(self, inputs, **kwargs): | ||
return inputs | ||
|
||
def transform_bounding_boxes(self, bounding_boxes, **kwargs): | ||
return bounding_boxes | ||
|
||
def transform_labels(self, labels, transformations=None, **kwargs): | ||
return labels | ||
|
||
def transform_segmentation_masks( | ||
self, segmentation_masks, transformations=None, **kwargs | ||
): | ||
return segmentation_masks | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({"factor": self.factor}) | ||
return config |
94 changes: 94 additions & 0 deletions
94
keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import numpy as np | ||
import pytest | ||
from absl.testing import parameterized | ||
from tensorflow import data as tf_data | ||
|
||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import ops | ||
from keras.src import testing | ||
|
||
|
||
class RandomGrayscaleTest(testing.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_layer(self): | ||
self.run_layer_test( | ||
layers.RandomGrayscale, | ||
init_kwargs={ | ||
"factor": 0.5, | ||
"data_format": "channels_last", | ||
}, | ||
input_shape=(1, 2, 2, 3), | ||
supports_masking=False, | ||
expected_output_shape=(1, 2, 2, 3), | ||
) | ||
|
||
self.run_layer_test( | ||
layers.RandomGrayscale, | ||
init_kwargs={ | ||
"factor": 0.5, | ||
"data_format": "channels_first", | ||
}, | ||
input_shape=(1, 3, 2, 2), | ||
supports_masking=False, | ||
expected_output_shape=(1, 3, 2, 2), | ||
) | ||
|
||
@parameterized.named_parameters( | ||
("channels_last", "channels_last"), ("channels_first", "channels_first") | ||
) | ||
def test_grayscale_conversion(self, data_format): | ||
if data_format == "channels_last": | ||
xs = np.random.uniform(0, 255, size=(2, 4, 4, 3)).astype(np.float32) | ||
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) | ||
transformed = ops.convert_to_numpy(layer(xs)) | ||
self.assertEqual(transformed.shape[-1], 3) | ||
for img in transformed: | ||
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] | ||
self.assertTrue(np.allclose(r, g) and np.allclose(g, b)) | ||
else: | ||
xs = np.random.uniform(0, 255, size=(2, 3, 4, 4)).astype(np.float32) | ||
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) | ||
transformed = ops.convert_to_numpy(layer(xs)) | ||
self.assertEqual(transformed.shape[1], 3) | ||
for img in transformed: | ||
r, g, b = img[0], img[1], img[2] | ||
self.assertTrue(np.allclose(r, g) and np.allclose(g, b)) | ||
|
||
def test_invalid_factor(self): | ||
with self.assertRaises(ValueError): | ||
layers.RandomGrayscale(factor=-0.1) | ||
|
||
with self.assertRaises(ValueError): | ||
layers.RandomGrayscale(factor=1.1) | ||
|
||
def test_tf_data_compatibility(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) * 255 | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) * 255 | ||
|
||
layer = layers.RandomGrayscale(factor=0.5, data_format=data_format) | ||
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) | ||
|
||
for output in ds.take(1): | ||
output_array = output.numpy() | ||
self.assertEqual(output_array.shape, input_data.shape) | ||
|
||
def test_grayscale_with_single_color_image(self): | ||
test_cases = [ | ||
(np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"), | ||
(np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"), | ||
] | ||
|
||
for xs, data_format in test_cases: | ||
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) | ||
transformed = ops.convert_to_numpy(layer(xs)) | ||
|
||
if data_format == "channels_last": | ||
unique_vals = np.unique(transformed[0, :, :, 0]) | ||
self.assertEqual(len(unique_vals), 1) | ||
else: | ||
unique_vals = np.unique(transformed[0, 0, :, :]) | ||
self.assertEqual(len(unique_vals), 1) |