From 62399885069091d8a1473b9076de35c1c317e91b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 1 Jun 2020 14:57:28 +0200 Subject: [PATCH 1/2] Make RandomHorizontalFlip torchscriptable --- test/test_transforms_tensor.py | 38 +++++++++++++++++++++ torchvision/transforms/functional.py | 19 +++++++---- torchvision/transforms/functional_pil.py | 30 ++++++++++++++++ torchvision/transforms/functional_tensor.py | 1 - torchvision/transforms/transforms.py | 16 +++++---- 5 files changed, 91 insertions(+), 13 deletions(-) create mode 100644 test/test_transforms_tensor.py create mode 100644 torchvision/transforms/functional_pil.py diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py new file mode 100644 index 00000000000..eeaec1ac9fc --- /dev/null +++ b/test/test_transforms_tensor.py @@ -0,0 +1,38 @@ +import torch +from torchvision import transforms as T +from torchvision.transforms import functional as F +from PIL import Image + +import numpy as np + +import unittest + + +class Tester(unittest.TestCase): + def _create_data(self, height=3, width=3, channels=3): + tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) + pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) + return tensor, pil_img + + def compareTensorToPIL(self, tensor, pil_image): + pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + self.assertTrue(tensor.equal(pil_tensor)) + + def test_random_horizontal_flip(self): + tensor, pil_img = self._create_data() + flip_tensor = F.hflip(tensor) + flip_pil_img = F.hflip(pil_img) + self.compareTensorToPIL(flip_tensor, flip_pil_img) + + scripted_fn = torch.jit.script(F.hflip) + flip_tensor_script = scripted_fn(tensor) + self.assertTrue(flip_tensor.equal(flip_tensor_script)) + + # test for class interface + f = T.RandomHorizontalFlip() + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7f22fc51391..fe2ac048fec 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor import math from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION try: @@ -11,6 +12,9 @@ from collections.abc import Sequence, Iterable import warnings +from . import functional_pil as F_pil +from . import functional_tensor as F_t + def _is_pil_image(img): if accimage is not None: @@ -428,19 +432,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE return img -def hflip(img): - """Horizontally flip the given PIL Image. +def hflip(img: Tensor) -> Tensor: + """Horizontally flip the given PIL Image or torch Tensor. Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Torch Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. Returns: PIL Image: Horizontally flipped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.hflip(img) - return img.transpose(Image.FLIP_LEFT_RIGHT) + return F_t.hflip(img) def _parse_fill(fill, img, min_pil_version): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py new file mode 100644 index 00000000000..00200212e4d --- /dev/null +++ b/torchvision/transforms/functional_pil.py @@ -0,0 +1,30 @@ +import torch +try: + import accimage +except ImportError: + accimage = None +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION + + +@torch.jit.unused +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +@torch.jit.unused +def hflip(img): + """Horizontally flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Horizontally flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index b81deed6d43..9369e0fb562 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,4 @@ import torch -import torchvision.transforms.functional as F from torch import Tensor from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 49fac26e395..bad2f9ab3f8 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -500,25 +500,29 @@ def __repr__(self): return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) -class RandomHorizontalFlip(object): - """Horizontally flip the given PIL Image randomly with a given probability. +class RandomHorizontalFlip(torch.nn.Module): + """Horizontally flip the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ - if random.random() < self.p: + if torch.rand(1) < self.p: return F.hflip(img) return img From 016784b15824ee806b6e09d5053c5f13428c6fb2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 1 Jun 2020 15:01:12 +0200 Subject: [PATCH 2/2] Make _is_tensor_a_torch_image more generic --- torchvision/transforms/functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 9369e0fb562..c0815393c37 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -4,7 +4,7 @@ def _is_tensor_a_torch_image(input): - return len(input.shape) == 3 + return input.ndim >= 2 def vflip(img):