diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 89af6dce5d7..cd3ae5a0a82 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,14 +1,17 @@ -import torch -import torchvision.transforms as transforms -import torchvision.transforms.functional_tensor as F_t -import torchvision.transforms.functional_pil as F_pil -import torchvision.transforms.functional as F -import numpy as np import unittest import random import colorsys from PIL import Image +from PIL.Image import NEAREST, BILINEAR, BICUBIC + +import numpy as np + +import torch +import torchvision.transforms as transforms +import torchvision.transforms.functional_tensor as F_t +import torchvision.transforms.functional_pil as F_pil +import torchvision.transforms.functional as F class Tester(unittest.TestCase): @@ -22,6 +25,14 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) self.assertTrue(tensor.equal(pil_tensor), msg) + def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None): + pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor) + mae = torch.abs(tensor - pil_tensor).mean().item() + self.assertTrue( + mae < tol, + msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10]) + ) + def test_vflip(self): script_vflip = torch.jit.script(F_t.vflip) img_tensor = torch.randn(3, 16, 16) @@ -282,6 +293,44 @@ def test_pad(self): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") + def test_resize(self): + script_fn = torch.jit.script(F_t.resize) + tensor, pil_img = self._create_data(26, 36) + + for dt in [None, torch.float32, torch.float64]: + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + for size in [32, [32, ], [32, 32], (32, 32), ]: + for interpolation in [BILINEAR, BICUBIC, NEAREST]: + resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation) + resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation) + + self.assertEqual( + resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) + ) + + if interpolation != NEAREST: + # We can not check values if mode = NEAREST, as results are different + # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] + # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] + resized_tensor_f = resized_tensor + # we need to cast to uint8 to compare with PIL image + if resized_tensor_f.dtype == torch.uint8: + resized_tensor_f = resized_tensor_f.to(torch.float) + + # Pay attention to high tolerance for MAE + self.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation) + ) + + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size + pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation) + self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation)) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 6a8d9930754..9d70744dfc1 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -2,6 +2,7 @@ from torchvision import transforms as T from torchvision.transforms import functional as F from PIL import Image +from PIL.Image import NEAREST, BILINEAR, BICUBIC import numpy as np @@ -217,6 +218,33 @@ def test_ten_crop(self): "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs ) + def test_resize(self): + tensor, _ = self._create_data(height=34, width=36) + script_fn = torch.jit.script(F.resize) + + for dt in [None, torch.float32, torch.float64]: + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + for size in [32, [32, ], [32, 32], (32, 32), ]: + for interpolation in [BILINEAR, BICUBIC, NEAREST]: + + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) + + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size + + s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation) + self.assertTrue(s_resized_tensor.equal(resized_tensor)) + + transform = T.Resize(size=script_size, interpolation=interpolation) + resized_tensor = transform(tensor) + script_transform = torch.jit.script(transform) + s_resized_tensor = script_transform(tensor) + self.assertTrue(s_resized_tensor.equal(resized_tensor)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 9c7efe0ef53..72ca54d7260 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -311,41 +311,29 @@ def normalize(tensor, mean, std, inplace=False): return tensor -def resize(img, size, interpolation=Image.BILINEAR): - r"""Resize the input PIL Image to the given size. +def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: + r"""Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: - img (PIL Image): Image to be resized. + img (PIL Image or Tensor): Image to be resized. size (sequence or int): Desired output size. If size is a sequence like (h, w), the output size will be matched to this. If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. i.e, if height > width, then image will be rescaled to - :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR`` + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation. Default is bilinear. Returns: - PIL Image: Resized image. + PIL Image or Tensor: Resized image. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): - raise TypeError('Got inappropriate size arg: {}'.format(size)) + if not isinstance(img, torch.Tensor): + return F_pil.resize(img, size=size, interpolation=interpolation) - if isinstance(size, int): - w, h = img.size - if (w <= h and w == size) or (h <= w and h == size): - return img - if w < h: - ow = size - oh = int(size * h / w) - return img.resize((ow, oh), interpolation) - else: - oh = size - ow = int(size * w / h) - return img.resize((ow, oh), interpolation) - else: - return img.resize(size[::-1], interpolation) + return F_t.resize(img, size=size, interpolation=interpolation) def scale(*args, **kwargs): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index f1bcda113aa..994988ce1f6 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, List +from typing import Any, List, Sequence import torch try: @@ -286,3 +286,44 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.crop((left, top, left + width, top + height)) + + +@torch.jit.unused +def resize(img, size, interpolation=Image.BILINEAR): + r"""Resize the input PIL Image to the given size. + + Args: + img (PIL Image): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaining + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. + For compatibility reasons with ``functional_tensor.resize``, if a tuple or list of length 1 is provided, + it is interpreted as a single int. + interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int) or len(size) == 1: + if isinstance(size, Sequence): + size = size[0] + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8b64abe9f9c..be0b7b3a622 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -8,6 +8,7 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool: def _get_image_size(img: Tensor) -> List[int]: + """Returns (w, h) of tensor image""" if _is_tensor_a_torch_image(img): return [img.shape[-1], img.shape[-2]] raise TypeError("Unexpected type {}".format(type(img))) @@ -433,6 +434,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con if isinstance(padding, int): if torch.jit.is_scripting(): + # This maybe unreachable raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") pad_left = pad_right = pad_top = pad_bottom = padding elif len(padding) == 1: @@ -480,3 +482,92 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con img = img.to(out_dtype) return img + + +def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: + r"""Resize the input Tensor to the given size. + + Args: + img (Tensor): Image to be resized. + size (int or tuple or list): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaining + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. + In torchscript mode padding as a single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation. Default is bilinear. + + Returns: + Tensor: Resized image. + """ + if not _is_tensor_a_torch_image(img): + raise TypeError("tensor is not a torch image.") + + if not isinstance(size, (int, tuple, list)): + raise TypeError("Got inappropriate size arg") + if not isinstance(interpolation, int): + raise TypeError("Got inappropriate interpolation arg") + + _interpolation_modes = { + 0: "nearest", + 2: "bilinear", + 3: "bicubic", + } + + if interpolation not in _interpolation_modes: + raise ValueError("This interpolation mode is unsupported with Tensor input") + + if isinstance(size, tuple): + size = list(size) + + if isinstance(size, list) and len(size) not in [1, 2]: + raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " + "{} element tuple/list".format(len(size))) + + w, h = _get_image_size(img) + + if isinstance(size, int): + size_w, size_h = size, size + elif len(size) < 2: + size_w, size_h = size[0], size[0] + else: + size_w, size_h = size[0], size[1] + + if isinstance(size, int) or len(size) < 2: + if w < h: + size_h = int(size_w * h / w) + else: + size_w = int(size_h * w / h) + + if (w <= h and w == size_w) or (h <= w and h == size_h): + return img + + # make image NCHW + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + mode = _interpolation_modes[interpolation] + + out_dtype = img.dtype + need_cast = False + if img.dtype not in (torch.float32, torch.float64): + need_cast = True + img = img.to(torch.float32) + + # Define align_corners to avoid warnings + align_corners = False if mode in ["bilinear", "bicubic"] else None + + img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + if mode == "bicubic": + img = img.clamp(min=0, max=255) + img = img.to(out_dtype) + + return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6bc9e7cbc4d..9f4ad8175c6 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -2,7 +2,7 @@ import numbers import random import warnings -from collections.abc import Sequence, Iterable +from collections.abc import Sequence from typing import Tuple, List, Optional import numpy as np @@ -209,31 +209,38 @@ def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) -class Resize(object): - """Resize the input PIL Image to the given size. +class Resize(torch.nn.Module): + """Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to - (size * height / width, size) - interpolation (int, optional): Desired interpolation. Default is - ``PIL.Image.BILINEAR`` + (size * height / width, size). + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, size, interpolation=Image.BILINEAR): - assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError("Size should be int or sequence. Got {}".format(type(size))) + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size self.interpolation = interpolation - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be scaled. + img (PIL Image or Tensor): Image to be scaled. Returns: - PIL Image: Rescaled image. + PIL Image or Tensor: Rescaled image. """ return F.resize(img, self.size, self.interpolation)