Skip to content

Commit

Permalink
Change default of antialias parameter from None to 'warn' (#7160)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
3 people authored Feb 13, 2023
1 parent 8fdaeb0 commit b030e93
Show file tree
Hide file tree
Showing 21 changed files with 345 additions and 79 deletions.
11 changes: 6 additions & 5 deletions gallery/plot_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ def plot(imgs, **imshow_kwargs):

#########################
# The RAFT model accepts RGB images. We first get the frames from
# :func:`~torchvision.io.read_video` and resize them to ensure their
# dimensions are divisible by 8. Then we use the transforms bundled into the
# weights in order to preprocess the input and rescale its values to the
# :func:`~torchvision.io.read_video` and resize them to ensure their dimensions
# are divisible by 8. Note that we explicitly use ``antialias=False``, because
# this is how those models were trained. Then we use the transforms bundled into
# the weights in order to preprocess the input and rescale its values to the
# required ``[-1, 1]`` interval.

from torchvision.models.optical_flow import Raft_Large_Weights
Expand All @@ -93,8 +94,8 @@ def plot(imgs, **imshow_kwargs):


def preprocess(img1_batch, img2_batch):
img1_batch = F.resize(img1_batch, size=[520, 960])
img2_batch = F.resize(img2_batch, size=[520, 960])
img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
img2_batch = F.resize(img2_batch, size=[520, 960], antialias=False)
return transforms(img1_batch, img2_batch)


Expand Down
6 changes: 5 additions & 1 deletion references/depth/stereo/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,11 @@ def forward(
INTERP_MODE = self._interpolation_mode_strategy()

for img in images:
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE),)
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)

for dsp in disparities:
if dsp is not None:
Expand Down
8 changes: 6 additions & 2 deletions references/optical_flow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ def forward(self, img1, img2, flow, valid_flow_mask):

if torch.rand(1).item() < self.resize_prob:
# rescale the images
img1 = F.resize(img1, size=(new_h, new_w))
img2 = F.resize(img2, size=(new_h, new_w))
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the OF models with antialias=True?
img1 = F.resize(img1, size=(new_h, new_w), antialias=False)
img2 = F.resize(img2, size=(new_h, new_w), antialias=False)
if valid_flow_mask is None:
flow = F.resize(flow, size=(new_h, new_w))
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
Expand Down
12 changes: 10 additions & 2 deletions references/video_classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def __init__(
):
trans = [
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the video models with antialias=True?
transforms.Resize(resize_size, antialias=False),
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
Expand All @@ -31,7 +35,11 @@ def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645),
self.transforms = transforms.Compose(
[
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the video models with antialias=True?
transforms.Resize(resize_size, antialias=False),
transforms.Normalize(mean=mean, std=std),
transforms.CenterCrop(crop_size),
ConvertBCHWtoCBHW(),
Expand Down
43 changes: 36 additions & 7 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import math
import os
import warnings
from functools import partial
from typing import Sequence

Expand Down Expand Up @@ -483,8 +484,8 @@ def test_resize(device, dt, size, max_size, interpolation):
tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)

assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]

Expand All @@ -509,10 +510,12 @@ def test_resize(device, dt, size, max_size, interpolation):
else:
script_size = size

resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size)
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
assert_equal(resized_tensor, resize_result)

_test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size)
_test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
)


@pytest.mark.parametrize("device", cpu_and_gpu())
Expand Down Expand Up @@ -547,7 +550,7 @@ def test_resize_antialias(device, dt, size, interpolation):
tensor = tensor.to(dt)

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)

assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]

Expand Down Expand Up @@ -596,6 +599,23 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)


def test_resize_antialias_default_warning():

img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)

match = "The default value of the antialias"
with pytest.warns(UserWarning, match=match):
F.resize(img, size=(20, 20))
with pytest.warns(UserWarning, match=match):
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20))

# For modes that aren't bicubic or bilinear, don't throw a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
F.resize(img, size=(20, 20), interpolation=NEAREST)
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]])
Expand Down Expand Up @@ -924,7 +944,9 @@ def test_resized_crop(device, mode):
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = _create_data(26, 36, device=device)

out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
out_tensor = F.resized_crop(
tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
)
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")

# 2) resize by half and crop a TL corner
Expand All @@ -939,7 +961,14 @@ def test_resized_crop(device, mode):

batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
_test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
batch_tensors,
F.resized_crop,
top=1,
left=2,
height=20,
width=30,
size=[10, 15],
interpolation=NEAREST,
)


Expand Down
20 changes: 20 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,5 +1050,25 @@ def test_raft(model_fn, scripted):
_assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)


def test_presets_antialias():

img = torch.randint(0, 256, size=(1, 3, 224, 224), dtype=torch.uint8)

match = "The default value of the antialias parameter"
with pytest.warns(UserWarning, match=match):
models.ResNet18_Weights.DEFAULT.transforms()(img)
with pytest.warns(UserWarning, match=match):
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms()(img)

with warnings.catch_warnings():
warnings.simplefilter("error")
models.ResNet18_Weights.DEFAULT.transforms(antialias=True)(img)
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms(antialias=True)(img)

models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()(img)
models.video.R3D_18_Weights.DEFAULT.transforms()(img)
models.optical_flow.Raft_Small_Weights.DEFAULT.transforms()(img, img)


if __name__ == "__main__":
pytest.main([__file__])
69 changes: 67 additions & 2 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import re
import warnings
from collections import defaultdict

import numpy as np
Expand Down Expand Up @@ -94,7 +95,7 @@ def parametrize_from_transforms(*transforms):
class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.Resize([16, 16], antialias=True),
transforms.CenterCrop([16, 16]),
transforms.ConvertDtype(),
transforms.RandomHorizontalFlip(),
Expand Down Expand Up @@ -210,7 +211,7 @@ def test_normalize(self, transform, input):
@parametrize(
[
(
transforms.RandomResizedCrop([16, 16]),
transforms.RandomResizedCrop([16, 16], antialias=True),
itertools.chain(
make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(),
Expand Down Expand Up @@ -1991,6 +1992,70 @@ def test__transform(self, inpt):
assert output.dtype == inpt.dtype


# TODO: remove this test in 0.17 when the default of antialias changes to True
def test_antialias_warning():
pil_img = PIL.Image.new("RGB", size=(10, 10), color=127)
tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8)
tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8)

match = "The default value of the antialias parameter"
with pytest.warns(UserWarning, match=match):
transforms.Resize((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.RandomResizedCrop((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.ScaleJitter((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.RandomShortestSize((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.RandomResize(10, 20)(tensor_img)

with pytest.warns(UserWarning, match=match):
transforms.functional.resize(tensor_img, (20, 20))
with pytest.warns(UserWarning, match=match):
transforms.functional.resize_image_tensor(tensor_img, (20, 20))

with pytest.warns(UserWarning, match=match):
transforms.functional.resize(tensor_video, (20, 20))
with pytest.warns(UserWarning, match=match):
transforms.functional.resize_video(tensor_video, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resize((20, 20))
with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resize((20, 20))
with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))

with warnings.catch_warnings():
warnings.simplefilter("error")
transforms.Resize((20, 20))(pil_img)
transforms.RandomResizedCrop((20, 20))(pil_img)
transforms.ScaleJitter((20, 20))(pil_img)
transforms.RandomShortestSize((20, 20))(pil_img)
transforms.RandomResize(10, 20)(pil_img)
transforms.functional.resize(pil_img, (20, 20))

transforms.Resize((20, 20), antialias=True)(tensor_img)
transforms.RandomResizedCrop((20, 20), antialias=True)(tensor_img)
transforms.ScaleJitter((20, 20), antialias=True)(tensor_img)
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)

transforms.functional.resize(tensor_img, (20, 20), antialias=True)
transforms.functional.resize_image_tensor(tensor_img, (20, 20), antialias=True)
transforms.functional.resize(tensor_video, (20, 20), antialias=True)
transforms.functional.resize_video(tensor_video, (20, 20), antialias=True)

datapoints.Image(tensor_img).resize((20, 20), antialias=True)
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video).resize((20, 20), antialias=True)
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
Expand Down
23 changes: 17 additions & 6 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import re
import warnings
from functools import partial

import numpy as np
Expand Down Expand Up @@ -319,7 +320,7 @@ def test_randomresized_params():
scale_range = (scale_min, scale_min + round(random.random(), 2))
aspect_min = max(round(random.random(), 2), epsilon)
aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range, antialias=True)
i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
aspect_ratio_obtained = w / h
assert (
Expand Down Expand Up @@ -366,7 +367,7 @@ def test_randomresized_params():
def test_resize(height, width, osize, max_size):
img = Image.new("RGB", size=(width, height), color=127)

t = transforms.Resize(osize, max_size=max_size)
t = transforms.Resize(osize, max_size=max_size, antialias=True)
result = t(img)

msg = f"{height}, {width} - {osize} - {max_size}"
Expand Down Expand Up @@ -424,7 +425,7 @@ def test_resize_sequence_output(height, width, osize):
img = Image.new("RGB", size=(width, height), color=127)
oheight, owidth = osize

t = transforms.Resize(osize)
t = transforms.Resize(osize, antialias=True)
result = t(img)

assert (owidth, oheight) == result.size
Expand All @@ -439,6 +440,16 @@ def test_resize_antialias_error():
t(img)


def test_resize_antialias_default_warning():

img = Image.new("RGB", size=(10, 10), color=127)
# We make sure we don't warn for PIL images since the default behaviour doesn't change
with warnings.catch_warnings():
warnings.simplefilter("error")
transforms.Resize((20, 20))(img)
transforms.RandomResizedCrop((20, 20))(img)


@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
def test_resize_size_equals_small_edge_size(height, width):
# Non-regression test for https://github.com/pytorch/vision/issues/5405
Expand All @@ -447,7 +458,7 @@ def test_resize_size_equals_small_edge_size(height, width):
img = Image.new("RGB", size=(width, height), color=127)

small_edge = min(height, width)
t = transforms.Resize(small_edge, max_size=max_size)
t = transforms.Resize(small_edge, max_size=max_size, antialias=True)
result = t(img)
assert max(result.size) == max_size

Expand Down Expand Up @@ -1424,11 +1435,11 @@ def test_random_choice(proba_passthrough, seed):
def test_random_order():
random_state = random.getstate()
random.seed(42)
random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)])
random_order_transform = transforms.RandomOrder([transforms.Resize(20, antialias=True), transforms.CenterCrop(10)])
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_normal_order = 0
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20, antialias=True)(img))
for _ in range(num_samples):
out = random_order_transform(img)
if out == resize_crop_out:
Expand Down
Loading

0 comments on commit b030e93

Please sign in to comment.