Skip to content

Commit

Permalink
Add torchscript test for io image stuff (#8313)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Mar 13, 2024
1 parent eb815ae commit c7bcfad
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def normalize_dimensions(img_pil):
("RGB", ImageReadMode.RGB),
],
)
def test_decode_jpeg(img_path, pil_mode, mode):
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):

with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
Expand All @@ -92,7 +94,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_ljpeg = decode_image(data, mode=mode)
if scripted:
decode_fun = torch.jit.script(decode_fun)
img_ljpeg = decode_fun(data, mode=mode)

# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
Expand Down Expand Up @@ -188,7 +192,12 @@ def test_damaged_corrupt_images(img_path):
("RGBA", ImageReadMode.RGB_ALPHA),
],
)
def test_decode_png(img_path, pil_mode, mode):
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):

if scripted:
decode_fun = torch.jit.script(decode_fun)

with Image.open(img_path) as img:
if pil_mode is not None:
Expand All @@ -202,15 +211,15 @@ def test_decode_png(img_path, pil_mode, mode):
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
img_lpng = decode_fun(data, mode=mode)

img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32
# PIL converts 16 bits pngs in uint8
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
else:
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
img_lpng = decode_fun(data, mode=mode)

tol = 0 if pil_mode is None else 1

Expand Down Expand Up @@ -239,11 +248,13 @@ def test_decode_png_errors():
"img_path",
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
def test_encode_png(img_path):
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_png(img_path, scripted):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
png_buf = encode_png(img_pil, compression_level=6)
encode = torch.jit.script(encode_png) if scripted else encode_png
png_buf = encode(img_pil, compression_level=6)

rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
rec_img = torch.from_numpy(np.array(rec_img))
Expand All @@ -270,27 +281,39 @@ def test_encode_png_errors():
"img_path",
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
def test_write_png(img_path, tmpdir):
@pytest.mark.parametrize("scripted", (True, False))
def test_write_png(img_path, tmpdir, scripted):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)

filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
write_png(img_pil, torch_png, compression_level=6)
write = torch.jit.script(write_png) if scripted else write_png
write(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1)

assert_equal(img_pil, saved_image)


def test_read_file(tmpdir):
def test_read_image():
# Just testing torchcsript, the functionality is somewhat tested already in other tests.
path = next(get_images(IMAGE_ROOT, ".jpg"))
out = read_image(path)
out_scripted = torch.jit.script(read_image)(path)
torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)


@pytest.mark.parametrize("scripted", (True, False))
def test_read_file(tmpdir, scripted):
fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname)
with open(fpath, "wb") as f:
f.write(content)

data = read_file(fpath)
fun = torch.jit.script(read_file) if scripted else read_file
data = fun(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
os.unlink(fpath)
assert_equal(data, expected)
Expand All @@ -311,11 +334,13 @@ def test_read_file_non_ascii(tmpdir):
assert_equal(data, expected)


def test_write_file(tmpdir):
@pytest.mark.parametrize("scripted", (True, False))
def test_write_file(tmpdir, scripted):
fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
write = torch.jit.script(write_file) if scripted else write_file
write(fpath, content_tensor)

with open(fpath, "rb") as f:
saved_content = f.read()
Expand Down Expand Up @@ -464,7 +489,8 @@ def test_encode_jpeg_errors():
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_encode_jpeg(img_path):
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_jpeg(img_path, scripted):
img = read_image(img_path)

pil_img = F.to_pil_image(img)
Expand All @@ -473,8 +499,9 @@ def test_encode_jpeg(img_path):

encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)

encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
for src_img in [img, img.contiguous()]:
encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
encoded_jpeg_torch = encode(src_img, quality=75)
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


Expand All @@ -483,15 +510,17 @@ def test_encode_jpeg(img_path):
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_write_jpeg(img_path, tmpdir):
@pytest.mark.parametrize("scripted", (True, False))
def test_write_jpeg(img_path, tmpdir, scripted):
tmpdir = Path(tmpdir)
img = read_image(img_path)
pil_img = F.to_pil_image(img)

torch_jpeg = str(tmpdir / "torch.jpg")
pil_jpeg = str(tmpdir / "pil.jpg")

write_jpeg(img, torch_jpeg, quality=75)
write = torch.jit.script(write_jpeg) if scripted else write_jpeg
write(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75)

with open(torch_jpeg, "rb") as f:
Expand Down

0 comments on commit c7bcfad

Please sign in to comment.