From 0063b25b69a8b3119d00c8a8bf842674f279fe62 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 6 Oct 2020 19:37:50 +0200 Subject: [PATCH] Remove read_jpeg/read_png in favor of read_image (#2764) --- test/test_image.py | 37 ++++++++++++------------------------- torchvision/io/image.py | 28 ---------------------------- 2 files changed, 12 insertions(+), 53 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 7a0317cae83..8eb3930d139 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -8,7 +8,7 @@ import torchvision from PIL import Image from torchvision.io.image import ( - read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, + decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, encode_png, write_png) import numpy as np @@ -33,19 +33,12 @@ def get_images(directory, img_ext): class ImageTester(unittest.TestCase): - def test_read_jpeg(self): - for img_path in get_images(IMAGE_ROOT, ".jpg"): - img_pil = torch.load(img_path.replace('jpg', 'pth')) - img_pil = img_pil.permute(2, 0, 1) - img_ljpeg = read_jpeg(img_path) - self.assertTrue(img_ljpeg.equal(img_pil)) - def test_decode_jpeg(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = img_pil.permute(2, 0, 1) - size = os.path.getsize(img_path) - img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) + data = read_file(img_path) + img_ljpeg = decode_jpeg(data) self.assertTrue(img_ljpeg.equal(img_pil)) with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): @@ -59,9 +52,9 @@ def test_decode_jpeg(self): def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) - bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg') + bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) try: - _ = read_jpeg(bad_huff) + _ = decode_jpeg(bad_huff) except RuntimeError: self.assertTrue(False) @@ -69,8 +62,9 @@ def test_damaged_images(self): truncated_images = glob.glob( os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) for image_path in truncated_images: + data = read_file(image_path) with self.assertRaises(RuntimeError): - read_jpeg(image_path) + decode_jpeg(data) def test_encode_jpeg(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): @@ -79,7 +73,7 @@ def test_encode_jpeg(self): write_folder = os.path.join(dirname, 'jpeg_write') expected_file = os.path.join( write_folder, '{0}_pil.jpg'.format(filename)) - img = read_jpeg(img_path) + img = decode_jpeg(read_file(img_path)) with open(expected_file, 'rb') as f: pil_bytes = f.read() @@ -117,7 +111,8 @@ def test_encode_jpeg(self): def test_write_jpeg(self): for img_path in get_images(IMAGE_ROOT, ".jpg"): - img = read_jpeg(img_path) + data = read_file(img_path) + img = decode_jpeg(data) basedir = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) @@ -137,20 +132,12 @@ def test_write_jpeg(self): os.remove(torch_jpeg) self.assertEqual(torch_bytes, pil_bytes) - def test_read_png(self): - # Check across .png - for img_path in get_images(IMAGE_DIR, ".png"): - img_pil = torch.from_numpy(np.array(Image.open(img_path))) - img_pil = img_pil.permute(2, 0, 1) - img_lpng = read_png(img_path) - self.assertTrue(img_lpng.equal(img_pil)) - def test_decode_png(self): for img_path in get_images(IMAGE_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = img_pil.permute(2, 0, 1) - size = os.path.getsize(img_path) - img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) + data = read_file(img_path) + img_lpng = decode_png(data) self.assertTrue(img_lpng.equal(img_pil)) with self.assertRaises(RuntimeError): diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 398d682689e..a79b39fabc3 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -54,21 +54,6 @@ def decode_png(input: torch.Tensor) -> torch.Tensor: return output -def read_png(path: str) -> torch.Tensor: - """ - Reads a PNG image into a 3 dimensional RGB Tensor. - The values of the output tensor are uint8 between 0 and 255. - - Arguments: - path (str): path of the PNG image. - - Returns: - output (Tensor[3, image_height, image_width]) - """ - data = read_file(path) - return decode_png(data) - - def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor: """ Takes an input tensor in CHW layout and returns a buffer with the contents @@ -124,19 +109,6 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor: return output -def read_jpeg(path: str) -> torch.Tensor: - """ - Reads a JPEG image into a 3 dimensional RGB Tensor. - The values of the output tensor are uint8 between 0 and 255. - Arguments: - path (str): path of the JPEG image. - Returns: - output (Tensor[3, image_height, image_width]) - """ - data = read_file(path) - return decode_jpeg(data) - - def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: """ Takes an input tensor in CHW layout and returns a buffer with the contents