Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AffineTransformation #793

Merged
merged 34 commits into from
Mar 25, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
154cadb
Merge pull request #1 from pytorch/master
ekagra-ranjan Feb 14, 2019
e964780
Merge pull request #2 from pytorch/master
ekagra-ranjan Mar 9, 2019
103b25a
Merge pull request #3 from pytorch/master
ekagra-ranjan Mar 11, 2019
c99eb98
Add Affinetransformation
ekagra-ranjan Mar 11, 2019
48d9002
Add test
ekagra-ranjan Mar 11, 2019
1d057de
Add zero mean_vector in LinearTransformation and improved docs
ekagra-ranjan Mar 11, 2019
f538d46
Merge pull request #4 from pytorch/master
ekagra-ranjan Mar 11, 2019
4d24d94
update
ekagra-ranjan Mar 11, 2019
be21da5
minor fix
ekagra-ranjan Mar 11, 2019
2c57a56
minor fix2
ekagra-ranjan Mar 11, 2019
c8d5711
fixed flake8
ekagra-ranjan Mar 11, 2019
9462789
fix flake8
ekagra-ranjan Mar 11, 2019
c257605
fixed transpose syntax
ekagra-ranjan Mar 11, 2019
defa1da
fixed shape of mean_vector in test
ekagra-ranjan Mar 11, 2019
9b8d07b
fixed test
ekagra-ranjan Mar 11, 2019
cf7afcd
print est cov and mean
ekagra-ranjan Mar 11, 2019
a1ee3d1
fixed flake8
ekagra-ranjan Mar 11, 2019
881eb41
debug
ekagra-ranjan Mar 11, 2019
f498aa5
reduce num_samples
ekagra-ranjan Mar 11, 2019
606f426
debug
ekagra-ranjan Mar 11, 2019
6b5bab0
fixed num_features
ekagra-ranjan Mar 11, 2019
0238549
fixed rtol for cov
ekagra-ranjan Mar 11, 2019
12e6685
fix __repr__
ekagra-ranjan Mar 11, 2019
121c1cf
Update transforms.py
ekagra-ranjan Mar 11, 2019
5b7a445
Update test_transforms.py
ekagra-ranjan Mar 11, 2019
c2d7627
Update transforms.py
ekagra-ranjan Mar 11, 2019
f524dfa
fix flake8
ekagra-ranjan Mar 12, 2019
18bd5fc
Update transforms.py
ekagra-ranjan Mar 12, 2019
e3735dc
Update transforms.py
ekagra-ranjan Mar 12, 2019
d5dcf4a
Update transforms.py
ekagra-ranjan Mar 12, 2019
c727945
Update transforms.py
ekagra-ranjan Mar 12, 2019
9907dee
Changed dim of mean_vector to 1D, doc and removed .numpy () from form…
ekagra-ranjan Mar 25, 2019
04c70ca
Restore test_linear_transformation()
ekagra-ranjan Mar 25, 2019
5883ef0
Update test_transforms.py
ekagra-ranjan Mar 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def test_color_jitter(self):
# Checking if ColorJitter can be printed as string
color_jitter.__repr__()

def test_linear_transformation(self):
def test_affine_transformation(self):
x = torch.randn(250, 10, 10, 3)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
# compute principal components
Expand All @@ -961,17 +961,20 @@ def test_linear_transformation(self):
zca_epsilon = 1e-10 # avoid division by 0
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t())
principal_components = torch.mm(torch.mm(u, d), u.t
mean_vector = torch.sum(flat_x, dim=0) / flat_x.size(0)
# initialize whitening matrix
whitening = transforms.LinearTransformation(principal_components)
whitening = transforms.AffineTransformation(principal_components, mean_vector)
# pass first vector
xwhite = whitening(x[0].view(10, 10, 3))
# estimate covariance
xwhite = xwhite.view(1, 300).numpy()
cov = np.dot(xwhite, xwhite.T) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3)
mean = np.sum(xwhite) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3), "cov not close to 1"
assert np.allclose(mean, 0, rtol=1e-3), "mean not close to 0"

# Checking if LinearTransformation can be printed as string
# Checking if AffineTransformation can be printed as string
whitening.__repr__()

def test_rotate(self):
Expand Down
40 changes: 28 additions & 12 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
"AffineTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
ekagra-ranjan marked this conversation as resolved.
Show resolved Hide resolved

_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Expand Down Expand Up @@ -702,52 +702,68 @@ def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)


class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix computed
class AffineTransformation(object):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.

Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
product with the transformation matrix and reshape the tensor to its
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and then reshaping the tensor to its
original shape.

Applications:
- whitening: zero-center the data, compute the data covariance matrix
[D x D] with np.dot(X.T, X), perform SVD on this matrix and
pass it as transformation_matrix.

Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [D], D = C x H x W
"""

def __init__(self, transformation_matrix):
def __init__(self, transformation_matrix, mean_vector):
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))

if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
" as the transformation_matrix")

self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector

def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

ekagra-ranjan marked this conversation as resolved.
Show resolved Hide resolved
Returns:
Tensor: Transformed image.
"""
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1)
flat_tensor = tensor.view(1, -1) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size())
return tensor

def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string = self.__class__.__name__ + '(transformation_matrix='
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
format_string += (", (mean_vector=" + str(self.cov_matrix.numpy().tolist()) + ')')
return format_string


class LinearTransformation(object):
"""
Note: This transform is deprecated in favor of AffineTransformation.
"""

def __init__(self, transformation_matrix):
warnings.warn("The use of the transforms.LinearTransformation transform is deprecated, " +
"please use transforms.AffineTransformation instead.")
super(LinearTransformation, self).__init__(transformation_matrix)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you want to pass a mean_vector filled with zeros here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you prefer for the dimension of mean_vector: 1xD or Dx1 given that transformation_matrix is DxD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the implementation of LinearTransformation I think it should be 1 x D.

I feel that the doc of the application of LinearTransformation about whitening is broken. The doc only shows the 1st line of the paragraph in the application block and the rest outside it. How do we write a multi-line bullet point? Do we replace - with ---?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan the dimension of the tensor doesn't matter, you can pass a D tensor and it will be fine, given that broadcasting will happen and the tensor will be filled with all zeros anyway.

About the doc, I'm not sure about how to properly format things. I'd try compiling the documentation locally and checking it visually



class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.

Expand Down