Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use from_blob to avoid memcpy #4118

Merged
merged 3 commits into from
Jun 28, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions torchvision/csrc/io/image/cpu/encode_jpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ using namespace detail;

torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling
struct jpeg_compress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
struct jpeg_compress_struct cinfo {};
struct torch_jpeg_error_mgr jerr {};

// Define buffer to write JPEG information to and its size
unsigned long jpegSize = 0;
uint8_t* jpegBuf = NULL;
uint8_t* jpegBuf = nullptr;

cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
Expand All @@ -34,7 +34,7 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
* We need to clean up the JPEG object and the buffer.
*/
jpeg_destroy_compress(&cinfo);
if (jpegBuf != NULL) {
if (jpegBuf != nullptr) {
free(jpegBuf);
}

Expand Down Expand Up @@ -92,16 +92,10 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
jpeg_destroy_compress(&cinfo);

torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)jpegSize}, options);

// Copy memory from jpeg buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
std::memcpy(outPtr, jpegBuf, sizeof(uint8_t) * outTensor.numel());

free(jpegBuf);

return outTensor;
auto out_tensor =
torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is raising the following clang-tidy warning in our internal CI:

There's an unchecked dereference of an object that is nullable. Make sure you null-check the object before you call methods on it, pass it as non-nullable parameter, or dereference it. If you are sure the the object cannot be nullptr, make it explicit with CHECK_NOTNULL() or assert(). jpegBuf is nullable.

Looks like jpegBuf is set above in jpeg_mem_dest, but I couldn't find any relevant docs about this function to know what happens if the allocation fails.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we have errors the setjmp that we do in https://github.com/pytorch/vision/pull/4118/files#diff-dfb505641c632421ae5cc17a4398fe0e3a47c436f62825e88ba2b0aebf497d5eR37 will kick in and the check that jpegBuf != nullptr, so maybe this clang-tidy error is a false-positive.

Still let's keep an eye on this

jpegBuf = nullptr;
return out_tensor;
}
#endif

Expand Down