From ba64d65bc6811f2b173792a640cb4cbe5a750840 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 14 Mar 2024 01:17:21 +0800 Subject: [PATCH] Fast rotation for right angles (#8295) Co-authored-by: Thien Tran --- test/test_transforms_v2.py | 11 +++++++++++ torchvision/transforms/v2/functional/_geometry.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e04c77f9b80..b469a630b4a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1782,6 +1782,17 @@ def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + @pytest.mark.parametrize("size", [(11, 17), (16, 16)]) + @pytest.mark.parametrize("angle", [0, 90, 180, 270]) + @pytest.mark.parametrize("expand", [False, True]) + def test_functional_image_fast_path_correctness(self, size, angle, expand): + image = make_image(size, dtype=torch.uint8, device="cpu") + + actual = F.rotate(image, angle=angle, expand=expand) + expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle, expand=expand)) + + torch.testing.assert_close(actual, expected) + class TestContainerTransforms: class BuiltinTransform(transforms.Transform): diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b681346ab09..2a1250ddf6c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -997,6 +997,21 @@ def rotate_image( center: Optional[List[float]] = None, fill: _FillTypeJIT = None, ) -> torch.Tensor: + angle = angle % 360 # shift angle to [0, 360) range + + # fast path: transpose without affine transform + if center is None: + if angle == 0: + return image.clone() + if angle == 180: + return torch.rot90(image, k=2, dims=(-2, -1)) + + if expand or image.shape[-1] == image.shape[-2]: + if angle == 90: + return torch.rot90(image, k=1, dims=(-2, -1)) + if angle == 270: + return torch.rot90(image, k=3, dims=(-2, -1)) + interpolation = _check_interpolation(interpolation) input_height, input_width = image.shape[-2:]