Skip to content

Commit

Permalink
[BC-breaking] Unified input for RandomPerspective (#2561)
Browse files Browse the repository at this point in the history
* Unified input for RandomPerspective

* Updated docs

* Fixed failing test and bug with torch.randint

* Update test_functional_tensor.py
  • Loading branch information
vfdev-5 authored Aug 8, 2020
1 parent 8c7e7bb commit a75fdd4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 34 deletions.
4 changes: 2 additions & 2 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,10 @@ def test_perspective(self):

num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, spoints, epoints),
ratio_diff_pixels,
Expand Down
17 changes: 17 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,23 @@ def test_random_rotate(self):
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))

def test_random_perspective(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)

for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomPerspective(
distortion_scale=distortion_scale,
interpolation=interpolation
)
s_transform = torch.jit.script(transform)

torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))


if __name__ == '__main__':
unittest.main()
75 changes: 43 additions & 32 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,66 +627,77 @@ def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomPerspective(object):
"""Performs Perspective transformation of the given PIL Image randomly with a given probability.
class RandomPerspective(torch.nn.Module):
"""Performs a random perspective transformation of the given image with a given probability.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
interpolation : Default- Image.BICUBIC
p (float): probability of the image being perspectively transformed. Default value is 0.5
distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Default is 0.5.
p (float): probability of the image being transformed. Default is 0.5.
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively. Default is 0.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively. Default value is 0.
"""

def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0):
super().__init__()
self.p = p
self.interpolation = interpolation
self.distortion_scale = distortion_scale
self.fill = fill

def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Image to be Perspectively transformed.
img (PIL Image or Tensor): Image to be Perspectively transformed.
Returns:
PIL Image: Random perspectivley transformed image.
PIL Image or Tensor: Randomly transformed image.
"""
if not F._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if random.random() < self.p:
width, height = img.size
if torch.rand(1) < self.p:
width, height = F._get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
return img

@staticmethod
def get_params(width, height, distortion_scale):
def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
"""Get parameters for ``perspective`` for a random perspective transform.
Args:
width : width of the image.
height : height of the image.
width (int): width of the image.
height (int): height of the image.
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Returns:
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
"""
half_height = int(height / 2)
half_width = int(width / 2)
topleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(0, int(distortion_scale * half_height)))
topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(0, int(distortion_scale * half_height)))
botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
botleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
half_height = height // 2
half_width = width // 2
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints

Expand Down

0 comments on commit a75fdd4

Please sign in to comment.