Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename InterpolationModes to InterpolationMode #3055

Merged
merged 1 commit into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationModes
from torchvision.transforms import InterpolationMode

from common_utils import TransformsTester


NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC


class Tester(TransformsTester):
Expand Down Expand Up @@ -419,7 +419,7 @@ def test_resize(self):
)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
Expand Down Expand Up @@ -626,7 +626,7 @@ def test_affine(self):
self.assertTrue(res1.equal(res2))

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
Expand Down Expand Up @@ -714,7 +714,7 @@ def test_rotate(self):
self.assertTrue(res1.equal(res2))

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
Expand Down Expand Up @@ -788,7 +788,7 @@ def test_perspective(self):
# assert changed type warning
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
Expand Down
14 changes: 7 additions & 7 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,12 +1500,12 @@ def test_random_rotation(self):
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
t = transforms.RandomRotation((-10, 10), resample=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
t = transforms.RandomRotation((-10, 10), interpolation=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)

def test_random_affine(self):

Expand Down Expand Up @@ -1547,22 +1547,22 @@ def test_random_affine(self):
# Checking if RandomAffine can be printed as string
t.__repr__()

t = transforms.RandomAffine(10, interpolation=transforms.InterpolationModes.BILINEAR)
t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR)
self.assertIn("bilinear", t.__repr__())

# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
t = transforms.RandomAffine(10, resample=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)

with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
t = transforms.RandomAffine(10, fillcolor=10)
self.assertEqual(t.fill, 10)

# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
t = transforms.RandomAffine(10, interpolation=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
self.assertEqual(t.interpolation, transforms.InterpolationMode.BILINEAR)

def test_to_grayscale(self):
"""Unit tests for grayscale transform"""
Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationModes
from torchvision.transforms import InterpolationMode

import numpy as np

Expand All @@ -11,7 +11,7 @@
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes


NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC


class Tester(TransformsTester):
Expand Down
106 changes: 53 additions & 53 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import functional_tensor as F_t


class InterpolationModes(Enum):
class InterpolationMode(Enum):
"""Interpolation modes
"""
NEAREST = "nearest"
Expand All @@ -33,26 +33,26 @@ class InterpolationModes(Enum):


# TODO: Once torchscript supports Enums with staticmethod
# this can be put into InterpolationModes as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationModes:
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
inverse_modes_mapping = {
0: InterpolationModes.NEAREST,
2: InterpolationModes.BILINEAR,
3: InterpolationModes.BICUBIC,
4: InterpolationModes.BOX,
5: InterpolationModes.HAMMING,
1: InterpolationModes.LANCZOS,
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[i]


pil_modes_mapping = {
InterpolationModes.NEAREST: 0,
InterpolationModes.BILINEAR: 2,
InterpolationModes.BICUBIC: 3,
InterpolationModes.BOX: 4,
InterpolationModes.HAMMING: 5,
InterpolationModes.LANCZOS: 1,
InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3,
InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1,
}

_is_pil_image = F_pil._is_pil_image
Expand Down Expand Up @@ -329,7 +329,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor


def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = InterpolationModes.BILINEAR) -> Tensor:
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor:
r"""Resize the input image to the given size.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Expand All @@ -343,10 +343,10 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode size as single int is not supported, use a tuple or
list of length 1: ``[size, ]``.
interpolation (InterpolationModes): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationModes`.
Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``,
``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

Returns:
Expand All @@ -355,13 +355,13 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)

if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
Expand Down Expand Up @@ -475,7 +475,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:

def resized_crop(
img: Tensor, top: int, left: int, height: int, width: int, size: List[int],
interpolation: InterpolationModes = InterpolationModes.BILINEAR
interpolation: InterpolationMode = InterpolationMode.BILINEAR
) -> Tensor:
"""Crop the given image and resize it to desired size.
The image can be a PIL Image or a Tensor, in which case it is expected
Expand All @@ -490,10 +490,10 @@ def resized_crop(
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (InterpolationModes): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationModes`.
Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``,
``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

Returns:
Expand Down Expand Up @@ -556,7 +556,7 @@ def perspective(
img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationModes = InterpolationModes.BILINEAR,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[int] = None
) -> Tensor:
"""Perform perspective transform of the given image.
Expand All @@ -569,9 +569,9 @@ def perspective(
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (InterpolationModes): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``.
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
Expand All @@ -587,13 +587,13 @@ def perspective(
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)

if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
Expand Down Expand Up @@ -869,7 +869,7 @@ def _get_inverse_affine_matrix(


def rotate(
img: Tensor, angle: float, interpolation: InterpolationModes = InterpolationModes.NEAREST,
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, center: Optional[List[int]] = None,
fill: Optional[int] = None, resample: Optional[int] = None
) -> Tensor:
Expand All @@ -880,9 +880,9 @@ def rotate(
Args:
img (PIL Image or Tensor): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise.
interpolation (InterpolationModes): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
Expand Down Expand Up @@ -913,8 +913,8 @@ def rotate(
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)

Expand All @@ -924,8 +924,8 @@ def rotate(
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")

if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
Expand All @@ -945,7 +945,7 @@ def rotate(

def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
interpolation: InterpolationModes = InterpolationModes.NEAREST, fill: Optional[int] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None,
resample: Optional[int] = None, fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
Expand All @@ -960,9 +960,9 @@ def affine(
shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
interpolation (InterpolationModes): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
Expand All @@ -984,8 +984,8 @@ def affine(
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)

Expand All @@ -1010,8 +1010,8 @@ def affine(
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")

if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if isinstance(angle, int):
angle = float(angle)
Expand Down
Loading