From d16f784e7709d59e8fb2e87c368a864c57aefa05 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 6 Oct 2020 18:22:24 +0200 Subject: [PATCH 1/2] Add write_file --- test/test_image.py | 14 +++++++++++- torchvision/csrc/cpu/image/image.cpp | 1 + .../csrc/cpu/image/read_write_file_cpu.cpp | 22 +++++++++++++++++++ .../csrc/cpu/image/read_write_file_cpu.h | 2 ++ torchvision/io/image.py | 12 ++++++++++ 5 files changed, 50 insertions(+), 1 deletion(-) diff --git a/test/test_image.py b/test/test_image.py index 7a0317cae83..ae23b4db5f9 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -9,7 +9,7 @@ from PIL import Image from torchvision.io.image import ( read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, - encode_png, write_png) + encode_png, write_png, write_file) import numpy as np from common_utils import get_tmp_dir @@ -238,6 +238,18 @@ def test_read_file(self): RuntimeError, "No such file or directory: 'tst'"): read_file('tst') + def test_write_file(self): + with get_tmp_dir() as d: + fname, content = 'test1.bin', b'TorchVision\211\n' + fpath = os.path.join(d, fname) + content_tensor = torch.tensor(list(content), dtype=torch.uint8) + write_file(fpath, content_tensor) + + with open(fpath, 'rb') as f: + saved_content = f.read() + self.assertEqual(content, saved_content) + os.unlink(fpath) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/cpu/image/image.cpp b/torchvision/csrc/cpu/image/image.cpp index 6b8022f3014..f0e4ce0e177 100644 --- a/torchvision/csrc/cpu/image/image.cpp +++ b/torchvision/csrc/cpu/image/image.cpp @@ -20,4 +20,5 @@ static auto registry = torch::RegisterOperators() .op("image::encode_jpeg", &encodeJPEG) .op("image::write_jpeg", &writeJPEG) .op("image::read_file", &read_file) + .op("image::write_file", &write_file) .op("image::decode_image", &decode_image); diff --git a/torchvision/csrc/cpu/image/read_write_file_cpu.cpp b/torchvision/csrc/cpu/image/read_write_file_cpu.cpp index e4213ca851e..861fb10d1a9 100644 --- a/torchvision/csrc/cpu/image/read_write_file_cpu.cpp +++ b/torchvision/csrc/cpu/image/read_write_file_cpu.cpp @@ -18,3 +18,25 @@ torch::Tensor read_file(std::string filename) { return data; } + +void write_file( + std::string filename, + torch::Tensor& data) { + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); + + auto fileBytes = data.data_ptr(); + auto fileCStr = filename.c_str(); + FILE* outfile = fopen(fileCStr, "wb"); + + TORCH_CHECK(outfile != NULL, "Error opening output file"); + + fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); + fclose(outfile); +} diff --git a/torchvision/csrc/cpu/image/read_write_file_cpu.h b/torchvision/csrc/cpu/image/read_write_file_cpu.h index 42b5e7cc36b..a6dbfb700a7 100644 --- a/torchvision/csrc/cpu/image/read_write_file_cpu.h +++ b/torchvision/csrc/cpu/image/read_write_file_cpu.h @@ -5,3 +5,5 @@ #include C10_EXPORT torch::Tensor read_file(std::string filename); + +C10_EXPORT void write_file(std::string filename, torch::Tensor& data); diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 398d682689e..30d67aa27a5 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -38,6 +38,18 @@ def read_file(path: str) -> torch.Tensor: return data +def write_file(filename: str, data: torch.Tensor) -> None: + """ + Writes the contents of a uint8 tensor with one dimension to a + file. + + Arguments: + filename (str): the path to the file to be written + data (Tensor): the contents to be written to the output file + """ + torch.ops.image.write_file(filename, data) + + def decode_png(input: torch.Tensor) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB Tensor. From e6a83756a14ae12bf32058c724d7996631b4dab7 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 6 Oct 2020 19:37:03 +0200 Subject: [PATCH 2/2] Fix lint --- torchvision/csrc/cpu/image/read_write_file_cpu.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/csrc/cpu/image/read_write_file_cpu.cpp b/torchvision/csrc/cpu/image/read_write_file_cpu.cpp index 861fb10d1a9..a1164938b70 100644 --- a/torchvision/csrc/cpu/image/read_write_file_cpu.cpp +++ b/torchvision/csrc/cpu/image/read_write_file_cpu.cpp @@ -19,9 +19,7 @@ torch::Tensor read_file(std::string filename) { return data; } -void write_file( - std::string filename, - torch::Tensor& data) { +void write_file(std::string filename, torch::Tensor& data) { // Check that the input tensor is on CPU TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");