-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add encoding and writing JPEG ops (#2696)
* Add decode and write JPEG ops * Fix styling issues * Use int64_t instead of int * Use std::string * Use jpegcommon.h for read_jpeg * Minor updates to error handling in read * Include header only once * Reverse header inclusion * Update common header * Add common definitions * Include string * Include header? * Include header? * Add Python frontend calls * Use unsigned long directly * Fix style issues * Include cstddef * Ignore clang-format on cstddef * Also include stdio * Add JPEG and PNG include dirs * Use C10_EXPORT * Add JPEG encoding test * Set quality to 75 by default and add write jpeg test * Minor error correction * Use assertEquals by assertEqual * Remove test results * Use pre-saved PIL output * Remove extra PIL call * Use read_jpeg instead of PIL * Add error tests * Address review comments * Fix style issues * Set test case to uint8 * Update test error check * Apply suggestions from code review * Fix clang-format * Fix lint * Fix test * Remove unused file * Fix regex error message * Fix tests Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
- Loading branch information
Showing
13 changed files
with
294 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ | |
#include <torch/torch.h> | ||
#include "readjpeg_cpu.h" | ||
#include "readpng_cpu.h" | ||
#include "writejpeg_cpu.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#include "jpegcommon.h" | ||
#include <string> | ||
|
||
void torch_jpeg_error_exit(j_common_ptr cinfo) { | ||
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce | ||
* pointer */ | ||
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; | ||
|
||
/* Always display the message. */ | ||
/* We could postpone this until after returning, if we chose. */ | ||
// (*cinfo->err->output_message)(cinfo); | ||
/* Create the message */ | ||
(*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg); | ||
|
||
/* Return control to the setjmp point */ | ||
longjmp(myerr->setjmp_buffer, 1); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
|
||
// clang-format off | ||
#include <cstdio> | ||
#include <cstddef> | ||
// clang-format on | ||
#include <jpeglib.h> | ||
#include <setjmp.h> | ||
#include <string> | ||
|
||
static const JOCTET EOI_BUFFER[1] = {JPEG_EOI}; | ||
struct torch_jpeg_error_mgr { | ||
struct jpeg_error_mgr pub; /* "public" fields */ | ||
char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */ | ||
jmp_buf setjmp_buffer; /* for return to caller */ | ||
}; | ||
|
||
typedef struct torch_jpeg_error_mgr* torch_jpeg_error_ptr; | ||
void torch_jpeg_error_exit(j_common_ptr cinfo); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#include "writejpeg_cpu.h" | ||
|
||
#include <setjmp.h> | ||
#include <string> | ||
|
||
#if !JPEG_FOUND | ||
|
||
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) { | ||
TORCH_CHECK( | ||
false, "encodeJPEG: torchvision not compiled with libjpeg support"); | ||
} | ||
|
||
void writeJPEG( | ||
const torch::Tensor& data, | ||
std::string filename, | ||
int64_t quality) { | ||
TORCH_CHECK( | ||
false, "writeJPEG: torchvision not compiled with libjpeg support"); | ||
} | ||
|
||
#else | ||
|
||
#include <jpeglib.h> | ||
#include "jpegcommon.h" | ||
|
||
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) { | ||
// Define compression structures and error handling | ||
struct jpeg_compress_struct cinfo; | ||
struct torch_jpeg_error_mgr jerr; | ||
|
||
// Define buffer to write JPEG information to and its size | ||
unsigned long jpegSize = 0; | ||
uint8_t* jpegBuf = NULL; | ||
|
||
cinfo.err = jpeg_std_error(&jerr.pub); | ||
jerr.pub.error_exit = torch_jpeg_error_exit; | ||
|
||
/* Establish the setjmp return context for my_error_exit to use. */ | ||
if (setjmp(jerr.setjmp_buffer)) { | ||
/* If we get here, the JPEG code has signaled an error. | ||
* We need to clean up the JPEG object and the buffer. | ||
*/ | ||
jpeg_destroy_compress(&cinfo); | ||
if (jpegBuf != NULL) { | ||
free(jpegBuf); | ||
} | ||
|
||
TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); | ||
} | ||
|
||
// Check that the input tensor is on CPU | ||
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); | ||
|
||
// Check that the input tensor dtype is uint8 | ||
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); | ||
|
||
// Check that the input tensor is 3-dimensional | ||
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); | ||
|
||
// Get image info | ||
int channels = data.size(0); | ||
int height = data.size(1); | ||
int width = data.size(2); | ||
auto input = data.permute({1, 2, 0}).contiguous(); | ||
|
||
TORCH_CHECK( | ||
channels == 1 || channels == 3, | ||
"The number of channels should be 1 or 3, got: ", | ||
channels); | ||
|
||
// Initialize JPEG structure | ||
jpeg_create_compress(&cinfo); | ||
|
||
// Set output image information | ||
cinfo.image_width = width; | ||
cinfo.image_height = height; | ||
cinfo.input_components = channels; | ||
cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB; | ||
|
||
jpeg_set_defaults(&cinfo); | ||
jpeg_set_quality(&cinfo, quality, TRUE); | ||
|
||
// Save JPEG output to a buffer | ||
jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize); | ||
|
||
// Start JPEG compression | ||
jpeg_start_compress(&cinfo, TRUE); | ||
|
||
auto stride = width * channels; | ||
auto ptr = input.data_ptr<uint8_t>(); | ||
|
||
// Encode JPEG file | ||
while (cinfo.next_scanline < cinfo.image_height) { | ||
jpeg_write_scanlines(&cinfo, &ptr, 1); | ||
ptr += stride; | ||
} | ||
|
||
jpeg_finish_compress(&cinfo); | ||
jpeg_destroy_compress(&cinfo); | ||
|
||
torch::TensorOptions options = torch::TensorOptions{torch::kU8}; | ||
auto outTensor = torch::empty({(long)jpegSize}, options); | ||
|
||
// Copy memory from jpeg buffer, since torch cannot get ownership of it via | ||
// `from_blob` | ||
auto outPtr = outTensor.data_ptr<uint8_t>(); | ||
std::memcpy(outPtr, jpegBuf, sizeof(uint8_t) * outTensor.numel()); | ||
|
||
free(jpegBuf); | ||
|
||
return outTensor; | ||
} | ||
|
||
void writeJPEG( | ||
const torch::Tensor& data, | ||
std::string filename, | ||
int64_t quality) { | ||
auto jpegBuf = encodeJPEG(data, quality); | ||
auto fileBytes = jpegBuf.data_ptr<uint8_t>(); | ||
auto fileCStr = filename.c_str(); | ||
FILE* outfile = fopen(fileCStr, "wb"); | ||
|
||
TORCH_CHECK(outfile != NULL, "Error opening output jpeg file"); | ||
|
||
fwrite(fileBytes, sizeof(uint8_t), jpegBuf.numel(), outfile); | ||
fclose(outfile); | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
|
||
#include <torch/torch.h> | ||
|
||
C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality); | ||
C10_EXPORT void writeJPEG( | ||
const torch::Tensor& data, | ||
std::string filename, | ||
int64_t quality); |
Oops, something went wrong.