From c7bcfada86e3e84a1aca2fbbe18ce427284108db Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 13 Mar 2024 16:23:27 +0000 Subject: [PATCH] Add torchscript test for io image stuff (#8313) --- test/test_image.py | 63 +++++++++++++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 4acf281e380..53714342d46 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -79,7 +79,9 @@ def normalize_dimensions(img_pil): ("RGB", ImageReadMode.RGB), ], ) -def test_decode_jpeg(img_path, pil_mode, mode): +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image)) +def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun): with Image.open(img_path) as img: is_cmyk = img.mode == "CMYK" @@ -92,7 +94,9 @@ def test_decode_jpeg(img_path, pil_mode, mode): img_pil = normalize_dimensions(img_pil) data = read_file(img_path) - img_ljpeg = decode_image(data, mode=mode) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img_ljpeg = decode_fun(data, mode=mode) # Permit a small variation on pixel values to account for implementation # differences between Pillow and LibJPEG. @@ -188,7 +192,12 @@ def test_damaged_corrupt_images(img_path): ("RGBA", ImageReadMode.RGB_ALPHA), ], ) -def test_decode_png(img_path, pil_mode, mode): +@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("decode_fun", (decode_png, decode_image)) +def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun): + + if scripted: + decode_fun = torch.jit.script(decode_fun) with Image.open(img_path) as img: if pil_mode is not None: @@ -202,7 +211,7 @@ def test_decode_png(img_path, pil_mode, mode): # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"): data = read_file(img_path) - img_lpng = decode_image(data, mode=mode) + img_lpng = decode_fun(data, mode=mode) img_lpng = _read_png_16(img_path, mode=mode) assert img_lpng.dtype == torch.int32 @@ -210,7 +219,7 @@ def test_decode_png(img_path, pil_mode, mode): img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8) else: data = read_file(img_path) - img_lpng = decode_image(data, mode=mode) + img_lpng = decode_fun(data, mode=mode) tol = 0 if pil_mode is None else 1 @@ -239,11 +248,13 @@ def test_decode_png_errors(): "img_path", [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], ) -def test_encode_png(img_path): +@pytest.mark.parametrize("scripted", (True, False)) +def test_encode_png(img_path, scripted): pil_image = Image.open(img_path) img_pil = torch.from_numpy(np.array(pil_image)) img_pil = img_pil.permute(2, 0, 1) - png_buf = encode_png(img_pil, compression_level=6) + encode = torch.jit.script(encode_png) if scripted else encode_png + png_buf = encode(img_pil, compression_level=6) rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist()))) rec_img = torch.from_numpy(np.array(rec_img)) @@ -270,27 +281,39 @@ def test_encode_png_errors(): "img_path", [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], ) -def test_write_png(img_path, tmpdir): +@pytest.mark.parametrize("scripted", (True, False)) +def test_write_png(img_path, tmpdir, scripted): pil_image = Image.open(img_path) img_pil = torch.from_numpy(np.array(pil_image)) img_pil = img_pil.permute(2, 0, 1) filename, _ = os.path.splitext(os.path.basename(img_path)) torch_png = os.path.join(tmpdir, f"{filename}_torch.png") - write_png(img_pil, torch_png, compression_level=6) + write = torch.jit.script(write_png) if scripted else write_png + write(img_pil, torch_png, compression_level=6) saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = saved_image.permute(2, 0, 1) assert_equal(img_pil, saved_image) -def test_read_file(tmpdir): +def test_read_image(): + # Just testing torchcsript, the functionality is somewhat tested already in other tests. + path = next(get_images(IMAGE_ROOT, ".jpg")) + out = read_image(path) + out_scripted = torch.jit.script(read_image)(path) + torch.testing.assert_close(out, out_scripted, atol=0, rtol=0) + + +@pytest.mark.parametrize("scripted", (True, False)) +def test_read_file(tmpdir, scripted): fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) with open(fpath, "wb") as f: f.write(content) - data = read_file(fpath) + fun = torch.jit.script(read_file) if scripted else read_file + data = fun(fpath) expected = torch.tensor(list(content), dtype=torch.uint8) os.unlink(fpath) assert_equal(data, expected) @@ -311,11 +334,13 @@ def test_read_file_non_ascii(tmpdir): assert_equal(data, expected) -def test_write_file(tmpdir): +@pytest.mark.parametrize("scripted", (True, False)) +def test_write_file(tmpdir, scripted): fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write_file(fpath, content_tensor) + write = torch.jit.script(write_file) if scripted else write_file + write(fpath, content_tensor) with open(fpath, "rb") as f: saved_content = f.read() @@ -464,7 +489,8 @@ def test_encode_jpeg_errors(): "img_path", [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], ) -def test_encode_jpeg(img_path): +@pytest.mark.parametrize("scripted", (True, False)) +def test_encode_jpeg(img_path, scripted): img = read_image(img_path) pil_img = F.to_pil_image(img) @@ -473,8 +499,9 @@ def test_encode_jpeg(img_path): encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8) + encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg for src_img in [img, img.contiguous()]: - encoded_jpeg_torch = encode_jpeg(src_img, quality=75) + encoded_jpeg_torch = encode(src_img, quality=75) assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) @@ -483,7 +510,8 @@ def test_encode_jpeg(img_path): "img_path", [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], ) -def test_write_jpeg(img_path, tmpdir): +@pytest.mark.parametrize("scripted", (True, False)) +def test_write_jpeg(img_path, tmpdir, scripted): tmpdir = Path(tmpdir) img = read_image(img_path) pil_img = F.to_pil_image(img) @@ -491,7 +519,8 @@ def test_write_jpeg(img_path, tmpdir): torch_jpeg = str(tmpdir / "torch.jpg") pil_jpeg = str(tmpdir / "pil.jpg") - write_jpeg(img, torch_jpeg, quality=75) + write = torch.jit.script(write_jpeg) if scripted else write_jpeg + write(img, torch_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75) with open(torch_jpeg, "rb") as f: