diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 2675226d3b7..f4edc0f7f07 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -573,10 +573,10 @@ def test_perspective(self): num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 3% of different pixels + # Tolerance : less than 5% of different pixels self.assertLess( ratio_diff_pixels, - 0.03, + 0.05, msg="{}: {}\n{} vs \n{}".format( (r, spoints, epoints), ratio_diff_pixels, diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 25a008295cd..4f7b8a8094a 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -301,6 +301,23 @@ def test_random_rotate(self): out2 = s_transform(tensor) self.assertTrue(out1.equal(out2)) + def test_random_perspective(self): + tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8) + + for distortion_scale in np.linspace(0.1, 1.0, num=20): + for interpolation in [NEAREST, BILINEAR]: + transform = T.RandomPerspective( + distortion_scale=distortion_scale, + interpolation=interpolation + ) + s_transform = torch.jit.script(transform) + + torch.manual_seed(12) + out1 = transform(tensor) + torch.manual_seed(12) + out2 = s_transform(tensor) + self.assertTrue(out1.equal(out2)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index e4c9101fd14..7b8f9e9601b 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -627,66 +627,77 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomPerspective(object): - """Performs Perspective transformation of the given PIL Image randomly with a given probability. +class RandomPerspective(torch.nn.Module): + """Performs a random perspective transformation of the given image with a given probability. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. Args: - interpolation : Default- Image.BICUBIC - - p (float): probability of the image being perspectively transformed. Default value is 0.5 - - distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5. + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + Default is 0.5. + p (float): probability of the image being transformed. Default is 0.5. + interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and + ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. Default is 0. + This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor + input. Fill value for the area outside the transform in the output image is always 0. - fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. - If int, it is used for all channels respectively. Default value is 0. """ - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0): + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): + super().__init__() self.p = p self.interpolation = interpolation self.distortion_scale = distortion_scale self.fill = fill - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be Perspectively transformed. + img (PIL Image or Tensor): Image to be Perspectively transformed. Returns: - PIL Image: Random perspectivley transformed image. + PIL Image or Tensor: Randomly transformed image. """ - if not F._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if random.random() < self.p: - width, height = img.size + if torch.rand(1) < self.p: + width, height = F._get_image_size(img) startpoints, endpoints = self.get_params(width, height, self.distortion_scale) return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) return img @staticmethod - def get_params(width, height, distortion_scale): + def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: """Get parameters for ``perspective`` for a random perspective transform. Args: - width : width of the image. - height : height of the image. + width (int): width of the image. + height (int): height of the image. + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. Returns: List containing [top-left, top-right, bottom-right, bottom-left] of the original image, List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. """ - half_height = int(height / 2) - half_width = int(width / 2) - topleft = (random.randint(0, int(distortion_scale * half_width)), - random.randint(0, int(distortion_scale * half_height))) - topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), - random.randint(0, int(distortion_scale * half_height))) - botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), - random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) - botleft = (random.randint(0, int(distortion_scale * half_width)), - random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) - startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] + half_height = height // 2 + half_width = width // 2 + topleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + ] + topright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + ] + botright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + ] + botleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] return startpoints, endpoints