Skip to content

Commit

Permalink
WIP on adding gray images support for adjust_contrast (pytorch#4477)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored and cyyever committed Nov 16, 2021
1 parent 14933ac commit 2e07bf8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
7 changes: 6 additions & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ def needs_cuda(test_func):
def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
mode = "RGB"
if channels == 1:
mode = "L"
data = data[..., 0]
pil_img = Image.fromarray(data, mode=mode)
return tensor, pil_img


Expand Down
14 changes: 9 additions & 5 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,14 @@ def backward(ctx, grad_output):
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)


def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
def check_functional_vs_PIL_vs_scripted(
fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
):

script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = _create_data(26, 34, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
tensor, pil_img = _create_data(26, 34, channels=channels, device=device)
batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device)

if dtype is not None:
tensor = F.convert_image_dtype(tensor, dtype)
Expand Down Expand Up @@ -798,14 +800,16 @@ def test_equalize(device):
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
def test_adjust_contrast(device, dtype, config):
@pytest.mark.parametrize('channels', [1, 3])
def test_adjust_contrast(device, dtype, config, channels):
check_functional_vs_PIL_vs_scripted(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
config,
device,
dtype
dtype,
channels=channels
)


Expand Down
9 changes: 6 additions & 3 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:

_assert_image_tensor(img)

_assert_channels(img, [3])

_assert_channels(img, [3, 1])
c = get_image_num_channels(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)

return _blend(img, mean, contrast_factor)

Expand Down

0 comments on commit 2e07bf8

Please sign in to comment.