From 2eba1f045af9afa646b9d0ec937d4eaa614b5c47 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 28 Aug 2020 10:46:30 +0200 Subject: [PATCH] Unified inputs for grayscale ops and transforms (#2586) * [WIP] Unify ops Grayscale and RandomGrayscale * Unified inputs for grayscale op and transforms - deprecated F.to_grayscale in favor of F.rgb_to_grayscale * Fixes bug with fp input * [WIP] Updated code according to review * Removed unused import --- test/common_utils.py | 9 ++- test/test_functional_tensor.py | 33 ++++++---- test/test_transforms_tensor.py | 71 +++++++++++++-------- torchvision/transforms/functional.py | 54 +++++++++++----- torchvision/transforms/functional_pil.py | 37 +++++++++++ torchvision/transforms/functional_tensor.py | 42 ++++++++++-- torchvision/transforms/transforms.py | 34 ++++++---- 7 files changed, 206 insertions(+), 74 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 90583f23921..13e3561f19b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -350,9 +350,12 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None): msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) self.assertTrue(tensor.cpu().equal(pil_tensor), msg) - def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"): - pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor) - err = getattr(torch, method)(tensor - pil_tensor).item() + def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean"): + np_pil_image = np.array(pil_image) + if np_pil_image.ndim == 2: + np_pil_image = np_pil_image[:, :, None] + pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor) + err = getattr(torch, agg_method)(tensor - pil_tensor).item() self.assertTrue( err < tol, msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10]) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 17834838c17..68359bc0437 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -194,18 +194,29 @@ def test_adjustments(self): def test_adjustments_cuda(self): self._test_adjustments("cuda") + def _test_rgb_to_grayscale(self, device): + script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) + + img_tensor, pil_img = self._create_data(32, 34, device=device) + + for num_output_channels in (3, 1): + gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) + gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) + + if num_output_channels == 1: + print(gray_tensor.shape) + + self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") + + s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) + self.assertTrue(s_gray_tensor.equal(gray_tensor)) + def test_rgb_to_grayscale(self): - script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale) - img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) - img_tensor_clone = img_tensor.clone() - grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int) - grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int) - max_diff = (grayscale_tensor - grayscale_pil_img).abs().max() - self.assertLess(max_diff, 1.0001) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) - # scriptable function test - grayscale_script = script_rgb_to_grayscale(img_tensor).to(int) - self.assertTrue(torch.equal(grayscale_script, grayscale_tensor)) + self._test_rgb_to_grayscale("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") + def test_rgb_to_grayscale_cuda(self): + self._test_rgb_to_grayscale("cuda") def _test_center_crop(self, device): script_center_crop = torch.jit.script(F.center_crop) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index b31e1424525..245a27954f9 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -13,7 +13,7 @@ class Tester(TransformsTester): - def _test_functional_geom_op(self, func, fn_kwargs): + def _test_functional_op(self, func, fn_kwargs): if fn_kwargs is None: fn_kwargs = {} tensor, pil_img = self._create_data(height=10, width=10) @@ -21,7 +21,7 @@ def _test_functional_geom_op(self, func, fn_kwargs): transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) self.compareTensorToPIL(transformed_tensor, transformed_pil_img) - def _test_class_geom_op(self, method, meth_kwargs=None): + def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs): if meth_kwargs is None: meth_kwargs = {} @@ -35,21 +35,24 @@ def _test_class_geom_op(self, method, meth_kwargs=None): transformed_tensor = f(tensor) torch.manual_seed(12) transformed_pil_img = f(pil_img) - self.compareTensorToPIL(transformed_tensor, transformed_pil_img) + if test_exact_match: + self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs) + else: + self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs) torch.manual_seed(12) transformed_tensor_script = scripted_fn(tensor) self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) - def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None): - self._test_functional_geom_op(func, fn_kwargs) - self._test_class_geom_op(method, meth_kwargs) + def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None): + self._test_functional_op(func, fn_kwargs) + self._test_class_op(method, meth_kwargs) def test_random_horizontal_flip(self): - self._test_geom_op('hflip', 'RandomHorizontalFlip') + self._test_op('hflip', 'RandomHorizontalFlip') def test_random_vertical_flip(self): - self._test_geom_op('vflip', 'RandomVerticalFlip') + self._test_op('vflip', 'RandomVerticalFlip') def test_adjustments(self): fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation'] @@ -80,22 +83,22 @@ def test_adjustments(self): def test_pad(self): # Test functional.pad (PIL and Tensor) with padding as single int - self._test_functional_geom_op( + self._test_functional_op( "pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"} ) # Test functional.pad and transforms.Pad with padding as [int, ] fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"} - self._test_geom_op( + self._test_op( "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) # Test functional.pad and transforms.Pad with padding as list fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"} - self._test_geom_op( + self._test_op( "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) # Test functional.pad and transforms.Pad with padding as tuple fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"} - self._test_geom_op( + self._test_op( "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) @@ -103,7 +106,7 @@ def test_crop(self): fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} # Test transforms.RandomCrop with size and padding as tuple meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } - self._test_geom_op( + self._test_op( 'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) @@ -120,17 +123,17 @@ def test_crop(self): for padding_config in padding_configs: config = dict(padding_config) config["size"] = size - self._test_class_geom_op("RandomCrop", config) + self._test_class_op("RandomCrop", config) def test_center_crop(self): fn_kwargs = {"output_size": (4, 5)} meth_kwargs = {"size": (4, 5), } - self._test_geom_op( + self._test_op( "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = {"output_size": (5,)} meth_kwargs = {"size": (5, )} - self._test_geom_op( + self._test_op( "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8) @@ -149,7 +152,7 @@ def test_center_crop(self): scripted_fn = torch.jit.script(f) scripted_fn(tensor) - def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): + def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): if fn_kwargs is None: fn_kwargs = {} if meth_kwargs is None: @@ -178,37 +181,37 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me def test_five_crop(self): fn_kwargs = meth_kwargs = {"size": (5,)} - self._test_geom_op_list_output( + self._test_op_list_output( "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = meth_kwargs = {"size": [5, ]} - self._test_geom_op_list_output( + self._test_op_list_output( "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = meth_kwargs = {"size": (4, 5)} - self._test_geom_op_list_output( + self._test_op_list_output( "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = meth_kwargs = {"size": [4, 5]} - self._test_geom_op_list_output( + self._test_op_list_output( "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) def test_ten_crop(self): fn_kwargs = meth_kwargs = {"size": (5,)} - self._test_geom_op_list_output( + self._test_op_list_output( "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = meth_kwargs = {"size": [5, ]} - self._test_geom_op_list_output( + self._test_op_list_output( "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = meth_kwargs = {"size": (4, 5)} - self._test_geom_op_list_output( + self._test_op_list_output( "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) fn_kwargs = meth_kwargs = {"size": [4, 5]} - self._test_geom_op_list_output( + self._test_op_list_output( "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) @@ -312,6 +315,24 @@ def test_random_perspective(self): out2 = s_transform(tensor) self.assertTrue(out1.equal(out2)) + def test_to_grayscale(self): + + meth_kwargs = {"num_output_channels": 1} + tol = 1.0 + 1e-10 + self._test_class_op( + "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" + ) + + meth_kwargs = {"num_output_channels": 3} + self._test_class_op( + "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" + ) + + meth_kwargs = {} + self._test_class_op( + "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 06b2a0e1f80..4a36e0b05e6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]: return F_pil._get_image_size(img) +def _get_image_num_channels(img: Tensor) -> int: + if isinstance(img, torch.Tensor): + return F_t._get_image_num_channels(img) + + return F_pil._get_image_num_channels(img) + + @torch.jit.unused def _is_numpy(img: Any) -> bool: return isinstance(img, np.ndarray) @@ -951,11 +958,13 @@ def affine( return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) +@torch.jit.unused def to_grayscale(img, num_output_channels=1): - """Convert image to grayscale version of image. + """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. Args: - img (PIL Image): Image to be converted to grayscale. + img (PIL Image): PIL Image to be converted to grayscale. + num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. Returns: PIL Image: Grayscale version of the image. @@ -963,20 +972,35 @@ def to_grayscale(img, num_output_channels=1): if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if num_output_channels == 1: - img = img.convert('L') - elif num_output_channels == 3: - img = img.convert('L') - np_img = np.array(img, dtype=np.uint8) - np_img = np.dstack([np_img, np_img, np_img]) - img = Image.fromarray(np_img, 'RGB') - else: - raise ValueError('num_output_channels should be either 1 or 3') + if isinstance(img, Image.Image): + return F_pil.to_grayscale(img, num_output_channels) - return img + raise TypeError("Input should be PIL Image") + + +def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: + """Convert RGB image to grayscale version of image. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Note: + Please, note that this method supports only RGB images as input. For inputs in other color spaces, + please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image. + + Args: + img (PIL Image or Tensor): RGB Image to be converted to grayscale. + num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. + + Returns: + PIL Image or Tensor: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + if not isinstance(img, torch.Tensor): + return F_pil.to_grayscale(img, num_output_channels) + + return F_t.rgb_to_grayscale(img, num_output_channels) def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index f1e8504f874..ba620ab9d9c 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -26,6 +26,13 @@ def _get_image_size(img: Any) -> List[int]: raise TypeError("Unexpected type {}".format(type(img))) +@torch.jit.unused +def _get_image_num_channels(img: Any) -> int: + if _is_pil_image(img): + return 1 if img.mode == 'L' else 3 + raise TypeError("Unexpected type {}".format(type(img))) + + @torch.jit.unused def hflip(img): """Horizontally flip the given PIL Image. @@ -480,3 +487,33 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None) opts = _parse_fill(fill, img, '5.0.0') return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) + + +@torch.jit.unused +def to_grayscale(img, num_output_channels): + """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. + + Args: + img (PIL Image): Image to be converted to grayscale. + num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. + + Returns: + PIL Image: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if num_output_channels == 1: + img = img.convert('L') + elif num_output_channels == 3: + img = img.convert('L') + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + else: + raise ValueError('num_output_channels should be either 1 or 3') + + return img diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 36a12280310..6b581abd8d9 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -18,6 +18,15 @@ def _get_image_size(img: Tensor) -> List[int]: raise TypeError("Unexpected type {}".format(type(img))) +def _get_image_num_channels(img: Tensor) -> int: + if img.ndim == 2: + return 1 + elif img.ndim > 2: + return img.shape[-3] + + raise TypeError("Unexpected type {}".format(type(img))) + + def vflip(img: Tensor) -> Tensor: """Vertically flip the given the Image Tensor. @@ -67,22 +76,41 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: return img[..., top:top + height, left:left + width] -def rgb_to_grayscale(img: Tensor) -> Tensor: +def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: """Convert the given RGB Image Tensor to Grayscale. For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is L = R * 0.2989 + G * 0.5870 + B * 0.1140 Args: img (Tensor): Image to be converted to Grayscale in the form [C, H, W]. + num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. Returns: - Tensor: Grayscale image. + Tensor: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b """ - if img.shape[0] != 3: - raise TypeError('Input Image does not contain 3 Channels') + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + c = img.shape[-3] + if c != 3: + raise TypeError("Input image tensor should 3 channels, but found {}".format(c)) + + if num_output_channels not in (1, 3): + raise ValueError('num_output_channels should be either 1 or 3') + + r, g, b = img.unbind(dim=-3) + # This implementation closely follows the TF one: + # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) + l_img = l_img.unsqueeze(dim=-3) + + if num_output_channels == 3: + return l_img.expand(img.shape) - return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) + return l_img def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: @@ -373,8 +401,8 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: - bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 - return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) + bound = 1.0 if img1.is_floating_point() else 255.0 + return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) def _rgb2hsv(img): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 7b8f9e9601b..b995101c3c7 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1354,8 +1354,11 @@ def __repr__(self): return s.format(name=self.__class__.__name__, **d) -class Grayscale(object): +class Grayscale(torch.nn.Module): """Convert image to grayscale. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading + dimensions Args: num_output_channels (int): (1 or 3) number of channels desired for output image @@ -1368,30 +1371,34 @@ class Grayscale(object): """ def __init__(self, num_output_channels=1): + super().__init__() self.num_output_channels = num_output_channels - def __call__(self, img): + def forward(self, img: Tensor) -> Tensor: """ Args: - img (PIL Image): Image to be converted to grayscale. + img (PIL Image or Tensor): Image to be converted to grayscale. Returns: - PIL Image: Randomly grayscaled image. + PIL Image or Tensor: Grayscaled image. """ - return F.to_grayscale(img, num_output_channels=self.num_output_channels) + return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) def __repr__(self): return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) -class RandomGrayscale(object): +class RandomGrayscale(torch.nn.Module): """Randomly convert image to grayscale with a probability of p (default 0.1). + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability that image should be converted to grayscale. Returns: - PIL Image: Grayscale version of the input image with probability p and unchanged + PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged with probability (1-p). - If input image is 1 channel: grayscale version is 1 channel - If input image is 3 channel: grayscale version is 3 channel with r == g == b @@ -1399,19 +1406,20 @@ class RandomGrayscale(object): """ def __init__(self, p=0.1): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img: Tensor) -> Tensor: """ Args: - img (PIL Image): Image to be converted to grayscale. + img (PIL Image or Tensor): Image to be converted to grayscale. Returns: - PIL Image: Randomly grayscaled image. + PIL Image or Tensor: Randomly grayscaled image. """ - num_output_channels = 1 if img.mode == 'L' else 3 - if random.random() < self.p: - return F.to_grayscale(img, num_output_channels=num_output_channels) + num_output_channels = F._get_image_num_channels(img) + if torch.rand(1) < self.p: + return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return img def __repr__(self):