Skip to content

Commit

Permalink
Unified inputs for grayscale ops and transforms (#2586)
Browse files Browse the repository at this point in the history
* [WIP] Unify ops Grayscale and RandomGrayscale

* Unified inputs for grayscale op and transforms
- deprecated F.to_grayscale in favor of F.rgb_to_grayscale

* Fixes bug with fp input

* [WIP] Updated code according to review

* Removed unused import
  • Loading branch information
vfdev-5 authored Aug 28, 2020
1 parent 279fca5 commit 2eba1f0
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 74 deletions.
9 changes: 6 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,12 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
err = getattr(torch, method)(tensor - pil_tensor).item()
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean"):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
err = getattr(torch, agg_method)(tensor - pil_tensor).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
Expand Down
33 changes: 22 additions & 11 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,29 @@ def test_adjustments(self):
def test_adjustments_cuda(self):
self._test_adjustments("cuda")

def _test_rgb_to_grayscale(self, device):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

img_tensor, pil_img = self._create_data(32, 34, device=device)

for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

if num_output_channels == 1:
print(gray_tensor.shape)

self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
self._test_rgb_to_grayscale("cpu")

@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_rgb_to_grayscale_cuda(self):
self._test_rgb_to_grayscale("cuda")

def _test_center_crop(self, device):
script_center_crop = torch.jit.script(F.center_crop)
Expand Down
71 changes: 46 additions & 25 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

class Tester(TransformsTester):

def _test_functional_geom_op(self, func, fn_kwargs):
def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

def _test_class_geom_op(self, method, meth_kwargs=None):
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
meth_kwargs = {}

Expand All @@ -35,21 +35,24 @@ def _test_class_geom_op(self, method, meth_kwargs=None):
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)

torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))

def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_geom_op(func, fn_kwargs)
self._test_class_geom_op(method, meth_kwargs)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)

def test_random_horizontal_flip(self):
self._test_geom_op('hflip', 'RandomHorizontalFlip')
self._test_op('hflip', 'RandomHorizontalFlip')

def test_random_vertical_flip(self):
self._test_geom_op('vflip', 'RandomVerticalFlip')
self._test_op('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
Expand Down Expand Up @@ -80,30 +83,30 @@ def test_adjustments(self):
def test_pad(self):

# Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_geom_op(
self._test_functional_op(
"pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
)
# Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op(
self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

Expand All @@ -120,17 +123,17 @@ def test_crop(self):
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_geom_op("RandomCrop", config)
self._test_class_op("RandomCrop", config)

def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
Expand All @@ -149,7 +152,7 @@ def test_center_crop(self):
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
Expand Down Expand Up @@ -178,37 +181,37 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me

def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

Expand Down Expand Up @@ -312,6 +315,24 @@ def test_random_perspective(self):
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))

def test_to_grayscale(self):

meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

meth_kwargs = {"num_output_channels": 3}
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

meth_kwargs = {}
self._test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)


if __name__ == '__main__':
unittest.main()
54 changes: 39 additions & 15 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]:
return F_pil._get_image_size(img)


def _get_image_num_channels(img: Tensor) -> int:
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)

return F_pil._get_image_num_channels(img)


@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
Expand Down Expand Up @@ -951,32 +958,49 @@ def affine(
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)


@torch.jit.unused
def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL Image): PIL Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels)

return img
raise TypeError("Input should be PIL Image")


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert RGB image to grayscale version of image.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image or Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)

return F_t.rgb_to_grayscale(img, num_output_channels)


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
Expand Down
37 changes: 37 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def _get_image_size(img: Any) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
Expand Down Expand Up @@ -480,3 +487,33 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
opts = _parse_fill(fill, img, '5.0.0')

return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)


@torch.jit.unused
def to_grayscale(img, num_output_channels):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
Args:
img (PIL Image): Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')

return img
Loading

0 comments on commit 2eba1f0

Please sign in to comment.