Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] Unified inputs for grayscale ops and transforms #2586

Merged
merged 10 commits into from
Aug 28, 2020
32 changes: 19 additions & 13 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
import random
import colorsys
import math

Expand All @@ -23,7 +22,10 @@ def _create_data(self, height=3, width=3, channels=3):
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg)
Expand Down Expand Up @@ -187,17 +189,21 @@ def test_adjustments(self):
scripted_fn(img)

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

img_tensor, pil_img = self._create_data(32, 34)

for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

if num_output_channels == 1:
print(gray_tensor.shape)

self.compareTensorToPIL(gray_tensor, gray_pil_image)

s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))

def test_center_crop(self):
script_center_crop = torch.jit.script(F.center_crop)
Expand Down
64 changes: 39 additions & 25 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@ def _create_data(self, height=3, width=3, channels=3):
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
pil_tensor = np.array(pil_image)
if pil_tensor.ndim == 2:
pil_tensor = pil_tensor[:, :, None]
pil_tensor = torch.as_tensor(pil_tensor.transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))

def _test_functional_geom_op(self, func, fn_kwargs):
def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

def _test_class_geom_op(self, method, meth_kwargs=None):
def _test_class_op(self, method, meth_kwargs=None):
if meth_kwargs is None:
meth_kwargs = {}

Expand All @@ -47,15 +50,15 @@ def _test_class_geom_op(self, method, meth_kwargs=None):
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))

def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_geom_op(func, fn_kwargs)
self._test_class_geom_op(method, meth_kwargs)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)

def test_random_horizontal_flip(self):
self._test_geom_op('hflip', 'RandomHorizontalFlip')
self._test_op('hflip', 'RandomHorizontalFlip')

def test_random_vertical_flip(self):
self._test_geom_op('vflip', 'RandomVerticalFlip')
self._test_op('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
Expand Down Expand Up @@ -86,30 +89,30 @@ def test_adjustments(self):
def test_pad(self):

# Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_geom_op(
self._test_functional_op(
"pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
)
# Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op(
self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

Expand All @@ -126,17 +129,17 @@ def test_crop(self):
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_geom_op("RandomCrop", config)
self._test_class_op("RandomCrop", config)

def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
Expand All @@ -155,7 +158,7 @@ def test_center_crop(self):
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
Expand Down Expand Up @@ -184,37 +187,37 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me

def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

Expand Down Expand Up @@ -318,6 +321,17 @@ def test_random_perspective(self):
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))

def test_to_grayscale(self):

fn_kwargs = meth_kwargs = {"num_output_channels": 1}
self._test_op("rgb_to_grayscale", "Grayscale", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)

fn_kwargs = meth_kwargs = {"num_output_channels": 3}
self._test_op("rgb_to_grayscale", "Grayscale", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)

meth_kwargs = {}
self._test_class_op("RandomGrayscale", meth_kwargs=meth_kwargs)


if __name__ == '__main__':
unittest.main()
61 changes: 44 additions & 17 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]:
return F_pil._get_image_size(img)


def _get_image_num_channels(img: Tensor) -> int:
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)

return F_pil._get_image_num_channels(img)


@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
Expand Down Expand Up @@ -951,32 +958,52 @@ def affine(
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)


def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.
def to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""DEPRECATED. Convert RGB image to grayscale version of image.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

.. warning::

This method is deprecated and will be removed in future releases.
Please, use ``F.rgb_to_grayscale`` instead.


Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image: Grayscale version of the image.
PIL Image or Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
warnings.warn("The use of the F.to_grayscale transform is deprecated, " +
"please use F.rgb_to_grayscale instead.")

return img
return rgb_to_grayscale(img, num_output_channels)


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert RGB image to grayscale version of image.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image or Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.rgb_to_grayscale(img, num_output_channels)

return F_t.rgb_to_grayscale(img, num_output_channels)


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
Expand Down
57 changes: 57 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
from typing import Any, List, Sequence

import numpy as np
Expand Down Expand Up @@ -26,6 +27,13 @@ def _get_image_size(img: Any) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
Expand Down Expand Up @@ -480,3 +488,52 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
opts = _parse_fill(fill, img, '5.0.0')

return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)


@torch.jit.unused
def to_grayscale(img, num_output_channels):
"""DEPRECATED. Convert RGB image to grayscale version of image.

Args:
img (PIL Image): Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
warnings.warn("The use of the F_pil.to_grayscale transform is deprecated, " +
"please use F.rgb_to_grayscale instead.")
return rgb_to_grayscale(img, num_output_channels)


@torch.jit.unused
def rgb_to_grayscale(img, num_output_channels):
"""Convert RGB image to grayscale version of image.

Args:
img (PIL Image): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')

return img
Loading