Skip to content

Commit

Permalink
Add write_file (pytorch#2765)
Browse files Browse the repository at this point in the history
* Add write_file

* Fix lint
  • Loading branch information
fmassa authored and vfdev-5 committed Dec 4, 2020
1 parent 5a0fc91 commit 3839527
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 1 deletion.
14 changes: 13 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png)
encode_png, write_png, write_file)
import numpy as np

from common_utils import get_tmp_dir
Expand Down Expand Up @@ -225,6 +225,18 @@ def test_read_file(self):
RuntimeError, "No such file or directory: 'tst'"):
read_file('tst')

def test_write_file(self):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)

with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torchvision/csrc/cpu/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
20 changes: 20 additions & 0 deletions torchvision/csrc/cpu/image/read_write_file_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,23 @@ torch::Tensor read_file(std::string filename) {

return data;
}

void write_file(std::string filename, torch::Tensor& data) {
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");

// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor");

auto fileBytes = data.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");

TORCH_CHECK(outfile != NULL, "Error opening output file");

fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile);
}
2 changes: 2 additions & 0 deletions torchvision/csrc/cpu/image/read_write_file_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
#include <torch/torch.h>

C10_EXPORT torch::Tensor read_file(std::string filename);

C10_EXPORT void write_file(std::string filename, torch::Tensor& data);
12 changes: 12 additions & 0 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def read_file(path: str) -> torch.Tensor:
return data


def write_file(filename: str, data: torch.Tensor) -> None:
"""
Writes the contents of a uint8 tensor with one dimension to a
file.
Arguments:
filename (str): the path to the file to be written
data (Tensor): the contents to be written to the output file
"""
torch.ops.image.write_file(filename, data)


def decode_png(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
Expand Down

0 comments on commit 3839527

Please sign in to comment.