Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added antialias option to transforms.functional.resize #3761

Merged
merged 11 commits into from
May 10, 2021
7 changes: 7 additions & 0 deletions android/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ file(GLOB VISION_SRCS
../../torchvision/csrc/ops/*.h
../../torchvision/csrc/ops/*.cpp)

# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and IndexingUtils.h is unavailable on Android build
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h")

add_library(${TARGET} SHARED
${VISION_SRCS}
)
Expand Down
7 changes: 7 additions & 0 deletions ios/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ file(GLOB VISION_SRCS
../torchvision/csrc/ops/*.h
../torchvision/csrc/ops/*.cpp)

# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and using TensorIterator unavailable with iOS
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h")

add_library(${TARGET} STATIC
${VISION_SRCS}
)
Expand Down
47 changes: 47 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,5 +1018,52 @@ def test_perspective_interpolation_warning(tester):
tester.assertTrue(res1.equal(res2))


@pytest.mark.parametrize('device', ["cpu", ])
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, ])
Comment on lines +1021 to +1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the port! Just 2 nits: we probably don't need the device and interpolation parametrizations since they only have one element, so I'd suggest to remove them. Should we need them in the future, it will be very easy to add back

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I agree but as mentioned in the description, in the follow-up PRs we aim to add more interpolation modes and CUDA support, that's why I kept them.

def test_resize_antialias(device, dt, size, interpolation, tester):

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(320, 290, device=device)

if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)

tester.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg=f"{size}, {interpolation}, {dt}"
)

resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)

tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
)
tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max",
msg=f"{size}, {interpolation}, {dt}"
)

if isinstance(size, int):
script_size = [size, ]
else:
script_size = size

resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ def test_resize(self):

self.assertEqual((owidth, oheight), result.size)

with self.assertWarnsRegex(UserWarning, r"Anti-alias option is always applied for PIL Image input"):
t = transforms.Resize(osize, antialias=False)
t(img)

def test_random_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
Loading