Skip to content

Commit

Permalink
Fix write and encode jpeg tests (#3908)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored May 25, 2021
1 parent c58d5d1 commit eaddb90
Showing 1 changed file with 156 additions and 73 deletions.
229 changes: 156 additions & 73 deletions test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
import glob
import io
import os
import sys
import unittest
from pathlib import Path

import pytest
import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir, needs_cuda
import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda, cpu_only
from _assert_utils import assert_equal

from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file, ImageReadMode)
encode_png, write_png, write_file, ImageReadMode, read_image)

IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
IS_WINDOWS = sys.platform in ('win32', 'cygwin')


def _get_safe_image_name(name):
# Used when we need to change the pytest "id" for an "image path" parameter.
# If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
# and this creates issues when the test is running in a different machine than where it was collected
# (typically, in fb internal infra)
return name.split(os.path.sep)[-1]


def get_images(directory, img_ext):
Expand Down Expand Up @@ -93,72 +105,6 @@ def test_damaged_images(self):
with self.assertRaises(RuntimeError):
decode_jpeg(data)

def test_encode_jpeg(self):
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path))

with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)

with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

def test_write_jpeg(self):
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
Expand Down Expand Up @@ -282,11 +228,7 @@ def test_write_file_non_ascii(self):

@needs_cuda
@pytest.mark.parametrize('img_path', [
# We need to change the "id" for that parameter.
# If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific,
# and this creates issues when the test is running in a different machine than where it was collected
# (typically, in fb internal infra)
pytest.param(jpeg_path, id=jpeg_path.split('/')[-1])
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
])
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
Expand Down Expand Up @@ -325,5 +267,146 @@ def test_decode_jpeg_cuda_errors():
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')


@cpu_only
def test_encode_jpeg_errors():

with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

with pytest.raises(ValueError, match="Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

with pytest.raises(ValueError, match="Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)
return _inner


@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_encode_jpeg_windows():
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from
# decode_jpeg, which can yield different results from pil.decode (see
# test_decode... which uses a high tolerance).
# Instead, we should start encoding from the exact same decoded image, for a
# valid comparison. This is done in test_encode_jpeg, but unfortunately
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = decode_jpeg(read_file(img_path))

with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)


@cpu_only
@_collect_if(cond=IS_WINDOWS)
def test_write_jpeg_windows():
# FIXME: Remove this eventually, see test_encode_jpeg_windows
with get_tmp_dir() as d:
for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

assert_equal(torch_bytes, pil_bytes)


@cpu_only
@_collect_if(cond=not IS_WINDOWS)
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_encode_jpeg(img_path):
img = read_image(img_path)

pil_img = F.to_pil_image(img)
buf = io.BytesIO()
pil_img.save(buf, format='JPEG', quality=75)

# pytorch can't read from raw bytes so we go through numpy
pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8)
encoded_jpeg_pil = torch.as_tensor(pil_bytes)

for src_img in [img, img.contiguous()]:
encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


@cpu_only
@_collect_if(cond=not IS_WINDOWS)
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg(img_path):
with get_tmp_dir() as d:
d = Path(d)
img = read_image(img_path)
pil_img = F.to_pil_image(img)

torch_jpeg = str(d / 'torch.jpg')
pil_jpeg = str(d / 'pil.jpg')

write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

assert_equal(torch_bytes, pil_bytes)


if __name__ == '__main__':
unittest.main()

0 comments on commit eaddb90

Please sign in to comment.