diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index b0a93209877..835ce330180 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -81,9 +81,10 @@ def plot(imgs, **imshow_kwargs): ######################### # The RAFT model accepts RGB images. We first get the frames from -# :func:`~torchvision.io.read_video` and resize them to ensure their -# dimensions are divisible by 8. Then we use the transforms bundled into the -# weights in order to preprocess the input and rescale its values to the +# :func:`~torchvision.io.read_video` and resize them to ensure their dimensions +# are divisible by 8. Note that we explicitly use ``antialias=False``, because +# this is how those models were trained. Then we use the transforms bundled into +# the weights in order to preprocess the input and rescale its values to the # required ``[-1, 1]`` interval. from torchvision.models.optical_flow import Raft_Large_Weights @@ -93,8 +94,8 @@ def plot(imgs, **imshow_kwargs): def preprocess(img1_batch, img2_batch): - img1_batch = F.resize(img1_batch, size=[520, 960]) - img2_batch = F.resize(img2_batch, size=[520, 960]) + img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False) + img2_batch = F.resize(img2_batch, size=[520, 960], antialias=False) return transforms(img1_batch, img2_batch) diff --git a/references/depth/stereo/transforms.py b/references/depth/stereo/transforms.py index f9e05febabd..9c4a6bab6d3 100644 --- a/references/depth/stereo/transforms.py +++ b/references/depth/stereo/transforms.py @@ -455,7 +455,11 @@ def forward( INTERP_MODE = self._interpolation_mode_strategy() for img in images: - resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE),) + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the stereo models with antialias=True? + resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),) for dsp in disparities: if dsp is not None: diff --git a/references/optical_flow/transforms.py b/references/optical_flow/transforms.py index 1ca3ca2a872..bc831a2ee52 100644 --- a/references/optical_flow/transforms.py +++ b/references/optical_flow/transforms.py @@ -196,8 +196,12 @@ def forward(self, img1, img2, flow, valid_flow_mask): if torch.rand(1).item() < self.resize_prob: # rescale the images - img1 = F.resize(img1, size=(new_h, new_w)) - img2 = F.resize(img2, size=(new_h, new_w)) + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the OF models with antialias=True? + img1 = F.resize(img1, size=(new_h, new_w), antialias=False) + img2 = F.resize(img2, size=(new_h, new_w), antialias=False) if valid_flow_mask is None: flow = F.resize(flow, size=(new_h, new_w)) flow = flow * torch.tensor([scale_x, scale_y])[:, None, None] diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index ef774052257..f73802c9666 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -15,7 +15,11 @@ def __init__( ): trans = [ transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the video models with antialias=True? + transforms.Resize(resize_size, antialias=False), ] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) @@ -31,7 +35,11 @@ def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), self.transforms = transforms.Compose( [ transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the video models with antialias=True? + transforms.Resize(resize_size, antialias=False), transforms.Normalize(mean=mean, std=std), transforms.CenterCrop(crop_size), ConvertBCHWtoCBHW(), diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index fb9838ec2e5..58ba98bdf74 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -2,6 +2,7 @@ import itertools import math import os +import warnings from functools import partial from typing import Sequence @@ -483,8 +484,8 @@ def test_resize(device, dt, size, max_size, interpolation): tensor = tensor.to(dt) batch_tensors = batch_tensors.to(dt) - resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size) - resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size) + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True) assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] @@ -509,10 +510,12 @@ def test_resize(device, dt, size, max_size, interpolation): else: script_size = size - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size) + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True) assert_equal(resized_tensor, resize_result) - _test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size) + _test_fn_on_batch( + batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True + ) @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -547,7 +550,7 @@ def test_resize_antialias(device, dt, size, interpolation): tensor = tensor.to(dt) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) - resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True) assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] @@ -596,6 +599,23 @@ def test_assert_resize_antialias(interpolation): F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) +def test_resize_antialias_default_warning(): + + img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8) + + match = "The default value of the antialias" + with pytest.warns(UserWarning, match=match): + F.resize(img, size=(20, 20)) + with pytest.warns(UserWarning, match=match): + F.resized_crop(img, 0, 0, 10, 10, size=(20, 20)) + + # For modes that aren't bicubic or bilinear, don't throw a warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + F.resize(img, size=(20, 20), interpolation=NEAREST) + F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]]) @@ -924,7 +944,9 @@ def test_resized_crop(device, mode): # 1) resize to the same size, crop to the same size => should be identity tensor, _ = _create_data(26, 36, device=device) - out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) + out_tensor = F.resized_crop( + tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True + ) assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}") # 2) resize by half and crop a TL corner @@ -939,7 +961,14 @@ def test_resized_crop(device, mode): batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device) _test_fn_on_batch( - batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST + batch_tensors, + F.resized_crop, + top=1, + left=2, + height=20, + width=30, + size=[10, 15], + interpolation=NEAREST, ) diff --git a/test/test_models.py b/test/test_models.py index 97494d64971..e1a288f4eb5 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1050,5 +1050,25 @@ def test_raft(model_fn, scripted): _assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1) +def test_presets_antialias(): + + img = torch.randint(0, 256, size=(1, 3, 224, 224), dtype=torch.uint8) + + match = "The default value of the antialias parameter" + with pytest.warns(UserWarning, match=match): + models.ResNet18_Weights.DEFAULT.transforms()(img) + with pytest.warns(UserWarning, match=match): + models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms()(img) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + models.ResNet18_Weights.DEFAULT.transforms(antialias=True)(img) + models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms(antialias=True)(img) + + models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()(img) + models.video.R3D_18_Weights.DEFAULT.transforms()(img) + models.optical_flow.Raft_Small_Weights.DEFAULT.transforms()(img, img) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 046550209b0..0ed51c44d77 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,6 @@ import itertools import re +import warnings from collections import defaultdict import numpy as np @@ -94,7 +95,7 @@ def parametrize_from_transforms(*transforms): class TestSmoke: @parametrize_from_transforms( transforms.RandomErasing(p=1.0), - transforms.Resize([16, 16]), + transforms.Resize([16, 16], antialias=True), transforms.CenterCrop([16, 16]), transforms.ConvertDtype(), transforms.RandomHorizontalFlip(), @@ -210,7 +211,7 @@ def test_normalize(self, transform, input): @parametrize( [ ( - transforms.RandomResizedCrop([16, 16]), + transforms.RandomResizedCrop([16, 16], antialias=True), itertools.chain( make_images(extra_dims=[(4,)]), make_vanilla_tensor_images(), @@ -1991,6 +1992,70 @@ def test__transform(self, inpt): assert output.dtype == inpt.dtype +# TODO: remove this test in 0.17 when the default of antialias changes to True +def test_antialias_warning(): + pil_img = PIL.Image.new("RGB", size=(10, 10), color=127) + tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8) + tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8) + + match = "The default value of the antialias parameter" + with pytest.warns(UserWarning, match=match): + transforms.Resize((20, 20))(tensor_img) + with pytest.warns(UserWarning, match=match): + transforms.RandomResizedCrop((20, 20))(tensor_img) + with pytest.warns(UserWarning, match=match): + transforms.ScaleJitter((20, 20))(tensor_img) + with pytest.warns(UserWarning, match=match): + transforms.RandomShortestSize((20, 20))(tensor_img) + with pytest.warns(UserWarning, match=match): + transforms.RandomResize(10, 20)(tensor_img) + + with pytest.warns(UserWarning, match=match): + transforms.functional.resize(tensor_img, (20, 20)) + with pytest.warns(UserWarning, match=match): + transforms.functional.resize_image_tensor(tensor_img, (20, 20)) + + with pytest.warns(UserWarning, match=match): + transforms.functional.resize(tensor_video, (20, 20)) + with pytest.warns(UserWarning, match=match): + transforms.functional.resize_video(tensor_video, (20, 20)) + + with pytest.warns(UserWarning, match=match): + datapoints.Image(tensor_img).resize((20, 20)) + with pytest.warns(UserWarning, match=match): + datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20)) + + with pytest.warns(UserWarning, match=match): + datapoints.Video(tensor_video).resize((20, 20)) + with pytest.warns(UserWarning, match=match): + datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20)) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + transforms.Resize((20, 20))(pil_img) + transforms.RandomResizedCrop((20, 20))(pil_img) + transforms.ScaleJitter((20, 20))(pil_img) + transforms.RandomShortestSize((20, 20))(pil_img) + transforms.RandomResize(10, 20)(pil_img) + transforms.functional.resize(pil_img, (20, 20)) + + transforms.Resize((20, 20), antialias=True)(tensor_img) + transforms.RandomResizedCrop((20, 20), antialias=True)(tensor_img) + transforms.ScaleJitter((20, 20), antialias=True)(tensor_img) + transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img) + transforms.RandomResize(10, 20, antialias=True)(tensor_img) + + transforms.functional.resize(tensor_img, (20, 20), antialias=True) + transforms.functional.resize_image_tensor(tensor_img, (20, 20), antialias=True) + transforms.functional.resize(tensor_video, (20, 20), antialias=True) + transforms.functional.resize_video(tensor_video, (20, 20), antialias=True) + + datapoints.Image(tensor_img).resize((20, 20), antialias=True) + datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True) + datapoints.Video(tensor_video).resize((20, 20), antialias=True) + datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True) + + @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("label_type", (torch.Tensor, int)) @pytest.mark.parametrize("dataset_return_type", (dict, tuple)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 214f2963bfe..a9074909cf0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,6 +2,7 @@ import os import random import re +import warnings from functools import partial import numpy as np @@ -319,7 +320,7 @@ def test_randomresized_params(): scale_range = (scale_min, scale_min + round(random.random(), 2)) aspect_min = max(round(random.random(), 2), epsilon) aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2)) - randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range) + randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range, antialias=True) i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range) aspect_ratio_obtained = w / h assert ( @@ -366,7 +367,7 @@ def test_randomresized_params(): def test_resize(height, width, osize, max_size): img = Image.new("RGB", size=(width, height), color=127) - t = transforms.Resize(osize, max_size=max_size) + t = transforms.Resize(osize, max_size=max_size, antialias=True) result = t(img) msg = f"{height}, {width} - {osize} - {max_size}" @@ -424,7 +425,7 @@ def test_resize_sequence_output(height, width, osize): img = Image.new("RGB", size=(width, height), color=127) oheight, owidth = osize - t = transforms.Resize(osize) + t = transforms.Resize(osize, antialias=True) result = t(img) assert (owidth, oheight) == result.size @@ -439,6 +440,16 @@ def test_resize_antialias_error(): t(img) +def test_resize_antialias_default_warning(): + + img = Image.new("RGB", size=(10, 10), color=127) + # We make sure we don't warn for PIL images since the default behaviour doesn't change + with warnings.catch_warnings(): + warnings.simplefilter("error") + transforms.Resize((20, 20))(img) + transforms.RandomResizedCrop((20, 20))(img) + + @pytest.mark.parametrize("height, width", ((32, 64), (64, 32))) def test_resize_size_equals_small_edge_size(height, width): # Non-regression test for https://github.com/pytorch/vision/issues/5405 @@ -447,7 +458,7 @@ def test_resize_size_equals_small_edge_size(height, width): img = Image.new("RGB", size=(width, height), color=127) small_edge = min(height, width) - t = transforms.Resize(small_edge, max_size=max_size) + t = transforms.Resize(small_edge, max_size=max_size, antialias=True) result = t(img) assert max(result.size) == max_size @@ -1424,11 +1435,11 @@ def test_random_choice(proba_passthrough, seed): def test_random_order(): random_state = random.getstate() random.seed(42) - random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)]) + random_order_transform = transforms.RandomOrder([transforms.Resize(20, antialias=True), transforms.CenterCrop(10)]) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 num_normal_order = 0 - resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img)) + resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20, antialias=True)(img)) for _ in range(num_samples): out = random_order_transform(img) if out == resize_crop_out: diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 1a1de659a76..b58e2420338 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -1,5 +1,6 @@ import os import sys +import warnings import numpy as np import pytest @@ -371,7 +372,7 @@ class TestResize: def test_resize_int(self, size): # TODO: Minimal check for bug-fix, improve this later x = torch.rand(3, 32, 46) - t = T.Resize(size=size) + t = T.Resize(size=size, antialias=True) y = t(x) # If size is an int, smaller edge of the image will be matched to this number. # i.e, if height > width, then image will be rescaled to (size * height / width, size). @@ -394,13 +395,13 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): if max_size is not None and len(size) != 1: pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified") - transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size) + transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size, antialias=True) s_transform = torch.jit.script(transform) _test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save_load(self, tmpdir): - fn = T.Resize(size=[32]) + fn = T.Resize(size=[32], antialias=True) _test_fn_save_load(fn, tmpdir) @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -424,9 +425,25 @@ def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resized_crop_save_load(self, tmpdir): - fn = T.RandomResizedCrop(size=[32]) + fn = T.RandomResizedCrop(size=[32], antialias=True) _test_fn_save_load(fn, tmpdir) + def test_antialias_default_warning(self): + + img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8) + + match = "The default value of the antialias" + with pytest.warns(UserWarning, match=match): + T.Resize((20, 20))(img) + with pytest.warns(UserWarning, match=match): + T.RandomResizedCrop((20, 20))(img) + + # For modes that aren't bicubic or bilinear, don't throw a warning + with warnings.catch_warnings(): + warnings.simplefilter("error") + T.Resize((20, 20), interpolation=NEAREST)(img) + T.RandomResizedCrop((20, 20), interpolation=NEAREST)(img) + def _test_random_affine_helper(device, **kwargs): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) diff --git a/torchvision/prototype/datapoints/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py index 1abcb395945..718c3c2ade8 100644 --- a/torchvision/prototype/datapoints/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -78,7 +78,7 @@ def resize( # type: ignore[override] size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> BoundingBox: output, spatial_size = self._F.resize_bounding_box( self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size @@ -105,7 +105,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> BoundingBox: output, spatial_size = self._F.resized_crop_bounding_box( self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 89c08a86477..d75a2211071 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -145,7 +145,7 @@ def resize( # type: ignore[override] size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Datapoint: return self @@ -163,7 +163,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Datapoint: return self diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index e999d8243e3..bbd06de707a 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -64,7 +64,7 @@ def resize( # type: ignore[override] size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Image: output = self._F.resize_image_tensor( self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias @@ -87,7 +87,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Image: output = self._F.resized_crop_image_tensor( self.as_subclass(torch.Tensor), diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index 55476cd503d..dec26f80af1 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -55,7 +55,7 @@ def resize( # type: ignore[override] size: List[int], interpolation: InterpolationMode = InterpolationMode.NEAREST, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Mask: output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size) return Mask.wrap_like(self, output) @@ -76,7 +76,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.NEAREST, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Mask: output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size) return Mask.wrap_like(self, output) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 5cc8370cd7b..2f628f2efc4 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -59,7 +59,7 @@ def resize( # type: ignore[override] size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Video: output = self._F.resize_video( self.as_subclass(torch.Tensor), @@ -86,7 +86,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Video: output = self._F.resized_crop_video( self.as_subclass(torch.Tensor), diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 70ae972d9e2..b8c8d10ae1d 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -47,7 +47,7 @@ def __init__( size: Union[int, Sequence[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() @@ -95,7 +95,7 @@ def __init__( scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -761,7 +761,7 @@ def __init__( target_size: Tuple[int, int], scale_range: Tuple[float, float] = (0.1, 2.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() self.target_size = target_size @@ -789,7 +789,7 @@ def __init__( min_size: Union[List[int], Tuple[int], int], max_size: Optional[int] = None, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) @@ -936,7 +936,7 @@ def __init__( min_size: int, max_size: int, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.min_size = min_size diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index a6980f3e135..86300b0494b 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -41,10 +41,14 @@ def __init__( def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]: def _process_image(img: PIL.Image.Image) -> Tensor: - if self.resize_size is not None: - img = F.resize(img, self.resize_size, interpolation=self.interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) + if self.resize_size is not None: + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the stereo models with antialias=True? + img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=False) if self.use_gray_scale is True: img = F.rgb_to_grayscale(img) img = F.convert_image_dtype(img, torch.float) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index aa16dc0afed..c7e80cb417f 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -10,6 +10,7 @@ from torchvision.prototype import datapoints from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional import ( + _check_antialias, _compute_resized_output_size as __compute_resized_output_size, _get_perspective_coeffs, InterpolationMode, @@ -143,14 +144,18 @@ def resize_image_tensor( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: + antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation) + assert not isinstance(antialias, str) antialias = False if antialias is None else antialias align_corners: Optional[bool] = None if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC: align_corners = False - elif antialias: - raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") + else: + # The default of antialias should be True from 0.17, so we don't warn or + # error if other interpolation modes are used. This is documented. + antialias = False shape = image.shape num_channels, old_height, old_width = shape[-3:] @@ -225,7 +230,7 @@ def resize_video( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -235,7 +240,7 @@ def resize( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(resize) @@ -1761,7 +1766,7 @@ def resized_crop_image_tensor( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: image = crop_image_tensor(image, top, left, height, width) return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) @@ -1814,7 +1819,7 @@ def resized_crop_video( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: return resized_crop_image_tensor( video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation @@ -1829,7 +1834,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints.InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(resized_crop) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 33b94d01c9d..ccbe425f2ac 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -2,7 +2,7 @@ This file is part of the private API. Please do not use directly these classes as they will be modified on future versions without warning. The classes should be accessed only via the transforms argument of Weights. """ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn, Tensor @@ -44,6 +44,7 @@ def __init__( mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.crop_size = [crop_size] @@ -51,9 +52,10 @@ def __init__( self.mean = list(mean) self.std = list(std) self.interpolation = interpolation + self.antialias = antialias def forward(self, img: Tensor) -> Tensor: - img = F.resize(img, self.resize_size, interpolation=self.interpolation) + img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias) img = F.center_crop(img, self.crop_size) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) @@ -105,7 +107,11 @@ def forward(self, vid: Tensor) -> Tensor: N, T, C, H, W = vid.shape vid = vid.view(-1, C, H, W) - vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the video models with antialias=True? + vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False) vid = F.center_crop(vid, self.crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self.mean, std=self.std) @@ -145,16 +151,18 @@ def __init__( mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.resize_size = [resize_size] if resize_size is not None else None self.mean = list(mean) self.std = list(std) self.interpolation = interpolation + self.antialias = antialias def forward(self, img: Tensor) -> Tensor: if isinstance(self.resize_size, list): - img = F.resize(img, self.resize_size, interpolation=self.interpolation) + img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index abf827a08c7..76c79df93d1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -393,7 +393,7 @@ def resize( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Tensor: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected @@ -429,10 +429,24 @@ def resize( smaller edge may be shorter than ``size``. This is only supported if ``size`` is an int (or a sequence of length 1 in torchscript mode). - antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias - is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. - This can help making the output for PIL images and tensors closer. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True``: will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The current default is ``None`` **but will change to** ``True`` **in + v0.17** for the PIL and Tensor backends to be consistent. Returns: PIL Image or Tensor: Resized image. @@ -462,6 +476,8 @@ def resize( if (image_height, image_width) == output_size: return img + antialias = _check_antialias(img, antialias, interpolation) + if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") @@ -594,7 +610,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Tensor: """Crop the given image and resize it to desired size. If the image is torch Tensor, it is expected @@ -614,10 +630,24 @@ def resized_crop( Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. - antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias - is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. - This can help making the output for PIL images and tensors closer. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True``: will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The current default is ``None`` **but will change to** ``True`` **in + v0.17** for the PIL and Tensor backends to be consistent. Returns: PIL Image or Tensor: Cropped image. """ @@ -1537,3 +1567,28 @@ def elastic_transform( if not isinstance(img, torch.Tensor): output = to_pil_image(output, mode=img.mode) return output + + +# TODO in v0.17: remove this helper and change default of antialias to True everywhere +def _check_antialias( + img: Tensor, antialias: Optional[Union[str, bool]], interpolation: InterpolationMode +) -> Optional[bool]: + if isinstance(antialias, str): # it should be "warn", but we don't bother checking against that + if isinstance(img, Tensor) and ( + interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC + ): + warnings.warn( + "The default value of the antialias parameter of all the resizing transforms " + "(Resize(), RandomResizedCrop(), etc.) " + "will change from None to True in v0.17, " + "in order to be consistent across the PIL and Tensor backends. " + "To suppress this warning, directly pass " + "antialias=True (recommended, future default), antialias=None (current default, " + "which means False for Tensors and True for PIL), " + "or antialias=False (only works on Tensors - PIL will still use antialiasing). " + "This also applies if you are using the inference transforms from the models weights: " + "update the call to weights.transforms(antialias=True)." + ) + antialias = None + + return antialias diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 30414bf1cd6..d0e7c17882b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -440,6 +440,8 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", + # TODO: in v0.17, change the default to True. This will a private function + # by then, so we don't care about warning here. antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) @@ -451,7 +453,11 @@ def resize( antialias = False if antialias and interpolation not in ["bilinear", "bicubic"]: - raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") + # We manually set it to False to avoid an error downstream in interpolate() + # This behaviour is documented: the parameter is irrelevant for modes + # that are not bilinear or bicubic. We used to raise an error here, but + # now we don't as True is the default. + antialias = False img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index e39e04c3478..88cc1c0d978 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -306,13 +306,27 @@ class Resize(torch.nn.Module): smaller edge may be shorter than ``size``. This is only supported if ``size`` is an int (or a sequence of length 1 in torchscript mode). - antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias - is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. - This can help making the output for PIL images and tensors closer. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True``: will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The current default is ``None`` **but will change to** ``True`` **in + v0.17** for the PIL and Tensor backends to be consistent. """ - def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): + def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"): super().__init__() _log_api_usage_once(self) if not isinstance(size, (int, Sequence)): @@ -847,10 +861,24 @@ class RandomResizedCrop(torch.nn.Module): :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. - antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias - is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes. - This can help making the output for PIL images and tensors closer. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True``: will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The current default is ``None`` **but will change to** ``True`` **in + v0.17** for the PIL and Tensor backends to be consistent. """ def __init__( @@ -859,7 +887,7 @@ def __init__( scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() _log_api_usage_once(self) @@ -874,6 +902,7 @@ def __init__( self.interpolation = interpolation self.antialias = antialias + self.interpolation = interpolation self.scale = scale self.ratio = ratio