From 2e07bf8977cb7c2c7d03813e763aa662bcb0962c Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 24 Sep 2021 14:55:42 +0200 Subject: [PATCH] WIP on adding gray images support for adjust_contrast (#4477) --- test/common_utils.py | 7 ++++++- test/test_functional_tensor.py | 14 +++++++++----- torchvision/transforms/functional_tensor.py | 9 ++++++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 5aad4a6dd24..4bbf748d3ba 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -128,7 +128,12 @@ def needs_cuda(test_func): def _create_data(height=3, width=3, channels=3, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) - pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) + data = tensor.permute(1, 2, 0).contiguous().cpu().numpy() + mode = "RGB" + if channels == 1: + mode = "L" + data = data[..., 0] + pil_img = Image.fromarray(data, mode=mode) return tensor, pil_img diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 717e2a7cb33..1a7b994d864 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -641,12 +641,14 @@ def backward(ctx, grad_output): assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) -def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): +def check_functional_vs_PIL_vs_scripted( + fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max" +): script_fn = torch.jit.script(fn) torch.manual_seed(15) - tensor, pil_img = _create_data(26, 34, device=device) - batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device) + tensor, pil_img = _create_data(26, 34, channels=channels, device=device) + batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device) if dtype is not None: tensor = F.convert_image_dtype(tensor, dtype) @@ -798,14 +800,16 @@ def test_equalize(device): @pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) @pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) -def test_adjust_contrast(device, dtype, config): +@pytest.mark.parametrize('channels', [1, 3]) +def test_adjust_contrast(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( F.adjust_contrast, F_pil.adjust_contrast, F_t.adjust_contrast, config, device, - dtype + dtype, + channels=channels ) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 61c07433cb6..67676f58dbc 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -169,10 +169,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: _assert_image_tensor(img) - _assert_channels(img, [3]) - + _assert_channels(img, [3, 1]) + c = get_image_num_channels(img) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) + if c == 3: + mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) + else: + mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) return _blend(img, mean, contrast_factor)