From 324e858eb05f2518550a72e401d9d9203568e247 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 28 Sep 2020 14:47:42 +0200 Subject: [PATCH 1/4] Add decode_image op --- test/test_image.py | 20 ++++-- torchvision/csrc/cpu/image/image.cpp | 3 +- torchvision/csrc/cpu/image/image.h | 1 + torchvision/csrc/cpu/image/read_image_cpu.cpp | 29 ++++++++ torchvision/csrc/cpu/image/read_image_cpu.h | 6 ++ torchvision/csrc/cpu/image/readjpeg_cpu.cpp | 7 ++ torchvision/csrc/cpu/image/readpng_cpu.cpp | 7 ++ torchvision/io/image.py | 72 ++++++++++++------- 8 files changed, 113 insertions(+), 32 deletions(-) create mode 100644 torchvision/csrc/cpu/image/read_image_cpu.cpp create mode 100644 torchvision/csrc/cpu/image/read_image_cpu.h diff --git a/test/test_image.py b/test/test_image.py index a87f45f2d70..b22dc5f6201 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -8,7 +8,7 @@ import torchvision from PIL import Image from torchvision.io.image import ( - read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg) + read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file) import numpy as np IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") @@ -44,10 +44,10 @@ def test_decode_jpeg(self): img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) self.assertTrue(img_ljpeg.equal(img_pil)) - with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."): + with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) - with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."): + with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"): decode_jpeg(torch.empty((100, ), dtype=torch.float16)) with self.assertRaises(RuntimeError): @@ -149,11 +149,23 @@ def test_decode_png(self): img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) self.assertTrue(img_lpng.equal(img_pil)) - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): decode_png(torch.empty((), dtype=torch.uint8)) with self.assertRaises(RuntimeError): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) + def test_decode_image(self): + for img_path in get_images(IMAGE_ROOT, ".jpg"): + img_pil = torch.load(img_path.replace('jpg', 'pth')) + img_pil = img_pil.permute(2, 0, 1) + img_ljpeg = decode_image(_read_file(img_path)) + self.assertTrue(img_ljpeg.equal(img_pil)) + + for img_path in get_images(IMAGE_DIR, ".png"): + img_pil = torch.from_numpy(np.array(Image.open(img_path))) + img_pil = img_pil.permute(2, 0, 1) + img_lpng = decode_image(_read_file(img_path)) + self.assertTrue(img_lpng.equal(img_pil)) if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/cpu/image/image.cpp b/torchvision/csrc/cpu/image/image.cpp index a10511c2953..9ce0b31ca44 100644 --- a/torchvision/csrc/cpu/image/image.cpp +++ b/torchvision/csrc/cpu/image/image.cpp @@ -16,4 +16,5 @@ static auto registry = torch::RegisterOperators() .op("image::decode_png", &decodePNG) .op("image::decode_jpeg", &decodeJPEG) .op("image::encode_jpeg", &encodeJPEG) - .op("image::write_jpeg", &writeJPEG); + .op("image::write_jpeg", &writeJPEG) + .op("image::decode_image", &decode_image); diff --git a/torchvision/csrc/cpu/image/image.h b/torchvision/csrc/cpu/image/image.h index 0e2c23cfc9e..0dae19bb72b 100644 --- a/torchvision/csrc/cpu/image/image.h +++ b/torchvision/csrc/cpu/image/image.h @@ -7,3 +7,4 @@ #include "readjpeg_cpu.h" #include "readpng_cpu.h" #include "writejpeg_cpu.h" +#include "read_image_cpu.h" diff --git a/torchvision/csrc/cpu/image/read_image_cpu.cpp b/torchvision/csrc/cpu/image/read_image_cpu.cpp new file mode 100644 index 00000000000..bd9d95f844b --- /dev/null +++ b/torchvision/csrc/cpu/image/read_image_cpu.cpp @@ -0,0 +1,29 @@ +#include "read_image_cpu.h" +#include + + +torch::Tensor decode_image(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto datap = data.data_ptr(); + + const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" + const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" + + if (memcmp(jpeg_signature, datap, 3) == 0) { + return decodeJPEG(data); + } else if (memcmp(png_signature, datap, 4) == 0) { + return decodePNG(data); + } else { + TORCH_CHECK( + false, + "Unsupported image file. Only jpeg and png ", + "are currently supported."); + } + +} diff --git a/torchvision/csrc/cpu/image/read_image_cpu.h b/torchvision/csrc/cpu/image/read_image_cpu.h new file mode 100644 index 00000000000..ab181a8ed0f --- /dev/null +++ b/torchvision/csrc/cpu/image/read_image_cpu.h @@ -0,0 +1,6 @@ +#pragma once + +#include "readjpeg_cpu.h" +#include "readpng_cpu.h" + +torch::Tensor decode_image(const torch::Tensor& data); diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp index 5059efbaeed..dd2354e4467 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp @@ -72,6 +72,13 @@ static void torch_jpeg_set_source_mgr( } torch::Tensor decodeJPEG(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp index 1438e3ee1c4..e91d6058d4b 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.cpp +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -13,6 +13,13 @@ torch::Tensor decodePNG(const torch::Tensor& data) { #include torch::Tensor decodePNG(const torch::Tensor& data) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 203c8e4f4ef..43ca1f86e87 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -23,23 +23,29 @@ pass +def _read_file(path: str) -> torch.Tensor: + if not os.path.isfile(path): + raise ValueError("Expected a valid file path.") + + size = os.path.getsize(path) + if size == 0: + raise ValueError("Expected a non empty file.") + data = torch.from_file(path, dtype=torch.uint8, size=size) + return data + + def decode_png(input: torch.Tensor) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. Arguments: - input (Tensor[1]): a one dimensional int8 tensor containing + input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the PNG image. Returns: output (Tensor[3, image_height, image_width]) """ - if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined] - raise ValueError("Expected a non empty 1-dimensional tensor.") - - if not input.dtype == torch.uint8: - raise ValueError("Expected a torch.uint8 tensor.") output = torch.ops.image.decode_png(input) return output @@ -55,13 +61,7 @@ def read_png(path: str) -> torch.Tensor: Returns: output (Tensor[3, image_height, image_width]) """ - if not os.path.isfile(path): - raise ValueError("Expected a valid file path.") - - size = os.path.getsize(path) - if size == 0: - raise ValueError("Expected a non empty file.") - data = torch.from_file(path, dtype=torch.uint8, size=size) + data = _read_file(path) return decode_png(data) @@ -70,17 +70,11 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor: Decodes a JPEG image into a 3 dimensional RGB Tensor. The values of the output tensor are uint8 between 0 and 255. Arguments: - input (Tensor[1]): a one dimensional int8 tensor containing + input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the JPEG image. Returns: output (Tensor[3, image_height, image_width]) """ - if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined] - raise ValueError("Expected a non empty 1-dimensional tensor.") - - if not input.dtype == torch.uint8: - raise ValueError("Expected a torch.uint8 tensor.") - output = torch.ops.image.decode_jpeg(input) return output @@ -94,13 +88,7 @@ def read_jpeg(path: str) -> torch.Tensor: Returns: output (Tensor[3, image_height, image_width]) """ - if not os.path.isfile(path): - raise ValueError("Expected a valid file path.") - - size = os.path.getsize(path) - if size == 0: - raise ValueError("Expected a non empty file.") - data = torch.from_file(path, dtype=torch.uint8, size=size) + data = _read_file(path) return decode_jpeg(data) @@ -141,3 +129,33 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): 'between 1 and 100') torch.ops.image.write_jpeg(input, filename, quality) + + +def decode_image(input: torch.Tensor) -> torch.Tensor: + """ + Detects whether an image is a JPEG or PNG and performs the appropriate + operation to decode the image into a 3 dimensional RGB Tensor. + + The values of the output tensor are uint8 between 0 and 255. + + Arguments: + input (Tensor): a one dimensional uint8 tensor containing + the raw bytes of the PNG or JPEG image. + Returns: + output (Tensor[3, image_height, image_width]) + """ + output = torch.ops.image.decode_image(input) + return output + + +def read_image(path: str) -> torch.Tensor: + """ + Reads a JPEG or PNG image into a 3 dimensional RGB Tensor. + The values of the output tensor are uint8 between 0 and 255. + Arguments: + path (str): path of the JPEG or PNG image. + Returns: + output (Tensor[3, image_height, image_width]) + """ + data = _read_file(path) + return decode_image(data) From d0f2200e48e6e6892a6d4a1584fcd502778d42c5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 28 Sep 2020 14:59:03 +0200 Subject: [PATCH 2/4] Fix lint --- test/test_image.py | 1 + torchvision/csrc/cpu/image/read_image_cpu.cpp | 13 ++++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index b22dc5f6201..618480f8765 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -167,5 +167,6 @@ def test_decode_image(self): img_lpng = decode_image(_read_file(img_path)) self.assertTrue(img_lpng.equal(img_pil)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/cpu/image/read_image_cpu.cpp b/torchvision/csrc/cpu/image/read_image_cpu.cpp index bd9d95f844b..a45ada8456b 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.cpp +++ b/torchvision/csrc/cpu/image/read_image_cpu.cpp @@ -12,18 +12,17 @@ torch::Tensor decode_image(const torch::Tensor& data) { auto datap = data.data_ptr(); - const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" - const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" + const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" + const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" if (memcmp(jpeg_signature, datap, 3) == 0) { return decodeJPEG(data); } else if (memcmp(png_signature, datap, 4) == 0) { return decodePNG(data); } else { - TORCH_CHECK( - false, - "Unsupported image file. Only jpeg and png ", - "are currently supported."); + TORCH_CHECK( + false, + "Unsupported image file. Only jpeg and png ", + "are currently supported."); } - } From f38fbce653d3fdd03a7ddd40475c4acf9255b529 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 28 Sep 2020 15:01:02 +0200 Subject: [PATCH 3/4] More lint --- torchvision/csrc/cpu/image/image.h | 3 +-- torchvision/csrc/cpu/image/read_image_cpu.cpp | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/csrc/cpu/image/image.h b/torchvision/csrc/cpu/image/image.h index 0dae19bb72b..077f4e16c18 100644 --- a/torchvision/csrc/cpu/image/image.h +++ b/torchvision/csrc/cpu/image/image.h @@ -1,10 +1,9 @@ - #pragma once // Comment #include #include +#include "read_image_cpu.h" #include "readjpeg_cpu.h" #include "readpng_cpu.h" #include "writejpeg_cpu.h" -#include "read_image_cpu.h" diff --git a/torchvision/csrc/cpu/image/read_image_cpu.cpp b/torchvision/csrc/cpu/image/read_image_cpu.cpp index a45ada8456b..0bd12d9c2e5 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.cpp +++ b/torchvision/csrc/cpu/image/read_image_cpu.cpp @@ -1,7 +1,6 @@ #include "read_image_cpu.h" #include - torch::Tensor decode_image(const torch::Tensor& data) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); From 5f9acc4f8e93cc1954e08fd9284fddbc754b6c79 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 28 Sep 2020 16:42:25 +0200 Subject: [PATCH 4/4] Add C10_EXPORT --- torchvision/csrc/cpu/image/read_image_cpu.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/cpu/image/read_image_cpu.h b/torchvision/csrc/cpu/image/read_image_cpu.h index ab181a8ed0f..c8538cc88c6 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.h +++ b/torchvision/csrc/cpu/image/read_image_cpu.h @@ -3,4 +3,4 @@ #include "readjpeg_cpu.h" #include "readpng_cpu.h" -torch::Tensor decode_image(const torch::Tensor& data); +C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data);