From 6afb3496cc27cebc938ac02880a35d6ddf9796a0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 28 Sep 2020 21:20:12 +0200 Subject: [PATCH] Add decode_image op (#2718) * Add decode_image op * Fix lint * More lint * Add C10_EXPORT --- test/test_image.py | 21 ++++-- torchvision/csrc/cpu/image/image.cpp | 3 +- torchvision/csrc/cpu/image/image.h | 2 +- torchvision/csrc/cpu/image/read_image_cpu.cpp | 27 +++++++ torchvision/csrc/cpu/image/read_image_cpu.h | 6 ++ torchvision/csrc/cpu/image/readjpeg_cpu.cpp | 7 ++ torchvision/csrc/cpu/image/readpng_cpu.cpp | 7 ++ torchvision/io/image.py | 72 ++++++++++++------- 8 files changed, 112 insertions(+), 33 deletions(-) create mode 100644 torchvision/csrc/cpu/image/read_image_cpu.cpp create mode 100644 torchvision/csrc/cpu/image/read_image_cpu.h diff --git a/test/test_image.py b/test/test_image.py index a87f45f2d70..618480f8765 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -8,7 +8,7 @@ import torchvision from PIL import Image from torchvision.io.image import ( - read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg) + read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file) import numpy as np IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") @@ -44,10 +44,10 @@ def test_decode_jpeg(self): img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) self.assertTrue(img_ljpeg.equal(img_pil)) - with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."): + with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) - with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."): + with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): decode_jpeg(torch.empty((100, ), dtype=torch.float16)) with self.assertRaises(RuntimeError): @@ -149,11 +149,24 @@ def test_decode_png(self): img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) self.assertTrue(img_lpng.equal(img_pil)) - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) + def test_decode_image(self): + for img_path in get_images(IMAGE_ROOT, ".jpg"): + img_pil = torch.load(img_path.replace('jpg', 'pth')) + img_pil = img_pil.permute(2, 0, 1) + img_ljpeg = decode_image(_read_file(img_path)) + self.assertTrue(img_ljpeg.equal(img_pil)) + + for img_path in get_images(IMAGE_DIR, ".png"): + img_pil = torch.from_numpy(np.array(Image.open(img_path))) + img_pil = img_pil.permute(2, 0, 1) + img_lpng = decode_image(_read_file(img_path)) + self.assertTrue(img_lpng.equal(img_pil)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/cpu/image/image.cpp b/torchvision/csrc/cpu/image/image.cpp index a10511c2953..9ce0b31ca44 100644 --- a/torchvision/csrc/cpu/image/image.cpp +++ b/torchvision/csrc/cpu/image/image.cpp @@ -16,4 +16,5 @@ static auto registry = torch::RegisterOperators() .op("image::decode_png", &decodePNG) .op("image::decode_jpeg", &decodeJPEG) .op("image::encode_jpeg", &encodeJPEG) - .op("image::write_jpeg", &writeJPEG); + .op("image::write_jpeg", &writeJPEG) + .op("image::decode_image", &decode_image); diff --git a/torchvision/csrc/cpu/image/image.h b/torchvision/csrc/cpu/image/image.h index 0e2c23cfc9e..077f4e16c18 100644 --- a/torchvision/csrc/cpu/image/image.h +++ b/torchvision/csrc/cpu/image/image.h @@ -1,9 +1,9 @@ - #pragma once // Comment #include #include +#include "read_image_cpu.h" #include "readjpeg_cpu.h" #include "readpng_cpu.h" #include "writejpeg_cpu.h" diff --git a/torchvision/csrc/cpu/image/read_image_cpu.cpp b/torchvision/csrc/cpu/image/read_image_cpu.cpp new file mode 100644 index 00000000000..0bd12d9c2e5 --- /dev/null +++ b/torchvision/csrc/cpu/image/read_image_cpu.cpp @@ -0,0 +1,27 @@ +#include "read_image_cpu.h" +#include + +torch::Tensor decode_image(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto datap = data.data_ptr(); + + const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" + const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" + + if (memcmp(jpeg_signature, datap, 3) == 0) { + return decodeJPEG(data); + } else if (memcmp(png_signature, datap, 4) == 0) { + return decodePNG(data); + } else { + TORCH_CHECK( + false, + "Unsupported image file. Only jpeg and png ", + "are currently supported."); + } +} diff --git a/torchvision/csrc/cpu/image/read_image_cpu.h b/torchvision/csrc/cpu/image/read_image_cpu.h new file mode 100644 index 00000000000..c8538cc88c6 --- /dev/null +++ b/torchvision/csrc/cpu/image/read_image_cpu.h @@ -0,0 +1,6 @@ +#pragma once + +#include "readjpeg_cpu.h" +#include "readpng_cpu.h" + +C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data); diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp index 5059efbaeed..dd2354e4467 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp @@ -72,6 +72,13 @@ static void torch_jpeg_set_source_mgr( } torch::Tensor decodeJPEG(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp index 1438e3ee1c4..e91d6058d4b 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.cpp +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -13,6 +13,13 @@ torch::Tensor decodePNG(const torch::Tensor& data) { #include torch::Tensor decodePNG(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 203c8e4f4ef..43ca1f86e87 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -23,23 +23,29 @@ pass +def _read_file(path: str) -> torch.Tensor: + if not os.path.isfile(path): + raise ValueError("Expected a valid file path.") + + size = os.path.getsize(path) + if size == 0: + raise ValueError("Expected a non empty file.") + data = torch.from_file(path, dtype=torch.uint8, size=size) + return data + + def decode_png(input: torch.Tensor) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. Arguments: - input (Tensor[1]): a one dimensional int8 tensor containing + input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the PNG image. Returns: output (Tensor[3, image_height, image_width]) """ - if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined] - raise ValueError("Expected a non empty 1-dimensional tensor.") - - if not input.dtype == torch.uint8: - raise ValueError("Expected a torch.uint8 tensor.") output = torch.ops.image.decode_png(input) return output @@ -55,13 +61,7 @@ def read_png(path: str) -> torch.Tensor: Returns: output (Tensor[3, image_height, image_width]) """ - if not os.path.isfile(path): - raise ValueError("Expected a valid file path.") - - size = os.path.getsize(path) - if size == 0: - raise ValueError("Expected a non empty file.") - data = torch.from_file(path, dtype=torch.uint8, size=size) + data = _read_file(path) return decode_png(data) @@ -70,17 +70,11 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor: Decodes a JPEG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. Arguments: - input (Tensor[1]): a one dimensional int8 tensor containing + input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the JPEG image. Returns: output (Tensor[3, image_height, image_width]) """ - if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined] - raise ValueError("Expected a non empty 1-dimensional tensor.") - - if not input.dtype == torch.uint8: - raise ValueError("Expected a torch.uint8 tensor.") - output = torch.ops.image.decode_jpeg(input) return output @@ -94,13 +88,7 @@ def read_jpeg(path: str) -> torch.Tensor: Returns: output (Tensor[3, image_height, image_width]) """ - if not os.path.isfile(path): - raise ValueError("Expected a valid file path.") - - size = os.path.getsize(path) - if size == 0: - raise ValueError("Expected a non empty file.") - data = torch.from_file(path, dtype=torch.uint8, size=size) + data = _read_file(path) return decode_jpeg(data) @@ -141,3 +129,33 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): 'between 1 and 100') torch.ops.image.write_jpeg(input, filename, quality) + + +def decode_image(input: torch.Tensor) -> torch.Tensor: + """ + Detects whether an image is a JPEG or PNG and performs the appropriate + operation to decode the image into a 3 dimensional RGB Tensor. + + The values of the output tensor are uint8 between 0 and 255. + + Arguments: + input (Tensor): a one dimensional uint8 tensor containing + the raw bytes of the PNG or JPEG image. + Returns: + output (Tensor[3, image_height, image_width]) + """ + output = torch.ops.image.decode_image(input) + return output + + +def read_image(path: str) -> torch.Tensor: + """ + Reads a JPEG or PNG image into a 3 dimensional RGB Tensor. + The values of the output tensor are uint8 between 0 and 255. + Arguments: + path (str): path of the JPEG or PNG image. + Returns: + output (Tensor[3, image_height, image_width]) + """ + data = _read_file(path) + return decode_image(data)