Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] Introduced InterpolationModes and deprecated arguments: resample and fillcolor #2952

Merged
merged 16 commits into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,12 @@ def test_resized_crop(self):

def _test_affine_identity_map(self, tensor, scripted_affine):
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0)

self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
Expand All @@ -461,13 +461,13 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
]
for a, true_tensor in test_configs:
out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)

for fn in [F.affine, scripted_affine]:
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0
)
if true_tensor is not None:
self.assertTrue(
Expand Down Expand Up @@ -496,13 +496,13 @@ def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
for a in test_configs:

out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

for fn in [F.affine, scripted_affine]:
out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0
).cpu()

if out_tensor.dtype != torch.uint8:
Expand All @@ -526,10 +526,10 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine):
]
for t in test_configs:

out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=0)

for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=0)

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
Expand All @@ -552,11 +552,11 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu()
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
Expand Down Expand Up @@ -613,10 +613,10 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
for e in [True, False]:
for c in centers:

out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu()
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu()

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
Expand Down Expand Up @@ -673,7 +673,7 @@ def test_rotate(self):

center = (20, 22)
self._test_fn_on_batch(
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
batch_tensors, F.rotate, angle=32, interpolation=0, expand=True, center=center
)

def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def test_random_affine(self):
# Checking if RandomAffine can be printed as string
t.__repr__()

t = transforms.RandomAffine(10, resample=Image.BILINEAR)
t = transforms.RandomAffine(10, interpolation=Image.BILINEAR)
self.assertIn("Image.BILINEAR", t.__repr__())

def test_to_grayscale(self):
Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_random_affine(self):
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomAffine(
degrees=degrees, translate=translate,
scale=scale, shear=shear, resample=interpolation
scale=scale, shear=shear, interpolation=interpolation
)
s_transform = torch.jit.script(transform)

Expand All @@ -368,7 +368,7 @@ def test_random_rotate(self):
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomRotation(
degrees=degrees, resample=interpolation, expand=expand, center=center
degrees=degrees, interpolation=interpolation, expand=expand, center=center
)
s_transform = torch.jit.script(transform)

Expand Down
33 changes: 23 additions & 10 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,8 @@ def _get_inverse_affine_matrix(


def rotate(
img: Tensor, angle: float, resample: int = 0, expand: bool = False,
center: Optional[List[int]] = None, fill: Optional[int] = None
img: Tensor, angle: float, interpolation: int = 0, expand: bool = False,
center: Optional[List[int]] = None, fill: Optional[int] = None, resample: Optional[int] = None
) -> Tensor:
"""Rotate the image by angle.
The image can be a PIL Image or a Tensor, in which case it is expected
Expand All @@ -803,7 +803,7 @@ def rotate(
Args:
img (PIL Image or Tensor): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
expand (bool, optional): Optional expansion flag.
Expand All @@ -817,21 +817,25 @@ def rotate(
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0.
resample (int, optional): deprecated argument, please use `arg`:interpolation: instead.

Returns:
PIL Image or Tensor: Rotated image.

.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

"""
if resample is not None:
warnings.warn("Argument resample is deprecated. Please, use interpolation instead")
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")

if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")

if not isinstance(img, torch.Tensor):
return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill)
return F_pil.rotate(img, angle=angle, interpolation=interpolation, expand=expand, center=center, fill=fill)

center_f = [0.0, 0.0]
if center is not None:
Expand All @@ -842,12 +846,13 @@ def rotate(
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill)
return F_t.rotate(img, matrix=matrix, interpolation=interpolation, expand=expand, fill=fill)


def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
resample: int = 0, fillcolor: Optional[int] = None
interpolation: int = 0, fill: Optional[int] = None, resample: Optional[int] = None,
fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected
Expand All @@ -861,17 +866,25 @@ def affine(
shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0.
fillcolor (tuple or int, optional): deprecated argument, please use `arg`:fill: instead.
resample (int, optional): deprecated argument, please use `arg`:interpolation: instead.

Returns:
PIL Image or Tensor: Transformed image.
"""
if resample is not None:
warnings.warn("Argument resample is deprecated. Please, use interpolation instead")

if fillcolor is not None:
warnings.warn("Argument fillcolor is deprecated. Please, use fill instead")
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")

Expand Down Expand Up @@ -913,11 +926,11 @@ def affine(
center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
return F_pil.affine(img, matrix=matrix, interpolation=interpolation, fill=fill)

translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
return F_t.affine(img, matrix=matrix, interpolation=interpolation, fill=fill)


@torch.jit.unused
Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"):


@torch.jit.unused
def affine(img, matrix, resample=0, fillcolor=None):
def affine(img, matrix, interpolation=0, fill=None):
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
"""PRIVATE METHOD. Apply affine transformation on the PIL Image keeping image center invariant.

.. warning::
Expand All @@ -485,11 +485,11 @@ def affine(img, matrix, resample=0, fillcolor=None):
Args:
img (PIL Image): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter.
See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
fill (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)

Returns:
PIL Image: Transformed image.
Expand All @@ -498,12 +498,12 @@ def affine(img, matrix, resample=0, fillcolor=None):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

output_size = img.size
opts = _parse_fill(fillcolor, img, '5.0.0')
return img.transform(output_size, Image.AFFINE, matrix, resample, **opts)
opts = _parse_fill(fill, img, '5.0.0')
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)


@torch.jit.unused
def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
"""PRIVATE METHOD. Rotate PIL image by angle.

.. warning::
Expand All @@ -514,7 +514,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
Args:
img (PIL Image): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
expand (bool, optional): Optional expansion flag.
Expand All @@ -538,7 +538,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))

opts = _parse_fill(fill, img, '5.2.0')
return img.rotate(angle, resample, expand, center, **opts)
return img.rotate(angle, interpolation, expand, center, **opts)


@torch.jit.unused
Expand Down
Loading