Skip to content

Commit

Permalink
Unified inputs for T.RandomAffine transformation (2292) (pytorch#2478)
Browse files Browse the repository at this point in the history
* [WIP] Unified input for T.RandomAffine

* Unified inputs for T.RandomAffine transformation

* Update transforms.py

* Updated docs of F.affine fillcolor

* Update transforms.py
  • Loading branch information
vfdev-5 authored and bryant1410 committed Nov 22, 2020
1 parent 4f9ff20 commit aaefcc5
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 69 deletions.
49 changes: 34 additions & 15 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,21 +248,40 @@ def test_resize(self):
def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)

scale = (0.7, 1.2)
ratio = (0.75, 1.333)

for size in [(32, ), [32, ], [32, 32], (32, 32)]:
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)

torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
for scale in [(0.7, 1.2), [0.7, 1.2]]:
for ratio in [(0.75, 1.333), [0.75, 1.333]]:
for size in [(32, ), [32, ], [32, 32], (32, 32)]:
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)

torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))

def test_random_affine(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)

for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
for scale in [(0.7, 1.2), [0.7, 1.2]]:
for translate in [(0.1, 0.2), [0.2, 0.1]]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomAffine(
degrees=degrees, translate=translate,
scale=scale, shear=shear, resample=interpolation
)
s_transform = torch.jit.script(transform)

torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))


if __name__ == '__main__':
Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,9 @@ def affine(
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0.
Returns:
PIL Image or Tensor: Transformed image.
Expand Down
124 changes: 71 additions & 53 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Sequence
from typing import Tuple, List, Optional

import numpy as np
import torch
from PIL import Image
from torch import Tensor
Expand Down Expand Up @@ -721,9 +720,9 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size

if not isinstance(scale, (tuple, list)):
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
Expand All @@ -734,14 +733,14 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat

@staticmethod
def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float]
img: Tensor, scale: List[float], ratio: List[float]
) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image or Tensor): Input image.
scale (tuple): range of scale of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
scale (list): range of scale of the origin size cropped
ratio (list): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Expand All @@ -751,7 +750,7 @@ def get_params(
area = height * width

for _ in range(10):
target_area = area * torch.empty(1).uniform_(*scale).item()
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
Expand Down Expand Up @@ -1173,8 +1172,10 @@ def __repr__(self):
return format_string


class RandomAffine(object):
"""Random affine transformation of the image keeping center invariant
class RandomAffine(torch.nn.Module):
"""Random affine transformation of the image keeping center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
degrees (sequence or float or int): Range of degrees to select from.
Expand All @@ -1188,41 +1189,51 @@ class RandomAffine(object):
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or float or int, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
outside the transform in the output image.(Pillow>=5.0.0)
Will not apply shear by default.
resample (int, optional): An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""

def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0):
super().__init__()
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
degrees = [-degrees, degrees]
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
if not isinstance(degrees, Sequence):
raise TypeError("degrees should be a sequence of length 2.")
if len(degrees) != 2:
raise ValueError("degrees should be sequence of length 2.")

self.degrees = [float(d) for d in degrees]

if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
if not isinstance(translate, Sequence):
raise TypeError("translate should be a sequence of length 2.")
if len(translate) != 2:
raise ValueError("translate should be sequence of length 2.")
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate

if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
if not isinstance(scale, Sequence):
raise TypeError("scale should be a sequence of length 2.")
if len(scale) != 2:
raise ValueError("scale should be sequence of length 2.")

for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
Expand All @@ -1232,62 +1243,69 @@ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Fal
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
shear = [-shear, shear]
else:
assert isinstance(shear, (tuple, list)) and \
(len(shear) == 2 or len(shear) == 4), \
"shear should be a list or tuple and it must be of length 2 or 4."
# X-Axis shear with [min, max]
if len(shear) == 2:
self.shear = [shear[0], shear[1], 0., 0.]
elif len(shear) == 4:
self.shear = [s for s in shear]
if not isinstance(shear, Sequence):
raise TypeError("shear should be a sequence of length 2 or 4.")
if len(shear) not in (2, 4):
raise ValueError("shear should be sequence of length 2 or 4.")

self.shear = [float(s) for s in shear]
else:
self.shear = shear

self.resample = resample
self.fillcolor = fillcolor

@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
def get_params(
degrees: List[float],
translate: Optional[List[float]],
scale_ranges: Optional[List[float]],
shears: Optional[List[float]],
img_size: List[int]
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
"""Get parameters for affine transformation
Returns:
sequence: params to be passed to the affine transformation
params to be passed to the affine transformation
"""
angle = random.uniform(degrees[0], degrees[1])
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
max_dx = float(translate[0] * img_size[0])
max_dy = float(translate[1] * img_size[1])
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
else:
translations = (0, 0)

if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
else:
scale = 1.0

shear_x = shear_y = 0.0
if shears is not None:
if len(shears) == 2:
shear = [random.uniform(shears[0], shears[1]), 0.]
elif len(shears) == 4:
shear = [random.uniform(shears[0], shears[1]),
random.uniform(shears[2], shears[3])]
else:
shear = 0.0
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
if len(shears) == 4:
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

shear = (shear_x, shear_y)

return angle, translations, scale, shear

def __call__(self, img):
def forward(self, img):
"""
img (PIL Image): Image to be transformed.
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image: Affine transformed image.
PIL Image or Tensor: Affine transformed image.
"""
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)

img_size = F._get_image_size(img)

ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)

def __repr__(self):
Expand Down

0 comments on commit aaefcc5

Please sign in to comment.