diff --git a/test/common_utils.py b/test/common_utils.py index 6c987cf9348..c5b39c451a3 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -402,3 +402,27 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names): kwargs_only = kwargs.copy() kwargs_only.update(dict(zip(arg_names, args))) return kwargs_only + + +def cpu_and_gpu(): + import pytest # noqa + # ignore CPU tests in RE as they're already covered by another contbuild + IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None + IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" + CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' + + devices = [] if IN_RE_WORKER else ['cpu'] + + if torch.cuda.is_available(): + cuda_marks = () + elif IN_FBCODE: + # Dont collect cuda tests on fbcode if the machine doesnt have a GPU + # This avoids skipping the tests. More robust would be to detect if + # we're in sancastle instead of fbcode? + cuda_marks = pytest.mark.dont_collect() + else: + cuda_marks = pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG) + + devices.append(pytest.param('cuda', marks=cuda_marks)) + + return devices diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000000..6e10e4ef071 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,14 @@ +def pytest_configure(config): + # register an additional marker (see pytest_collection_modifyitems) + config.addinivalue_line( + "markers", "dont_collect: marks a test that should not be collected (avoids skipping it)" + ) + + +def pytest_collection_modifyitems(items): + # This hook is called by pytest after it has collected the tests (google its name!) + # We can ignore some tests as we see fit here. In particular we ignore the tests that + # we have marked with the custom 'dont_collect' mark. This avoids skipping the tests, + # since the internal fb infra doesn't like skipping tests. + to_keep = [item for item in items if item.get_closest_marker('dont_collect') is None] + items[:] = to_keep diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d2bc4c8a7bc..1964e3134ec 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,17 +1,20 @@ +import itertools import os import unittest import colorsys import math import numpy as np +import pytest import torch import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional as F +import torchvision.transforms as T from torchvision.transforms import InterpolationMode -from common_utils import TransformsTester +from common_utils import TransformsTester, cpu_and_gpu from typing import Dict, List, Sequence, Tuple @@ -19,6 +22,13 @@ NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC +@pytest.fixture(scope='module') +def tester(): + # instanciation of the Tester class used for equality assertions and other utilities + # TODO: remove this eventually when we don't need the class anymore + return Tester() + + class Tester(TransformsTester): def setUp(self): @@ -759,88 +769,6 @@ def test_rotate(self): res2 = F.rotate(tensor, 45, interpolation=BILINEAR) self.assertTrue(res1.equal(res2)) - def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): - dt = tensor.dtype - for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]: - for r in [NEAREST, ]: - for spoints, epoints in test_configs: - f_pil = int(f[0]) if f is not None and len(f) == 1 else f - out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r, - fill=f_pil) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - - for fn in [F.perspective, scripted_transform]: - out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu() - - if out_tensor.dtype != torch.uint8: - out_tensor = out_tensor.to(torch.uint8) - - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 5% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.05, - msg="{}: {}\n{} vs \n{}".format( - (f, r, dt, spoints, epoints), - ratio_diff_pixels, - out_tensor[0, :7, :7], - out_pil_tensor[0, :7, :7] - ) - ) - - def test_perspective(self): - - from torchvision.transforms import RandomPerspective - - data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)] - scripted_transform = torch.jit.script(F.perspective) - - for tensor, pil_img in data: - - test_configs = [ - [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], - [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], - [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], - ] - n = 10 - test_configs += [ - RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n) - ] - - for dt in [None, torch.float32, torch.float64, torch.float16]: - - if dt == torch.float16 and torch.device(self.device).type == "cpu": - # skip float16 on CPU case - continue - - if dt is not None: - tensor = tensor.to(dtype=dt) - - self._test_perspective(tensor, pil_img, scripted_transform, test_configs) - - batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) - if dt is not None: - batch_tensors = batch_tensors.to(dtype=dt) - - # Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at - # the border may be entirely different due to small rounding errors. - scripted_fn_atol = -1 if (dt == torch.float16 and self.device == "cuda") else 1e-8 - - for spoints, epoints in test_configs: - self._test_fn_on_batch( - batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, - startpoints=spoints, endpoints=epoints, interpolation=NEAREST - ) - - # assert changed type warning - spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] - epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] - with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): - res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) - res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) - self.assertTrue(res1.equal(res2)) - def test_gaussian_blur(self): small_image_tensor = torch.from_numpy( np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) @@ -996,5 +924,99 @@ def test_scale_channel(self): self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu'))) +def _get_data_dims_and_points_for_perspective(): + # Ideally we would parametrize independently over data dims and points, but + # we want to tests on some points that also depend on the data dims. + # Pytest doesn't support covariant parametrization, so we do it somewhat manually here. + + data_dims = [(26, 34), (26, 26)] + points = [ + [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], + [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], + ] + + dims_and_points = list(itertools.product(data_dims, points)) + + # up to here, we could just have used 2 @parametrized. + # Down below is the covarariant part as the points depend on the data dims. + + n = 10 + for dim in data_dims: + points += [ + (dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) + for i in range(n) + ] + return dims_and_points + + +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) +@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, ))) +@pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)]) +def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn, tester): + + if dt == torch.float16 and device == "cpu": + # skip float16 on CPU case + return + + data_dims, (spoints, epoints) = dims_and_points + + tensor, pil_img = tester._create_data(*data_dims, device=device) + if dt is not None: + tensor = tensor.to(dtype=dt) + + interpolation = NEAREST + fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill + out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, + fill=fill_pil) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu() + + if out_tensor.dtype != torch.uint8: + out_tensor = out_tensor.to(torch.uint8) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + assert ratio_diff_pixels < 0.05 + + +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) +@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) +def test_perspective_batch(device, dims_and_points, dt, tester): + + if dt == torch.float16 and device == "cpu": + # skip float16 on CPU case + return + + data_dims, (spoints, epoints) = dims_and_points + + batch_tensors = tester._create_data_batch(*data_dims, num_samples=4, device=device) + if dt is not None: + batch_tensors = batch_tensors.to(dtype=dt) + + # Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at + # the border may be entirely different due to small rounding errors. + scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8 + tester._test_fn_on_batch( + batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, + startpoints=spoints, endpoints=epoints, interpolation=NEAREST + ) + + +def test_perspective_interpolation_warning(tester): + # assert changed type warning + spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] + epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] + tensor = torch.randint(0, 256, (3, 26, 26)) + with tester.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): + res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) + res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) + tester.assertTrue(res1.equal(res2)) + + if __name__ == '__main__': unittest.main()