diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c070c5c1d61..0c4997f1499 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -862,6 +862,21 @@ def test_gaussian_blur(self): msg="{}, {}".format(ksize, sigma) ) + def test_invert(self): + script_invert = torch.jit.script(F.invert) + + img_tensor, pil_img = self._create_data(16, 18, device=self.device) + inverted_img = F.invert(img_tensor) + inverted_pil_img = F.invert(pil_img) + self.compareTensorToPIL(inverted_img, inverted_pil_img) + + # scriptable function test + inverted_img_script = script_invert(img_tensor) + self.assertTrue(inverted_img.equal(inverted_img_script)) + + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.invert) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 30749772d6a..d6b8f48959c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1749,6 +1749,38 @@ def test_gaussian_blur_asserts(self): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): transforms.GaussianBlur(3, "sigma_string") + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_invert(self): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + inv_img = F.invert(img) + + num_samples = 250 + num_inverts = 0 + for _ in range(num_samples): + out = transforms.RandomInvert()(img) + if out == inv_img: + num_inverts += 1 + + p_value = stats.binom_test(num_inverts, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + num_samples = 250 + num_inverts = 0 + for _ in range(num_samples): + out = transforms.RandomInvert(p=0.7)(img) + if out == inv_img: + num_inverts += 1 + + p_value = stats.binom_test(num_inverts, num_samples, p=0.7) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + # Checking if RandomInvert can be printed as string + transforms.RandomInvert().__repr__() + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index aff492b41d6..142b3af847d 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -89,6 +89,9 @@ def test_random_horizontal_flip(self): def test_random_vertical_flip(self): self._test_op('vflip', 'RandomVerticalFlip') + def test_random_invert(self): + self._test_op('invert', 'RandomInvert') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 72baf021f9d..b64d00138dd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1178,3 +1178,21 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa if not isinstance(img, torch.Tensor): output = to_pil_image(output) return output + + +def invert(img: Tensor) -> Tensor: + """Invert the colors of a PIL Image or torch Tensor. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image: Color inverted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.invert(img) + + return F_t.invert(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 51d83f0fd63..17c67355535 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -606,3 +606,23 @@ def to_grayscale(img, num_output_channels): raise ValueError('num_output_channels should be either 1 or 3') return img + + +@torch.jit.unused +def invert(img): + """PRIVATE METHOD. Invert the colors of an image. + + .. warning:: + + Module ``transforms.functional_pil`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + img (PIL Image): Image to have its colors inverted. + + Returns: + PIL Image: Color inverted image Tensor. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.invert(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 0c72a745bba..ce899efbabf 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1179,3 +1179,30 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img + + +def invert(img: Tensor) -> Tensor: + """PRIVATE METHOD. Invert the colors of a grayscale or RGB image. + + .. warning::`` + + Module ``transforms.functional_tensor`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + img (Tensor): Image to have its colors inverted in the form [C, H, W]. + + Returns: + Tensor: Color inverted image Tensor. + """ + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + + _assert_channels(img, [1, 3]) + + bound = 1.0 if img.is_floating_point() else 255.0 + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + return (bound - img.to(dtype)).to(img.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 3b159fd3f22..9295004e4a6 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,7 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"] class Compose: @@ -1699,3 +1699,43 @@ def _setup_angle(x, name, req_sizes=(2, )): _check_sequence_input(x, name, req_sizes) return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose value for random color inversion. + + Returns: + float: Random value which is used to determine whether the random color inversion + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if self.get_params() < self.p: + return F.invert(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p)