diff --git a/test/test_image.py b/test/test_image.py index ebc47fde9e4..eae4a1473c5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1,24 +1,36 @@ import glob import io import os +import sys import unittest +from pathlib import Path import pytest import numpy as np import torch from PIL import Image -from common_utils import get_tmp_dir, needs_cuda +import torchvision.transforms.functional as F +from common_utils import get_tmp_dir, needs_cuda, cpu_only from _assert_utils import assert_equal from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, - encode_png, write_png, write_file, ImageReadMode) + encode_png, write_png, write_file, ImageReadMode, read_image) IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") +IS_WINDOWS = sys.platform in ('win32', 'cygwin') + + +def _get_safe_image_name(name): + # Used when we need to change the pytest "id" for an "image path" parameter. + # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific, + # and this creates issues when the test is running in a different machine than where it was collected + # (typically, in fb internal infra) + return name.split(os.path.sep)[-1] def get_images(directory, img_ext): @@ -93,72 +105,6 @@ def test_damaged_images(self): with self.assertRaises(RuntimeError): decode_jpeg(data) - def test_encode_jpeg(self): - for img_path in get_images(ENCODE_JPEG, ".jpg"): - dirname = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - write_folder = os.path.join(dirname, 'jpeg_write') - expected_file = os.path.join( - write_folder, '{0}_pil.jpg'.format(filename)) - img = decode_jpeg(read_file(img_path)) - - with open(expected_file, 'rb') as f: - pil_bytes = f.read() - pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) - for src_img in [img, img.contiguous()]: - # PIL sets jpeg quality to 75 by default - jpeg_bytes = encode_jpeg(src_img, quality=75) - assert_equal(jpeg_bytes, pil_bytes) - - with self.assertRaisesRegex( - RuntimeError, "Input tensor dtype should be uint8"): - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) - - with self.assertRaisesRegex( - ValueError, "Image quality should be a positive number " - "between 1 and 100"): - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) - - with self.assertRaisesRegex( - ValueError, "Image quality should be a positive number " - "between 1 and 100"): - encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) - - with self.assertRaisesRegex( - RuntimeError, "The number of channels should be 1 or 3, got: 5"): - encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) - - with self.assertRaisesRegex( - RuntimeError, "Input data should be a 3-dimensional tensor"): - encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) - - with self.assertRaisesRegex( - RuntimeError, "Input data should be a 3-dimensional tensor"): - encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) - - def test_write_jpeg(self): - with get_tmp_dir() as d: - for img_path in get_images(ENCODE_JPEG, ".jpg"): - data = read_file(img_path) - img = decode_jpeg(data) - - basedir = os.path.dirname(img_path) - filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - d, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) - - write_jpeg(img, torch_jpeg, quality=75) - - with open(torch_jpeg, 'rb') as f: - torch_bytes = f.read() - - with open(pil_jpeg, 'rb') as f: - pil_bytes = f.read() - - self.assertEqual(torch_bytes, pil_bytes) - def test_decode_png(self): conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] @@ -282,11 +228,7 @@ def test_write_file_non_ascii(self): @needs_cuda @pytest.mark.parametrize('img_path', [ - # We need to change the "id" for that parameter. - # If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific, - # and this creates issues when the test is running in a different machine than where it was collected - # (typically, in fb internal infra) - pytest.param(jpeg_path, id=jpeg_path.split('/')[-1]) + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg") ]) @pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) @@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors(): torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') +@cpu_only +def test_encode_jpeg_errors(): + + with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) + + with pytest.raises(ValueError, match="Image quality should be a positive number " + "between 1 and 100"): + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) + + with pytest.raises(ValueError, match="Image quality should be a positive number " + "between 1 and 100"): + encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) + + with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): + encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8)) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8)) + + with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"): + encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) + + +def _collect_if(cond): + # TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows + # are removed + def _inner(test_func): + if cond: + return test_func + else: + return pytest.mark.dont_collect(test_func) + return _inner + + +@cpu_only +@_collect_if(cond=IS_WINDOWS) +def test_encode_jpeg_windows(): + # This test is *wrong*. + # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it + # starts encoding the torchvision version from an image that comes from + # decode_jpeg, which can yield different results from pil.decode (see + # test_decode... which uses a high tolerance). + # Instead, we should start encoding from the exact same decoded image, for a + # valid comparison. This is done in test_encode_jpeg, but unfortunately + # these more correct tests fail on windows (probably because of a difference + # in libjpeg) between torchvision and PIL. + # FIXME: make the correct tests pass on windows and remove this. + for img_path in get_images(ENCODE_JPEG, ".jpg"): + dirname = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + write_folder = os.path.join(dirname, 'jpeg_write') + expected_file = os.path.join( + write_folder, '{0}_pil.jpg'.format(filename)) + img = decode_jpeg(read_file(img_path)) + + with open(expected_file, 'rb') as f: + pil_bytes = f.read() + pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) + for src_img in [img, img.contiguous()]: + # PIL sets jpeg quality to 75 by default + jpeg_bytes = encode_jpeg(src_img, quality=75) + assert_equal(jpeg_bytes, pil_bytes) + + +@cpu_only +@_collect_if(cond=IS_WINDOWS) +def test_write_jpeg_windows(): + # FIXME: Remove this eventually, see test_encode_jpeg_windows + with get_tmp_dir() as d: + for img_path in get_images(ENCODE_JPEG, ".jpg"): + data = read_file(img_path) + img = decode_jpeg(data) + + basedir = os.path.dirname(img_path) + filename, _ = os.path.splitext(os.path.basename(img_path)) + torch_jpeg = os.path.join( + d, '{0}_torch.jpg'.format(filename)) + pil_jpeg = os.path.join( + basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + + write_jpeg(img, torch_jpeg, quality=75) + + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() + + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() + + assert_equal(torch_bytes, pil_bytes) + + +@cpu_only +@_collect_if(cond=not IS_WINDOWS) +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") +]) +def test_encode_jpeg(img_path): + img = read_image(img_path) + + pil_img = F.to_pil_image(img) + buf = io.BytesIO() + pil_img.save(buf, format='JPEG', quality=75) + + # pytorch can't read from raw bytes so we go through numpy + pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) + encoded_jpeg_pil = torch.as_tensor(pil_bytes) + + for src_img in [img, img.contiguous()]: + encoded_jpeg_torch = encode_jpeg(src_img, quality=75) + assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) + + +@cpu_only +@_collect_if(cond=not IS_WINDOWS) +@pytest.mark.parametrize('img_path', [ + pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) + for jpeg_path in get_images(ENCODE_JPEG, ".jpg") +]) +def test_write_jpeg(img_path): + with get_tmp_dir() as d: + d = Path(d) + img = read_image(img_path) + pil_img = F.to_pil_image(img) + + torch_jpeg = str(d / 'torch.jpg') + pil_jpeg = str(d / 'pil.jpg') + + write_jpeg(img, torch_jpeg, quality=75) + pil_img.save(pil_jpeg, quality=75) + + with open(torch_jpeg, 'rb') as f: + torch_bytes = f.read() + + with open(pil_jpeg, 'rb') as f: + pil_bytes = f.read() + + assert_equal(torch_bytes, pil_bytes) + + if __name__ == '__main__': unittest.main()