Skip to content

Commit

Permalink
Check num of channels on adjust_* transformations (pytorch#3069)
Browse files Browse the repository at this point in the history
* Fixing upperbound value on tests and documentation.

* Limit the number of channels on adjust_* transoforms.
  • Loading branch information
datumbox authored and vfdev-5 committed Dec 4, 2020
1 parent f3de020 commit b24ed2e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,13 @@ def freeze_rng_state():
class TransformsTester(unittest.TestCase):

def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img

def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 255,
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
Expand Down
28 changes: 21 additions & 7 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Dict, Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -45,6 +45,12 @@ def _max_value(dtype: torch.dtype) -> float:
return max_value.item()


def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = _get_image_num_channels(img)
if c not in permitted:
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly
Expand Down Expand Up @@ -210,9 +216,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
c = img.shape[-3]
if c != 3:
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
_assert_channels(img, [3])

if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
Expand All @@ -230,7 +234,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:


def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""PRIVATE METHOD. Adjust brightness of an RGB image.
"""PRIVATE METHOD. Adjust brightness of a Grayscale or RGB image.
.. warning::
Expand All @@ -252,6 +256,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [1, 3])

return _blend(img, torch.zeros_like(img), brightness_factor)


Expand All @@ -278,14 +284,16 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [3])

dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)

return _blend(img, mean, contrast_factor)


def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""PRIVATE METHOD. Adjust hue of an image.
"""PRIVATE METHOD. Adjust hue of an RGB image.
.. warning::
Expand Down Expand Up @@ -320,6 +328,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('Input img should be Tensor image')

_assert_channels(img, [3])

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
Expand Down Expand Up @@ -359,11 +369,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [3])

return _blend(img, rgb_to_grayscale(img), saturation_factor)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""PRIVATE METHOD. Adjust gamma of an RGB image.
r"""PRIVATE METHOD. Adjust gamma of a Grayscale or RGB image.
.. warning::
Expand Down Expand Up @@ -391,6 +403,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError('Input img should be a Tensor.')

_assert_channels(img, [1, 3])

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

Expand Down

0 comments on commit b24ed2e

Please sign in to comment.