Skip to content

Commit

Permalink
Remove read_jpeg/read_png in favor of read_image (pytorch#2764)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa authored and vfdev-5 committed Dec 4, 2020
1 parent c69afbb commit 0063b25
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 53 deletions.
37 changes: 12 additions & 25 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchvision
from PIL import Image
from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png)
import numpy as np

Expand All @@ -33,19 +33,12 @@ def get_images(directory, img_ext):


class ImageTester(unittest.TestCase):
def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = read_jpeg(img_path)
self.assertTrue(img_ljpeg.equal(img_pil))

def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
data = read_file(img_path)
img_ljpeg = decode_jpeg(data)
self.assertTrue(img_ljpeg.equal(img_pil))

with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
Expand All @@ -59,18 +52,19 @@ def test_decode_jpeg(self):

def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise)
bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
try:
_ = read_jpeg(bad_huff)
_ = decode_jpeg(bad_huff)
except RuntimeError:
self.assertTrue(False)

# Truncated images should raise an exception
truncated_images = glob.glob(
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
for image_path in truncated_images:
data = read_file(image_path)
with self.assertRaises(RuntimeError):
read_jpeg(image_path)
decode_jpeg(data)

def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
Expand All @@ -79,7 +73,7 @@ def test_encode_jpeg(self):
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = read_jpeg(img_path)
img = decode_jpeg(read_file(img_path))

with open(expected_file, 'rb') as f:
pil_bytes = f.read()
Expand Down Expand Up @@ -117,7 +111,8 @@ def test_encode_jpeg(self):

def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img = read_jpeg(img_path)
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
Expand All @@ -137,20 +132,12 @@ def test_write_jpeg(self):
os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)

def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = read_png(img_path)
self.assertTrue(img_lpng.equal(img_pil))

def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
data = read_file(img_path)
img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil))

with self.assertRaises(RuntimeError):
Expand Down
28 changes: 0 additions & 28 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,6 @@ def decode_png(input: torch.Tensor) -> torch.Tensor:
return output


def read_png(path: str) -> torch.Tensor:
"""
Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = read_file(path)
return decode_png(data)


def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
Expand Down Expand Up @@ -124,19 +109,6 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
return output


def read_jpeg(path: str) -> torch.Tensor:
"""
Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = read_file(path)
return decode_jpeg(data)


def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
Expand Down

0 comments on commit 0063b25

Please sign in to comment.