Skip to content

Commit

Permalink
Add a GrayscaleToRgb transform that can expand channels to 3 (#8247)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadsharif1 authored Mar 15, 2024
1 parent fa82fd3 commit 2bababf
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 1 deletion.
4 changes: 3 additions & 1 deletion docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ Color
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
v2.Grayscale
v2.RGB
v2.RandomGrayscale
v2.GaussianBlur
v2.RandomInvert
Expand All @@ -364,6 +365,7 @@ Functionals

v2.functional.permute_channels
v2.functional.rgb_to_grayscale
v2.functional.grayscale_to_rgb
v2.functional.to_grayscale
v2.functional.gaussian_blur
v2.functional.invert
Expand Down Expand Up @@ -584,7 +586,7 @@ Conversion
while performing the conversion, while some may not do any scaling. By
scaling, we mean e.g. that a ``uint8`` -> ``float32`` would map the [0,
255] range into [0, 1] (and vice-versa). See :ref:`range_and_dtype`.

.. autosummary::
:toctree: generated/
:template: class.rst
Expand Down
48 changes: 48 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5005,6 +5005,54 @@ def test_random_transform_correctness(self, num_input_channels):
assert_equal(actual, expected, rtol=0, atol=1)


class TestGrayscaleToRgb:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, dtype, device):
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_functional(self, make_input):
check_functional(F.grayscale_to_rgb, make_input())

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.rgb_to_grayscale_image, torch.Tensor),
(F._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_transform(self, make_input):
check_transform(transforms.RGB(), make_input(color_space="GRAY"))

@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")

actual = fn(image)
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))

assert_equal(actual, expected, rtol=0, atol=1)

def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")

output_image = F.grayscale_to_rgb(image)
assert_equal(output_image[0][0][0], output_image[1][0][0])
output_image[0][0][0] = output_image[0][0][0] + 1
assert output_image[0][0][0] != output_image[1][0][0]

def test_rgb_image_is_unchanged(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB")
assert_equal(image.shape[-3], 3)
assert_equal(F.grayscale_to_rgb(image), image)


class TestRandomZoomOut:
# Tests are light because this largely relies on the already tested `pad` kernels.

Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
RGB,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
Expand Down
14 changes: 14 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])


class RGB(Transform):
"""Convert images or videos to RGB (if they are already not RGB).
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
"""

def __init__(self):
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.grayscale_to_rgb, inpt)


class ColorJitter(Transform):
"""Randomly change the brightness, contrast, saturation and hue of an image or video.
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
equalize,
equalize_image,
equalize_video,
grayscale_to_rgb,
grayscale_to_rgb_image,
invert,
invert_image,
invert_video,
Expand Down
26 changes: 26 additions & 0 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
return _FP.to_grayscale(image, num_output_channels=num_output_channels)


def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details."""
if torch.jit.is_scripting():
return grayscale_to_rgb_image(inpt)

_log_api_usage_once(grayscale_to_rgb)

kernel = _get_kernel(grayscale_to_rgb, type(inpt))
return kernel(inpt)


@_register_kernel_internal(grayscale_to_rgb, torch.Tensor)
@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image)
def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor:
if image.shape[-3] >= 3:
# Image already has RGB channels. We don't need to do anything.
return image
# rgb_to_grayscale can be used to add channels so we reuse that function.
return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True)


@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image)
def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return image.convert(mode="RGB")


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
Expand Down

0 comments on commit 2bababf

Please sign in to comment.