From bc6ef1dd49744364258a462645aad935a2133098 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 2 Oct 2023 21:52:57 +0200 Subject: [PATCH 1/2] port container transforms --- test/test_transforms_v2.py | 29 ------ test/test_transforms_v2_consistency.py | 95 -------------------- test/test_transforms_v2_refactored.py | 113 ++++++++++++++++++++++-- torchvision/transforms/v2/_container.py | 16 ++-- 4 files changed, 116 insertions(+), 137 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 03f9e906675..aa11b83b61a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -122,35 +122,6 @@ def test_check_transformed_types(self, inpt_type, mocker): t(inpt) -class TestContainers: - @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]) - def test_assertions(self, transform_cls): - with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"): - transform_cls(transforms.RandomCrop(28)) - - @pytest.mark.parametrize("transform_cls", [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]) - @pytest.mark.parametrize( - "trfms", - [ - [transforms.Pad(2), transforms.RandomCrop(28)], - [lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)], - [transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)], - ], - ) - def test_ctor(self, transform_cls, trfms): - c = transform_cls(trfms) - inpt = torch.rand(1, 3, 32, 32) - output = c(inpt) - assert isinstance(output, torch.Tensor) - assert output.ndim == 4 - - -class TestRandomChoice: - def test_assertions(self): - with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"): - transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1]) - - class TestRandomIoUCrop: @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 397d42101ce..12e76e89f43 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -11,7 +11,6 @@ import torch import torchvision.transforms.v2 as v2_transforms from common_utils import assert_close, assert_equal, set_rng_seed -from torch import nn from torchvision import transforms as legacy_transforms, tv_tensors from torchvision._utils import sequence_to_str @@ -82,22 +81,6 @@ def __init__( # images given that the transform does nothing but call it anyway. supports_pil=False, ), - ConsistencyConfig( - v2_transforms.Compose, - legacy_transforms.Compose, - ), - ConsistencyConfig( - v2_transforms.RandomApply, - legacy_transforms.RandomApply, - ), - ConsistencyConfig( - v2_transforms.RandomChoice, - legacy_transforms.RandomChoice, - ), - ConsistencyConfig( - v2_transforms.RandomOrder, - legacy_transforms.RandomOrder, - ), ] @@ -298,84 +281,6 @@ def test_jit_consistency(config, args_kwargs): assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs) -class TestContainerTransforms: - """ - Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for - consistency automatically tests the wrapped transforms consistency. - - Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones - that were already tested for consistency above. - """ - - def test_compose(self): - prototype_transform = v2_transforms.Compose( - [ - v2_transforms.Resize(256), - v2_transforms.CenterCrop(224), - ] - ) - legacy_transform = legacy_transforms.Compose( - [ - legacy_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ] - ) - - # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) - - @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) - @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) - def test_random_apply(self, p, sequence_type): - prototype_transform = v2_transforms.RandomApply( - sequence_type( - [ - v2_transforms.Resize(256), - v2_transforms.CenterCrop(224), - ] - ), - p=p, - ) - legacy_transform = legacy_transforms.RandomApply( - sequence_type( - [ - legacy_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ] - ), - p=p, - ) - - # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) - - if sequence_type is nn.ModuleList: - # quick and dirty test that it is jit-scriptable - scripted = torch.jit.script(prototype_transform) - scripted(torch.rand(1, 3, 300, 300)) - - # We can't test other values for `p` since the random parameter generation is different - @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)]) - def test_random_choice(self, probabilities): - prototype_transform = v2_transforms.RandomChoice( - [ - v2_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ], - p=probabilities, - ) - legacy_transform = legacy_transforms.RandomChoice( - [ - legacy_transforms.Resize(256), - legacy_transforms.CenterCrop(224), - ], - p=probabilities, - ) - - # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1)) - - class TestToTensorTransforms: def test_pil_to_tensor(self): prototype_transform = v2_transforms.PILToTensor() diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index b700b159ec5..0eafddf590f 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_ if check_v1_compatibility: _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility)) + return output + def transform_cls_to_functional(transform_cls, **transform_specific_kwargs): def wrapper(input, *args, **kwargs): @@ -1773,7 +1775,7 @@ def test_transform_unknown_fill_error(self): transforms.RandomAffine(degrees=0, fill="fill") -class TestCompose: +class TestContainerTransforms: class BuiltinTransform(transforms.Transform): def _transform(self, inpt, params): return inpt @@ -1788,7 +1790,10 @@ def forward(self, image, label): return image, label @pytest.mark.parametrize( - "transform_clss", + "transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder] + ) + @pytest.mark.parametrize( + "wrapped_transform_clss", [ [BuiltinTransform], [PackedInputTransform], @@ -1803,12 +1808,12 @@ def forward(self, image, label): ], ) @pytest.mark.parametrize("unpack", [True, False]) - def test_packed_unpacked(self, transform_clss, unpack): - needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) - needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack): + needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss) + needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss) assert not (needs_packed_inputs and needs_unpacked_inputs) - transform = transforms.Compose([cls() for cls in transform_clss]) + transform = transform_cls([cls() for cls in wrapped_transform_clss]) image = make_image() label = 3 @@ -1833,6 +1838,102 @@ def call_transform(): assert output[0] is image assert output[1] is label + def test_compose(self): + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + ] + ) + + input = make_image() + + actual = check_transform(transform, input) + expected = F.vertical_flip(F.horizontal_flip(input)) + + assert_equal(actual, expected) + + @pytest.mark.parametrize("p", [0.0, 1.0]) + @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) + def test_random_apply(self, p, sequence_type): + transform = transforms.RandomApply( + sequence_type( + [ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + ] + ), + p=p, + ) + + # This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility + # check + input = make_image_tensor() + output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList)) + + if p == 1: + assert_equal(output, F.vertical_flip(F.horizontal_flip(input))) + else: + assert output is input + + @pytest.mark.parametrize("p", [(0, 1), (1, 0)]) + def test_random_choice(self, p): + transform = transforms.RandomChoice( + [ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + ], + p=p, + ) + + input = make_image() + output = check_transform(transform, input) + + p_horz, p_vert = p + if p_horz: + assert_equal(output, F.horizontal_flip(input)) + else: + assert_equal(output, F.vertical_flip(input)) + + def test_random_order(self): + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1), + ] + ) + + input = make_image() + + actual = check_transform(transform, input) + # horizontal and vertical flip are commutative. Meaning, although the order in the transform is indeed random, + # we don't need to care here. + expected = F.vertical_flip(F.horizontal_flip(input)) + + assert_equal(actual, expected) + + def test_errors(self): + for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]: + with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"): + cls(lambda x: x) + + with pytest.raises(ValueError, match="at least one transform"): + transforms.Compose([]) + + for p in [-1, 2]: + with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")): + transforms.RandomApply([lambda x: x], p=p) + + for transforms_, p in [ + ( + [lambda x: x], + [], + ), + ([], [1.0]), + ]: + with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"): + transforms.RandomChoice(transforms_, p=p) + class TestToDtype: @pytest.mark.parametrize( diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 8f591c49707..d57c2a72009 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -100,14 +100,15 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: return {"transforms": self.transforms, "p": self.p} def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + needs_unpacking = len(inputs) > 1 if torch.rand(1) >= self.p: - return sample + return inputs if needs_unpacking else inputs[0] for transform in self.transforms: - sample = transform(sample) - return sample + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs def extra_repr(self) -> str: format_string = [] @@ -173,8 +174,9 @@ def __init__(self, transforms: Sequence[Callable]) -> None: self.transforms = transforms def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + needs_unpacking = len(inputs) > 1 for idx in torch.randperm(len(self.transforms)): transform = self.transforms[idx] - sample = transform(sample) - return sample + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs From 195a458947f20bff89d127f1b41b6d3643e56670 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 2 Oct 2023 21:58:10 +0200 Subject: [PATCH 2/2] cleanup --- test/test_transforms_v2_refactored.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 0eafddf590f..d9e271ce7b9 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1924,13 +1924,7 @@ def test_errors(self): with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")): transforms.RandomApply([lambda x: x], p=p) - for transforms_, p in [ - ( - [lambda x: x], - [], - ), - ([], [1.0]), - ]: + for transforms_, p in [([lambda x: x], []), ([], [1.0])]: with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"): transforms.RandomChoice(transforms_, p=p)