Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random affine transformation #411

Merged
merged 4 commits into from
Feb 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 198 additions & 66 deletions test/sanity_checks.ipynb

Large diffs are not rendered by default.

128 changes: 126 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import unittest
import math
import random
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -846,6 +847,88 @@ def test_rotate(self):

assert np.all(np.array(result_a) == np.array(result_b))

def test_affine(self):
input_img = np.zeros((200, 200, 3), dtype=np.uint8)
pts = []
cnt = [100, 100]
for pt in [(80, 80), (100, 80), (100, 100)]:
for i in range(-5, 5):
for j in range(-5, 5):
input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
pts.append((pt[0] + i, pt[1] + j))
pts = list(set(pts))

with self.assertRaises(TypeError):
F.affine(input_img, 10)

pil_img = F.to_pil_image(input_img)

def _to_3x3_inv(inv_result_matrix):
result_matrix = np.zeros((3, 3))
result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
result_matrix[2, 2] = 1
return np.linalg.inv(result_matrix)

def _test_transformation(a, t, s, sh):
a_rad = math.radians(a)
s_rad = math.radians(sh)
# 1) Check transformation matrix:
c_matrix = np.array([[1.0, 0.0, cnt[0]], [0.0, 1.0, cnt[1]], [0.0, 0.0, 1.0]])
c_inv_matrix = np.linalg.inv(c_matrix)
t_matrix = np.array([[1.0, 0.0, t[0]],
[0.0, 1.0, t[1]],
[0.0, 0.0, 1.0]])
r_matrix = np.array([[s * math.cos(a_rad), -s * math.sin(a_rad + s_rad), 0.0],
[s * math.sin(a_rad), s * math.cos(a_rad + s_rad), 0.0],
[0.0, 0.0, 1.0]])
true_matrix = np.dot(t_matrix, np.dot(c_matrix, np.dot(r_matrix, c_inv_matrix)))
result_matrix = _to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=a,
translate=t, scale=s, shear=sh))
assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10
# 2) Perform inverse mapping:
true_result = np.zeros((200, 200, 3), dtype=np.uint8)
inv_true_matrix = np.linalg.inv(true_matrix)
for y in range(true_result.shape[0]):
for x in range(true_result.shape[1]):
res = np.dot(inv_true_matrix, [x, y, 1])
_x = int(res[0] + 0.5)
_y = int(res[1] + 0.5)
if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
true_result[y, x, :] = input_img[_y, _x, :]

result = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh)
assert result.size == pil_img.size
# Compute number of different pixels:
np_result = np.array(result)
n_diff_pixels = np.sum(np_result != true_result) / 3
# Accept 3 wrong pixels
assert n_diff_pixels < 3, \
"a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) +\
"n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0]))

# Test rotation
a = 45
_test_transformation(a=a, t=(0, 0), s=1.0, sh=0.0)

# Test translation
t = [10, 15]
_test_transformation(a=0.0, t=t, s=1.0, sh=0.0)

# Test scale
s = 1.2
_test_transformation(a=0.0, t=(0.0, 0.0), s=s, sh=0.0)

# Test shear
sh = 45.0
_test_transformation(a=0.0, t=(0.0, 0.0), s=1.0, sh=sh)

# Test rotation, scale, translation, shear
for a in range(-90, 90, 25):
for t1 in range(-10, 10, 5):
for s in [0.75, 0.98, 1.0, 1.1, 1.2]:
for sh in range(-15, 15, 5):
_test_transformation(a=a, t=(t1, t1), s=s, sh=sh)

def test_random_rotation(self):

with self.assertRaises(ValueError):
Expand All @@ -864,6 +947,47 @@ def test_random_rotation(self):
# Checking if RandomRotation can be printed as string
t.__repr__()

def test_random_affine(self):

with self.assertRaises(ValueError):
transforms.RandomAffine(-0.7)
transforms.RandomAffine([-0.7])
transforms.RandomAffine([-0.7, 0, 0.7])

transforms.RandomAffine([-90, 90], translate=2.0)
transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])

transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])

transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])

x = np.zeros((100, 100, 3), dtype=np.uint8)
img = F.to_pil_image(x)

t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10])
for _ in range(100):
angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear,
img_size=img.size)
assert -10 < angle < 10
assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, \
"{} vs {}".format(translations[0], img.size[0] * 0.5)
assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, \
"{} vs {}".format(translations[1], img.size[1] * 0.5)
assert 0.7 < scale < 1.3
assert -10 < shear < 10

# Checking if RandomAffine can be printed as string
t.__repr__()

t = transforms.RandomAffine(10, resample=Image.BILINEAR)
assert "Image.BILINEAR" in t.__repr__()

def test_to_grayscale(self):
"""Unit tests for grayscale transform"""

Expand Down Expand Up @@ -933,8 +1057,8 @@ def test_random_grayscale(self):
gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
gray_np_2 = np.array(gray_pil_2)
if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
np.array_equal(gray_np, gray_np_2[:, :, 0]):
np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
np.array_equal(gray_np, gray_np_2[:, :, 0]):
num_gray = num_gray + 1

p_value = stats.binom_test(num_gray, num_samples, p=0.5)
Expand Down
64 changes: 64 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,70 @@ def rotate(img, angle, resample=False, expand=False, center=None):
return img.rotate(angle, resample, expand, center)


def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
# Helper method to compute inverse matrix for affine transformation

# As it is explained in PIL.Image.rotate
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RSS is rotation with scale and shear matrix
# RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
# [ sin(a)*scale cos(a + shear)*scale 0]
# [ 0 0 1]
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1

angle = math.radians(angle)
shear = math.radians(shear)
scale = 1.0 / scale

# Inverted rotation matrix with scale and shear
d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
matrix = [
math.cos(angle + shear), math.sin(angle + shear), 0,
-math.sin(angle), math.cos(angle), 0
]
matrix = [scale / d * m for m in matrix]

# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])

# Apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += center[0]
matrix[5] += center[1]
return matrix


def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
"""Apply affine transformation on the image keeping image center invariant

Args:
img (PIL Image): PIL Image to be rotated.
angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction.
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale
shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
An optional resampling filter.
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
fillcolor (int): Optional fill color for the area outside the transform in the output image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"Argument translate should be a list or tuple of length 2"

assert scale > 0.0, "Argument scale should be positive"

output_size = img.size
center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
return img.transform(output_size, Image.AFFINE, matrix, resample, fillcolor=fillcolor)


def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.

Expand Down
122 changes: 121 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]

_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Expand Down Expand Up @@ -808,6 +808,126 @@ def __repr__(self):
return format_string


class RandomAffine(object):
"""Random affine transformation of the image keeping center invariant

Args:
degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to desactivate rotations.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or float or int, optional): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Will not apply shear by default
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
An optional resampling filter.
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
fillcolor (int): Optional fill color for the area outside the transform in the output image.
"""

def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees

if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate

if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale

if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear

self.resample = resample
self.fillcolor = fillcolor

@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
"""Get parameters for affine transformation

Returns:
sequence: params to be passed to the affine transformation
"""
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)

if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0

if shears is not None:
shear = random.uniform(shears[0], shears[1])
else:
shear = 0.0

return angle, translations, scale, shear

def __call__(self, img):
"""
img (PIL Image): Image to be transformed.

Returns:
PIL Image: Affine transformed image.
"""
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)

def __repr__(self):
s = '{name}(degrees={degrees}'
if self.translate is not None:
s += ', translate={translate}'
if self.scale is not None:
s += ', scale={scale}'
if self.shear is not None:
s += ', shear={shear}'
if self.resample > 0:
s += ', resample={resample}'

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if self.fillcolor != 0:
s += ', fillcolor={fillcolor}'
s += ')'
d = dict(self.__dict__)
d['resample'] = _pil_interpolation_to_str[d['resample']]
return s.format(name=self.__class__.__name__, **d)


class Grayscale(object):
"""Convert image to grayscale.

Expand Down