Skip to content

Commit

Permalink
Add encoding and writing JPEG ops (#2696)
Browse files Browse the repository at this point in the history
* Add decode and write JPEG ops

* Fix styling issues

* Use int64_t instead of int

* Use std::string

* Use jpegcommon.h for read_jpeg

* Minor updates to error handling in read

* Include header only once

* Reverse header inclusion

* Update common header

* Add common definitions

* Include string

* Include header?

* Include header?

* Add Python frontend calls

* Use unsigned long directly

* Fix style issues

* Include cstddef

* Ignore clang-format on cstddef

* Also include stdio

* Add JPEG and PNG include dirs

* Use C10_EXPORT

* Add JPEG encoding test

* Set quality to 75 by default and add write jpeg test

* Minor error correction

* Use assertEquals by assertEqual

* Remove test results

* Use pre-saved PIL output

* Remove extra PIL call

* Use read_jpeg instead of PIL

* Add error tests

* Address review comments

* Fix style issues

* Set test case to uint8

* Update test error check

* Apply suggestions from code review

* Fix clang-format

* Fix lint

* Fix test

* Remove unused file

* Fix regex error message

* Fix tests

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
andfoy and fmassa authored Sep 25, 2020
1 parent 6e10e3f commit 662373f
Show file tree
Hide file tree
Showing 13 changed files with 294 additions and 34 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES} ${IMAGE
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${PNG_LIBRARY} ${JPEG_LIBRARIES} Python3::Python)
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision)

include_directories(torchvision/csrc)
include_directories(torchvision/csrc ${JPEG_INCLUDE_DIRS} ${PNG_INCLUDE_DIRS})
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 69 additions & 2 deletions test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import io
import glob
import unittest
import sys

import torch
import torchvision
from PIL import Image
from torchvision.io.image import read_png, decode_png, read_jpeg, decode_jpeg
from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg)
import numpy as np

IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
Expand All @@ -17,7 +19,7 @@
def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
if os.path.basename(root) == 'damaged_jpeg':
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
continue

for fl in files:
Expand Down Expand Up @@ -66,6 +68,71 @@ def test_damaged_images(self):
with self.assertRaises(RuntimeError):
read_jpeg(image_path)

def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = read_jpeg(img_path)

with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
self.assertTrue(jpeg_bytes.equal(pil_bytes))

with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img = read_jpeg(img_path)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()

with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()

os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)

def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/cpu/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ PyMODINIT_FUNC PyInit_image(void) {

static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG)
.op("image::decode_jpeg", &decodeJPEG);
.op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG);
1 change: 1 addition & 0 deletions torchvision/csrc/cpu/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
#include <torch/torch.h>
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include "writejpeg_cpu.h"
17 changes: 17 additions & 0 deletions torchvision/csrc/cpu/image/jpegcommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "jpegcommon.h"
#include <string>

void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;

/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg);

/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
19 changes: 19 additions & 0 deletions torchvision/csrc/cpu/image/jpegcommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

// clang-format off
#include <cstdio>
#include <cstddef>
// clang-format on
#include <jpeglib.h>
#include <setjmp.h>
#include <string>

static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};

typedef struct torch_jpeg_error_mgr* torch_jpeg_error_ptr;
void torch_jpeg_error_exit(j_common_ptr cinfo);
33 changes: 5 additions & 28 deletions torchvision/csrc/cpu/image/readjpeg_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,13 @@
#if !JPEG_FOUND

torch::Tensor decodeJPEG(const torch::Tensor& data) {
AT_ERROR("decodeJPEG: torchvision not compiled with libjpeg support");
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}

#else
#include <jpeglib.h>

const static JOCTET EOI_BUFFER[1] = {JPEG_EOI};
char jpegLastErrorMsg[JMSG_LENGTH_MAX];

struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
jmp_buf setjmp_buffer; /* for return to caller */
};

typedef struct torch_jpeg_error_mgr* torch_jpeg_error_ptr;

void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;

/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, jpegLastErrorMsg);

/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
#include "jpegcommon.h"

struct torch_jpeg_mgr {
struct jpeg_source_mgr pub;
Expand All @@ -50,7 +27,7 @@ static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
// No more data. Probably an incomplete image; Raise exception.
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
strcpy(jpegLastErrorMsg, "Image is incomplete or truncated");
strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated");
longjmp(myerr->setjmp_buffer, 1);
src->pub.next_input_byte = EOI_BUFFER;
src->pub.bytes_in_buffer = 1;
Expand Down Expand Up @@ -108,7 +85,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {
* We need to clean up the JPEG object.
*/
jpeg_destroy_decompress(&cinfo);
AT_ERROR(jpegLastErrorMsg);
TORCH_CHECK(false, jerr.jpegLastErrorMsg);
}

jpeg_create_decompress(&cinfo);
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/image/readjpeg_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

#include <torch/torch.h>

torch::Tensor decodeJPEG(const torch::Tensor& data);
C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data);
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/image/readpng_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
#include <torch/torch.h>
#include <string>

torch::Tensor decodePNG(const torch::Tensor& data);
C10_EXPORT torch::Tensor decodePNG(const torch::Tensor& data);
129 changes: 129 additions & 0 deletions torchvision/csrc/cpu/image/writejpeg_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "writejpeg_cpu.h"

#include <setjmp.h>
#include <string>

#if !JPEG_FOUND

torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
TORCH_CHECK(
false, "encodeJPEG: torchvision not compiled with libjpeg support");
}

void writeJPEG(
const torch::Tensor& data,
std::string filename,
int64_t quality) {
TORCH_CHECK(
false, "writeJPEG: torchvision not compiled with libjpeg support");
}

#else

#include <jpeglib.h>
#include "jpegcommon.h"

torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling
struct jpeg_compress_struct cinfo;
struct torch_jpeg_error_mgr jerr;

// Define buffer to write JPEG information to and its size
unsigned long jpegSize = 0;
uint8_t* jpegBuf = NULL;

cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;

/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(jerr.setjmp_buffer)) {
/* If we get here, the JPEG code has signaled an error.
* We need to clean up the JPEG object and the buffer.
*/
jpeg_destroy_compress(&cinfo);
if (jpegBuf != NULL) {
free(jpegBuf);
}

TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
}

// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");

// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");

// Get image info
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();

TORCH_CHECK(
channels == 1 || channels == 3,
"The number of channels should be 1 or 3, got: ",
channels);

// Initialize JPEG structure
jpeg_create_compress(&cinfo);

// Set output image information
cinfo.image_width = width;
cinfo.image_height = height;
cinfo.input_components = channels;
cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;

jpeg_set_defaults(&cinfo);
jpeg_set_quality(&cinfo, quality, TRUE);

// Save JPEG output to a buffer
jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);

// Start JPEG compression
jpeg_start_compress(&cinfo, TRUE);

auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();

// Encode JPEG file
while (cinfo.next_scanline < cinfo.image_height) {
jpeg_write_scanlines(&cinfo, &ptr, 1);
ptr += stride;
}

jpeg_finish_compress(&cinfo);
jpeg_destroy_compress(&cinfo);

torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)jpegSize}, options);

// Copy memory from jpeg buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
std::memcpy(outPtr, jpegBuf, sizeof(uint8_t) * outTensor.numel());

free(jpegBuf);

return outTensor;
}

void writeJPEG(
const torch::Tensor& data,
std::string filename,
int64_t quality) {
auto jpegBuf = encodeJPEG(data, quality);
auto fileBytes = jpegBuf.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");

TORCH_CHECK(outfile != NULL, "Error opening output jpeg file");

fwrite(fileBytes, sizeof(uint8_t), jpegBuf.numel(), outfile);
fclose(outfile);
}

#endif
9 changes: 9 additions & 0 deletions torchvision/csrc/cpu/image/writejpeg_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <torch/torch.h>

C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality);
C10_EXPORT void writeJPEG(
const torch::Tensor& data,
std::string filename,
int64_t quality);
Loading

0 comments on commit 662373f

Please sign in to comment.