Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified input for resize op #2394

Merged
merged 4 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
import numpy as np
import unittest
import random
import colorsys

from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC

import numpy as np

import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F


class Tester(unittest.TestCase):
Expand All @@ -22,6 +25,14 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
mae = torch.abs(tensor - pil_tensor).mean().item()
self.assertTrue(
mae < tol,
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)

def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16)
Expand Down Expand Up @@ -282,6 +293,44 @@ def test_pad(self):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36)

for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, [32, ], [32, 32], (32, 32), ]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation)

self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
)

if interpolation != NEAREST:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
Comment on lines +315 to +316
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity that the behavior of nearest interpolate is different between implementations, I would say it could be worth opening an issue in PyTorch to mention this. I also believe that PIL and OpenCV are consistent, which would make for a case to maybe change the implementation in PyTorch to make this more consistent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check that between PIL and OpenCV and then we decide about PyTorch.

resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)

# Pay attention to high tolerance for MAE
self.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
)

if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))


if __name__ == '__main__':
unittest.main()
28 changes: 28 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC

import numpy as np

Expand Down Expand Up @@ -217,6 +218,33 @@ def test_ten_crop(self):
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_resize(self):
tensor, _ = self._create_data(height=34, width=36)
script_fn = torch.jit.script(F.resize)

for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, [32, ], [32, 32], (32, 32), ]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)

if isinstance(size, int):
script_size = [size, ]
else:
script_size = size

s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(s_resized_tensor.equal(resized_tensor))

transform = T.Resize(size=script_size, interpolation=interpolation)
resized_tensor = transform(tensor)
script_transform = torch.jit.script(transform)
s_resized_tensor = script_transform(tensor)
self.assertTrue(s_resized_tensor.equal(resized_tensor))


if __name__ == '__main__':
unittest.main()
38 changes: 13 additions & 25 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,41 +311,29 @@ def normalize(tensor, mean, std, inplace=False):
return tensor


def resize(img, size, interpolation=Image.BILINEAR):
r"""Resize the input PIL Image to the given size.
def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
r"""Resize the input image to the given size.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

Args:
img (PIL Image): Image to be resized.
img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as single int is not supported, use a tuple or
list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is bilinear.

Returns:
PIL Image: Resized image.
PIL Image or Tensor: Resized image.
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if not isinstance(img, torch.Tensor):
return F_pil.resize(img, size=size, interpolation=interpolation)

if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
return F_t.resize(img, size=size, interpolation=interpolation)


def scale(*args, **kwargs):
Expand Down
43 changes: 42 additions & 1 deletion torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, List
from typing import Any, List, Sequence

import torch
try:
Expand Down Expand Up @@ -286,3 +286,44 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((left, top, left + width, top + height))


@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR):
r"""Resize the input PIL Image to the given size.

Args:
img (PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
For compatibility reasons with ``functional_tensor.resize``, if a tuple or list of length 1 is provided,
it is interpreted as a single int.
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``.

Returns:
PIL Image: Resized image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))

if isinstance(size, int) or len(size) == 1:
if isinstance(size, Sequence):
size = size[0]
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
96 changes: 96 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from PIL.Image import NEAREST, BILINEAR, BICUBIC
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this doesn't seem to be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Will remove it


import torch
from torch import Tensor
from torch.jit.annotations import List, BroadcastingList2
Expand All @@ -8,6 +10,7 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:


def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
if _is_tensor_a_torch_image(img):
return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected type {}".format(type(img)))
Expand Down Expand Up @@ -433,6 +436,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con

if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
Expand Down Expand Up @@ -480,3 +484,95 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
img = img.to(out_dtype)

return img


def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
r"""Resize the input Tensor to the given size.

Args:
img (Tensor): Image to be resized.
size (int or tuple or list): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as a single int is not supported, use a tuple or
list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is bilinear.

Returns:
Tensor: Resized image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")

if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, int):
raise TypeError("Got inappropriate interpolation arg")

_interpolation_modes = {
0: "nearest",
2: "bilinear",
3: "bicubic",
}

if interpolation not in _interpolation_modes:
raise ValueError("This interpolation mode is unsupported with Tensor input")

if isinstance(size, tuple):
size = list(size)

if isinstance(size, list) and len(size) not in [1, 2]:
raise ValueError("Padding must be an int or a 1 or 2 element tuple/list, not a " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error message seems off?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

"{} element tuple/list".format(len(size)))

if interpolation not in [0, 1, 2, 3, 4]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we don't support mode 1 in the _interpolation_modes?

raise ValueError("Interpolation mode should be either constant, edge, reflect or symmetric")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error message seems off?


w, h = _get_image_size(img)

if isinstance(size, int):
size_w, size_h = size, size
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[0], size[1]

if isinstance(size, int) or len(size) < 2:
if w < h:
size_h = int(size_w * h / w)
else:
size_w = int(size_h * w / h)

if (w <= h and w == size_w) or (h <= w and h == size_h):
return img

# make image NCHW
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True

mode = _interpolation_modes[interpolation]

out_dtype = img.dtype
need_cast = False
if img.dtype not in (torch.float32, torch.float64):
need_cast = True
img = img.to(torch.float32)

# Define align_corners to avoid warnings
align_corners = False if mode in ["bilinear", "bicubic"] else None

img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners)

if need_squeeze:
img = img.squeeze(dim=0)

if need_cast:
if mode == "bicubic":
img = img.clamp(min=0, max=255)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for the future, we might want to change the 255 value to be dtype-dependent. But this can be done in a future, maybe using convert_image_dtype before and after interpolate is called.

img = img.to(out_dtype)

return img
Loading