Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AffineTransformation #793

Merged
merged 34 commits into from
Mar 25, 2019
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
154cadb
Merge pull request #1 from pytorch/master
ekagra-ranjan Feb 14, 2019
e964780
Merge pull request #2 from pytorch/master
ekagra-ranjan Mar 9, 2019
103b25a
Merge pull request #3 from pytorch/master
ekagra-ranjan Mar 11, 2019
c99eb98
Add Affinetransformation
ekagra-ranjan Mar 11, 2019
48d9002
Add test
ekagra-ranjan Mar 11, 2019
1d057de
Add zero mean_vector in LinearTransformation and improved docs
ekagra-ranjan Mar 11, 2019
f538d46
Merge pull request #4 from pytorch/master
ekagra-ranjan Mar 11, 2019
4d24d94
update
ekagra-ranjan Mar 11, 2019
be21da5
minor fix
ekagra-ranjan Mar 11, 2019
2c57a56
minor fix2
ekagra-ranjan Mar 11, 2019
c8d5711
fixed flake8
ekagra-ranjan Mar 11, 2019
9462789
fix flake8
ekagra-ranjan Mar 11, 2019
c257605
fixed transpose syntax
ekagra-ranjan Mar 11, 2019
defa1da
fixed shape of mean_vector in test
ekagra-ranjan Mar 11, 2019
9b8d07b
fixed test
ekagra-ranjan Mar 11, 2019
cf7afcd
print est cov and mean
ekagra-ranjan Mar 11, 2019
a1ee3d1
fixed flake8
ekagra-ranjan Mar 11, 2019
881eb41
debug
ekagra-ranjan Mar 11, 2019
f498aa5
reduce num_samples
ekagra-ranjan Mar 11, 2019
606f426
debug
ekagra-ranjan Mar 11, 2019
6b5bab0
fixed num_features
ekagra-ranjan Mar 11, 2019
0238549
fixed rtol for cov
ekagra-ranjan Mar 11, 2019
12e6685
fix __repr__
ekagra-ranjan Mar 11, 2019
121c1cf
Update transforms.py
ekagra-ranjan Mar 11, 2019
5b7a445
Update test_transforms.py
ekagra-ranjan Mar 11, 2019
c2d7627
Update transforms.py
ekagra-ranjan Mar 11, 2019
f524dfa
fix flake8
ekagra-ranjan Mar 12, 2019
18bd5fc
Update transforms.py
ekagra-ranjan Mar 12, 2019
e3735dc
Update transforms.py
ekagra-ranjan Mar 12, 2019
d5dcf4a
Update transforms.py
ekagra-ranjan Mar 12, 2019
c727945
Update transforms.py
ekagra-ranjan Mar 12, 2019
9907dee
Changed dim of mean_vector to 1D, doc and removed .numpy () from form…
ekagra-ranjan Mar 25, 2019
04c70ca
Restore test_linear_transformation()
ekagra-ranjan Mar 25, 2019
5883ef0
Update test_transforms.py
ekagra-ranjan Mar 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,9 @@ def test_color_jitter(self):
# Checking if ColorJitter can be printed as string
color_jitter.__repr__()

def test_linear_transformation(self):
x = torch.randn(250, 10, 10, 3)
def test_affine_transformation(self):
num_samples = 1000
x = torch.randn(num_samples, 3, 10, 10)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
# compute principal components
sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
Expand All @@ -962,16 +963,23 @@ def test_linear_transformation(self):
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t())
mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0)).view(1, -1)
# initialize whitening matrix
whitening = transforms.LinearTransformation(principal_components)
# pass first vector
xwhite = whitening(x[0].view(10, 10, 3))
# estimate covariance
xwhite = xwhite.view(1, 300).numpy()
cov = np.dot(xwhite, xwhite.T) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3)

# Checking if LinearTransformation can be printed as string
whitening = transforms.AffineTransformation(principal_components, mean_vector)
# estimate covariance and mean using weak law of large number
num_features = flat_x.size(1)
cov = 0.0
mean = 0.0
for i in x:
xwhite = whitening(i)
xwhite = xwhite.view(1, -1).numpy()
cov += np.dot(xwhite, xwhite.T) / num_features
mean += np.sum(xwhite) / num_features
# if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1"
assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0"

# Checking if AffineTransformation can be printed as string
whitening.__repr__()

def test_rotate(self):
Expand Down
42 changes: 30 additions & 12 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
"AffineTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
ekagra-ranjan marked this conversation as resolved.
Show resolved Hide resolved

_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Expand Down Expand Up @@ -397,7 +397,7 @@ class RandomCrop(object):
respectively. If a sequence of length 2 is provided, it is used to
pad left/right, top/bottom borders, respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
Expand Down Expand Up @@ -703,28 +703,34 @@ def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)


class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix computed
class AffineTransformation(object):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.

Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
product with the transformation matrix and reshape the tensor to its
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and then reshaping the tensor to its
original shape.

Applications:
- whitening: zero-center the data, compute the data covariance matrix
[D x D] with np.dot(X.T, X), perform SVD on this matrix and
pass it as transformation_matrix.

Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [1 x D], D = C x H x W
"""

def __init__(self, transformation_matrix):
def __init__(self, transformation_matrix, mean_vector):
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))

if mean_vector.size(1) != transformation_matrix.size(0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is raise an error if the user tries to use LinearTransformation, because the mean vector is 1d.

I really think that you should just make mean_vector here be 1d as well, and change it in the documentation.

Also, can you keep the tests for the LinearTransformation?

raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(1)) +
" as any one of the dimensions of the transformation_matrix [{} x {}]"
.format(transformation_matrix.size()))

self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector

def __call__(self, tensor):
"""
Expand All @@ -738,17 +744,29 @@ def __call__(self, tensor):
raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1)
flat_tensor = tensor.view(1, -1) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size())
return tensor

def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string = self.__class__.__name__ + '(transformation_matrix='
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
format_string += (", (mean_vector=" + str(self.mean_vector.numpy().tolist()) + ')')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to convert to numpy here, tolist() also exists for torch tensors

return format_string


class LinearTransformation(AffineTransformation):
"""
Note: This transform is deprecated in favor of AffineTransformation.
"""

def __init__(self, transformation_matrix):
warnings.warn("The use of the transforms.LinearTransformation transform is deprecated, " +
"please use transforms.AffineTransformation instead.")
super(LinearTransformation, self).__init__(transformation_matrix, torch.zeros_like(transformation_matrix[0]))


class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.

Expand Down