-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
In transforms.Resize, tensor interpolate is not the same as PIL resize. #2950
Comments
I think the different interpolation method used in between |
@hjinlee88 thanks for the report. Currently, interpolation method is only visually similar for "nearest" option: interpolation = 0
tensor_interpolate = transforms.Compose([transforms.ToTensor(), transforms.Resize(size, interpolation=interpolation), transforms.ToPILImage()])
pillow_resize = transforms.Compose([transforms.Resize(size, interpolation=interpolation)]) In general, as @zhiqwang said,
|
@vfdev-5 I did a little research on this. The easiest solution for downscale is to downsample as much as possible and then interpolate. Here is an example. The code is a bit dirty, but you can see what I am doing. import urllib
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
class ResizeModify(nn.Module):
def __init__(self, size, interpolation):
super().__init__()
self.size = size
self.interpolation = interpolation
self.Resize = transforms.Resize(size, interpolation)
def forward(self, img):
img = img.unsqueeze(0)
h, w = img.shape[2:]
if isinstance(self.size, int) == 1:
if h > w:
h2, w2 = int(self.size / w * h), self.size
else:
h2, w2 = self.size, int(self.size / h * w)
else:
h2, w2 = self.size
if h2 > h: # upscale h
strides = int(h2 / h)
if strides > 1:
weights = torch.full((img.shape[1], 1, strides, 1), 1.0)
img = F.conv_transpose2d(img, weights, stride=(strides, 1), groups=img.shape[1])
else: # downscale h
strides = int(h / h2) # floor and int
if strides > 1:
# test with uniform weight, but normal (gaussian) weight will be better.
weights = torch.full((img.shape[1], 1, strides, 1), 1 / strides)
img = F.conv2d(img, weights, stride=(strides, 1), groups=img.shape[1])
if w2 > w: # upsacle w
strides = int(w2 / w)
if strides > 1:
weights = torch.full((img.shape[1], 1, 1, strides), 1.0)
img = F.conv_transpose2d(img, weights, stride=(1, strides), groups=img.shape[1])
else: # downscale w
strides = int(w / w2)
if strides > 1:
weights = torch.full((img.shape[1], 1, 1, strides), 1 / strides)
img = F.conv2d(img, weights, stride=(1, strides), groups=img.shape[1])
img = img.squeeze(0)
return self.Resize(img)
def main():
size = 112
img = Image.open(urllib.request.urlopen("https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image01.png"))
# img = Image.open("images/tv_image01.png")
interpolation = [Image.NEAREST, Image.BILINEAR, Image.BICUBIC][1]
tensor_interpolate = transforms.Compose(
[transforms.ToTensor(), transforms.Resize(size, interpolation=interpolation), transforms.ToPILImage()])
tensor_interpolate_modify = transforms.Compose(
[transforms.ToTensor(), ResizeModify(size, interpolation=interpolation), transforms.ToPILImage()])
pillow_resize = transforms.Compose([transforms.Resize(size, interpolation=interpolation)])
plt.subplot(221)
plt.imshow(img)
plt.title("original")
plt.subplot(222)
plt.imshow(pillow_resize(img))
plt.title("pillow resize")
plt.subplot(223)
plt.imshow(tensor_interpolate(img))
plt.title("tensor interpolate")
plt.subplot(224)
plt.imshow(tensor_interpolate_modify(img))
plt.title("tensor interpolate modify")
plt.show()
if __name__ == "__main__":
main() |
@hjinlee88 thanks for investigating this. Results of resampling with convolutions look neat. Let me check and decide if what can be done with this issue. EDIT:
As in the code we are working on tensor input with dtype float32 we do not perform any clamping: transforms.Compose([transforms.ToTensor(), transforms.Resize(size, interpolation=interpolation), transforms.ToPILImage()])
EDIT, EDIT: clamping float32 may lead to other unexpected results as we have no predefined range for float32 input vs 0-255 for uint8. The issue is also with - tensor_interpolate = transforms.Compose([transforms.ToTensor(), transforms.Resize(size, interpolation=interpolation), transforms.ToPILImage()])
+ tensor_interpolate = transforms.Compose([
lambda x: torch.from_numpy(np.asarray(x)).permute(2, 0, 1),
transforms.Resize(size, interpolation=interpolation),
lambda x: Image.fromarray(x.permute(1, 2, 0).numpy()),
])
|
@vfdev-5 The below also works. interpolation = Image.BICUBIC
tensor_interpolate = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size, interpolation=interpolation),
lambda x: x.clamp(0, 1),
transforms.ToPILImage()
]) The problem could be in this line vision/torchvision/transforms/functional.py Lines 196 to 197 in 8088cc9
from the following reasons: print(torch.tensor([255, 256, 257]).byte())
I suggest below because mul(255) assume the pic is float and is in range [0, 1]. pic = pic.mul(255).clamp(0, 255).byte() EDIT: add some explanation and suggestions. |
@vfdev-5 I investigated the code I wrote earlier #2950 (comment). I studied transpose convolution and found it useless here. Because, in here, this is just like copying the pixels closer together. Therefore, it must be removed. conv2d with weights and strides looks good because it is essentially the same as blur (mean blur or gaussian blur) and downsampling. However, weights of Gaussian blur can be difficult to implement because the shape must be odd and you have to deal with sigma. But I don't understand the behavior of interpolate2d. Is it intended to find the location of the target pixel in the source image, and find the pixel value using only the 4 points around the location? |
@hjinlee88 interpolate in PyTorch implements interpolation following the standard approaches from OpenCV (for float values). For a python-based implementation of interpolate that gives the exact same result as torchvision, see https://gist.github.com/fmassa/cb2d0dff7731f6459d8ca5b5c9ea15d9 , in particular I took your example image and used instead OpenCV to perform bilinear interpolation, and the results from torchvision and OpenCV matched almost exactly, with just rounding differences leading to 1 (out of 255) pixel differences. # same code as before that should be added here
import cv2
import numpy as np
imt = np.array(img)
tt = transforms.Resize(size, interpolation=interpolation)
# convert image as tensor without casting to 0-1
# and then resize it
res_tv = tt(torch.as_tensor(imt).permute(2, 0, 1)).permute(1, 2, 0).contiguous().numpy()
# apply bilinear resize from opencv
res_cv = cv2.resize(imt, (231, 112), interpolation=cv2.INTER_LINEAR)
# compute the difference
np.abs(res_tv.astype(float) - res_cv.astype(float)).max()
# > returns 1.0 |
@fmassa Thank you for the explanation. I find that OpenCV already knows that CV resize is not same as Pillow resize. opencv/opencv#17068 (comment) I also tested resize bilinear in GIMP, the famous GNU graphics editor, and the result is different for both CV resize and Pillow resize. This result means that resize can be different for each program. However, I think the current situation, that the output for the same input of the same class Resize is different depending on the type of input (torch Tensor or pillow Image), should be fixed. I compared the results of the #2950 (comment) with the results of Pillow Resize, but the results were not the same. It's probably because the weights are different. |
I agree that having a slightly different output for the function depending if it's using PIL Images or torch Tensors is not ideal, but fixing this would be complicated. The trade-off here is that in C++, most users would rely on OpenCV or PyTorch to perform the resizing, so it would make sense for torchvision to be compatible with both. Plus we get the benefit that those ops are already implemented in C++ / CUDA with native batch support. |
Cross-linking some discussion in https://twitter.com/ajlavin/status/1336131931314954240 Maybe it might be worth considering adding a new interpolation mode, akin to |
Resuscitating this thread: I just lost a few days chasing down a bug because we assumed the output of Could we add a note in the documentation specifying that users should not expect the same behavior if downsampling depending on whether they pass an image or a tensor? import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import PIL.Image as pil
img_pil = TF.to_pil_image(torch.randint(size=[3, 128, 128], low=0, high=255, dtype=torch.uint8))
img = TF.to_tensor(img_pil)
img_small_pil = TF.resize(img_pil, 64, interpolation=pil.BILINEAR)
img_small = TF.resize(img, 64, interpolation=pil.BILINEAR)
img_big_pil = TF.resize(img_pil, 256, interpolation=pil.BILINEAR)
img_big = TF.resize(img, 256, interpolation=pil.BILINEAR)
upsample_avg_error = torch.mean(torch.abs(TF.to_tensor(img_big_pil) - img_big)) * 255
downsample_avg_error = torch.mean(torch.abs(TF.to_tensor(img_small_pil) - img_small)) * 255
print(f"upsample_avg_error: {upsample_avg_error:0.2f}")
print(f"downsample_avg_error: {downsample_avg_error:0.2f}") upsample_avg_error: 0.35
downsample_avg_error: 15.28 |
Hi @mrharicot , @tcapelle Very sorry about the situation. We are working on adding a support for anti-aliasing for Tensor Transforms, so that they more closely match PIL. |
cc @vfdev-5 who will be looking into addressing this |
I wrote a summary of the issue here: https://tcapelle.github.io/pytorch/fastai/2021/02/26/image_resizing.html |
Your article saved my day! |
@iynaur since version 0.10.0 we added import torch
import torchvision
print(torch.__version__, torchvision.__version__)
import matplotlib.pyplot as plt
import urllib
import torchvision.transforms.functional as f
from PIL import Image as Image
url = "https://user-images.githubusercontent.com/3275025/123925242-4c795b00-d9bd-11eb-9f0c-3c09a5204190.jpg"
img = Image.open(
urllib.request.urlopen(url)
)
t_img = f.to_tensor(img)
img_small_pil = f.resize(img, 128, interpolation=Image.BILINEAR)
img_small_aa = f.to_pil_image(f.resize(t_img, 128, interpolation=Image.BILINEAR, antialias=True))
img_small = f.to_pil_image(f.resize(t_img, 128, interpolation=Image.BILINEAR, antialias=False))
plt.figure(figsize=(3 * 8, 8))
plt.subplot(131)
plt.title("PIL")
plt.imshow(img_small_pil)
plt.subplot(132)
plt.title("Tensor with antialias")
plt.imshow(img_small_aa)
plt.subplot(133)
plt.title("Tensor without antialias")
plt.imshow(img_small)
> 1.9.0 0.10.0 |
Thanks! Looks great. I will try when updated to that version. |
I will have to update my blog post with torchvision 0.10 then =) |
@tcapelle that would be nice! Note that the |
Done! https://tcapelle.github.io/pytorch/fastai/2021/02/26/image_resizing.html |
how can we use vision transforms resize on c++? |
@bnascimento there is no C++ API for vision transforms, but you can use pytorch C++ API which can do a similar resizing: https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1functional_1afb8b9cd051ced01899b6d3142ac2f47c.html#exhale-function-namespacetorch-1-1nn-1-1functional-1afb8b9cd051ced01899b6d3142ac2f47c HTH |
Has this bug been fixed in 0.11 or 0.12? |
It will be impossible to get exactly the same result for torch interpolate and PIL resize for all interpolation modes and scales. Results are compatible and almost equal. For example, here is a test that checks outputs: vision/test/test_functional_tensor.py Line 548 in 59c4de9
We set tolerance to 8.0 while computing mean abs error, data is RGB uint8 ranges between [0-256]. I think we can close this issue. |
@vfdev-5 I re-read this issue and found out that I accidentally put two bugs (anti-alias and bicubic overshoot) in one issue. |
@hjinlee88 I think with the newest and recommended way to do things there is no more both issues:
So, your intial example would look like: import urllib
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
size = 112
img = Image.open(urllib.request.urlopen("https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image01.png"))
tensor_interpolate = transforms.Compose([
transforms.PILToTensor(),
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.ToPILImage()
])
pillow_resize = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC)
])
plt.figure(figsize=(18, 18))
plt.subplot(121)
plt.imshow(tensor_interpolate(img))
plt.title("tensor interpolate")
plt.subplot(122)
plt.imshow(pillow_resize(img))
plt.title("pillow resize") Please let me know if this is sufficient or there is still a problem. Thanks! |
@vfdev-5 Thank you for the update. Has the performance issue with antialias=True with Resize been resolved? |
No, but it does make a small difference |
By chance, it was discovered that even the newer version of Torch (I am 2.1.2) still has this issue in "bicubic" mode, which is that the resized image pixels will have significant unevenness in "bicubic" mode. So are there any bugs that need to be fixed in "bicubic" mode ? my test code is here, just change "img_dir" to one dir that with '.jpg' or '.png' in dir, you can test this idea quickly. |
@lzcchl this is an expected behaviour as you resize a float tensor and cast to uint8 without clamping the range into [0, 255]. See the first note in the docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html#torch.nn.functional.interpolate
Here are two options that could remove the black dots:
resized_tensor = torch.nn.functional.interpolate(
tensor.float(),
size=out_size,
mode='bicubic',
antialias=True
)
resized_tensor = resized_tensor.clamp(0, 255)
resized_tensor = resized_tensor.to(tensor.dtype)
resized_tensor = torch.nn.functional.interpolate(
tensor,
size=out_size,
mode='bicubic',
antialias=True
) Let me know if I misunderstood something from your comment. |
Thank you, this is a great job and you have basically solved my doubts. In my code, I also compared the results of pil, torch, and torchvision, but that's not the focus of my question because I know there are slight differences in their implementation methods, which can lead to differences in digital results, which is understandable. So in my future work, I will use your first suggestion because handling floats is very common. |
Given what I tested, the difference between In [45]: from PIL import Image
...: import torch as th
...: from torchvision.transforms.v2 import functional as F
...:
...: im = Image.open("im.png")
...: x = th.from_numpy(np.array(im)).permute(2,0,1)
...: print(x.dtype)
...:
...: a = im.resize((200,300), resample=Image.BICUBIC)
...: b = thvF.resize(x, (300,200), interpolation=thvF.InterpolationMode.BICUBIC, antialias=True)
...: c = Image.fromarray(x.permute(1,2,0).numpy()).resize((200,300), resample=Image.BICUBIC)
...: print("shapes", im.size, x.shape, a.size, b.shape, c.size)
...: print("diff between Image.resize and torchvision.resize", (th.from_numpy(np.array(a)).permute(2,0,1) - b).abs().float().mean())
...: print("diff between Image.resize and tensor-to-Image.resize", (th.from_numpy(np.array(a) - np.array(c))).abs().float().mean())
torch.uint8
shapes (800, 1000) torch.Size([3, 1000, 800]) (200, 300) torch.Size([3, 300, 200]) (200, 300)
diff between Image.resize and torchvision.resize tensor(0.4633)
diff between Image.resize and tensor-to-Image.resize tensor(0.) |
🐛 Bug
Resize supports tensors by F.interpolate, but the behavior is not the same as Pillow resize.
vision/torchvision/transforms/functional.py
Lines 309 to 312 in f95b053
To Reproduce
Steps to reproduce the behavior:
Expected behavior
Both should have the same or nearly identical output.
Perhaps, it needs blur before interpolate.
Environment
I installed pytorch using the following command:
conda install pytorch torchvision -c pytorch
python collect_env.py
Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Home
GCC version: (MinGW.org GCC-8.2.0-3) 8.2.0
Clang version: Could not collect
CMake version: version 3.18.2
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce RTX 2060
Nvidia driver version: 456.38
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin\cudnn64_7.dll
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.7.0
[pip3] torchvision==0.8.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.0.221 h74a9793_0
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py38hb782905_0
[conda] mkl_fft 1.2.0 py38h45dec08_0
[conda] mkl_random 1.1.1 py38h47e9c7a_0
[conda] numpy 1.19.2 py38hadc3359_0
[conda] numpy-base 1.19.2 py38ha3acd2a_0
[conda] pytorch 1.7.0 py3.8_cuda110_cudnn8_0 pytorch
[conda] torchvision 0.8.1 py38_cu110 pytorch
cc @vfdev-5
The text was updated successfully, but these errors were encountered: