Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix write and encode jpeg tests #3908

Merged
merged 4 commits into from
May 25, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 156 additions & 73 deletions test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
import glob
import io
import os
import sys
import unittest
from pathlib import Path

import pytest
import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir, needs_cuda
import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda, cpu_only
from _assert_utils import assert_equal

from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode)
encode_png, write_png, write_file, ImageReadMode, read_image)

IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
IS_WINDOWS = sys.platform in ('win32', 'cygwin')


def _get_safe_image_name(name):
# Used when we need to change the pytest "id" for an "image path" parameter.
# If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
# and this creates issues when the test is running in a different machine than where it was collected
# (typically, in fb internal infra)
return name.split(os.path.sep)[-1]


def get_images(directory, img_ext):
Expand Down Expand Up @@ -93,72 +105,6 @@ def test_damaged_images(self):
with self.assertRaises(RuntimeError):
decode_jpeg(data)

def test_encode_jpeg(self):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path))

with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)

with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

def test_write_jpeg(self):
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
Expand Down Expand Up @@ -282,11 +228,7 @@ def test_write_file_non_ascii(self):

@needs_cuda
@pytest.mark.parametrize('img_path', [
# We need to change the "id" for that parameter.
# If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific,
# and this creates issues when the test is running in a different machine than where it was collected
# (typically, in fb internal infra)
pytest.param(jpeg_path, id=jpeg_path.split('/')[-1])
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
])
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
Expand Down Expand Up @@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors():
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')


@cpu_only
def test_encode_jpeg_errors():

with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

with pytest.raises(ValueError, match="Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

with pytest.raises(ValueError, match="Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)
return _inner


@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_encode_jpeg_windows():
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from
# decode_jpeg, which can yield different results from pil.decode (see
# test_decode... which uses a high tolerance).
# Instead, we should start encoding from the exact same decoded image, for a
# valid comparison. This is done in test_encode_jpeg, but unfortunately
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this comment. The right comparison would be to encode from the same decoded image and compare the results. Unfortunately as Nicolas explains, that's going to fail on Windows. I would not be surprised if the expected image is created on Linux or macOS to make this work. Hence this test is misleading and does not test that the encoding on our side is the same as the encoding on PIL side on the same platform.

Having said that, it's good that Nicolas found a way around to maintain the test until we decide whether we want to keep it or drop it. I'm on the fence on keeping it and Nicolas briefly considered dropping it. I don't have strong opinions on this. @fmassa thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the underlying issue might be that we are using different libjpeg for different OSes, which is less than ideal.

Fixing this would probably fix the issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the key issue comes from the fact that apparently on windows, the libjpeg version of PIL is different from that of torchvision.

Having different libjpeg versions across OSes wouldn't be a problem if for a given OS, the libjpeg version was the same for both PIL and torchvision I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug it would be good to log the libraries installed in

conda env update --file "${this_dir}/environment.yml" --prune
and
conda env update --file "${this_dir}/environment.yml" --prune
, our current constraint on libjpeg is <=9b, which in principle allows for different versions to be installed on different OSes, if there is one version missing in conda

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the problem lies within PIL, this means that if we were to add a test comparing the encoding of PIL on the CI itself it should fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to log the libraries installed

I'm not sure how to best do that yet but I started looking into it in #3968

for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path))

with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)


@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_write_jpeg_windows():
# FIXME: Remove this eventually, see test_encode_jpeg_windows
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

assert_equal(torch_bytes, pil_bytes)


@cpu_only
@_collect_if(cond=not IS_WINDOWS)
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg(img_path):
img = read_image(img_path)

pil_img = F.to_pil_image(img)
buf = io.BytesIO()
pil_img.save(buf, format='JPEG', quality=75)

# pytorch can't read from raw bytes so we go through numpy
pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8)
encoded_jpeg_pil = torch.as_tensor(pil_bytes)

for src_img in [img, img.contiguous()]:
encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


@cpu_only
@_collect_if(cond=not IS_WINDOWS)
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg(img_path):
with get_tmp_dir() as d:
d = Path(d)
img = read_image(img_path)
pil_img = F.to_pil_image(img)

torch_jpeg = str(d / 'torch.jpg')
pil_jpeg = str(d / 'pil.jpg')

write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

assert_equal(torch_bytes, pil_bytes)


if __name__ == '__main__':
unittest.main()