-
Notifications
You must be signed in to change notification settings - Fork 7k
/
image.py
230 lines (182 loc) · 6.61 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import torch
import os
import os.path as osp
import importlib.machinery
_HAS_IMAGE_OPT = False
try:
lib_dir = osp.abspath(osp.join(osp.dirname(__file__), ".."))
loader_details = (
importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES
)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type]
ext_specs = extfinder.find_spec("image")
if os.name == 'nt':
# Load the image extension using LoadLibraryExW
import ctypes
import sys
kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
prev_error_mode = kernel32.SetErrorMode(0x0001)
kernel32.LoadLibraryW.restype = ctypes.c_void_p
if with_load_library_flags:
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
if ext_specs is not None:
res = kernel32.LoadLibraryExW(ext_specs.origin, None, 0x00001100)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += (f' Error loading "{ext_specs.origin}" or any or '
'its dependencies.')
raise err
kernel32.SetErrorMode(prev_error_mode)
if ext_specs is not None:
torch.ops.load_library(ext_specs.origin)
_HAS_IMAGE_OPT = True
except (ImportError, OSError):
pass
def read_file(path: str) -> torch.Tensor:
"""
Reads and outputs the bytes contents of a file as a uint8 Tensor
with one dimension.
Arguments:
path (str): the path to the file to be read
Returns:
data (Tensor)
"""
data = torch.ops.image.read_file(path)
return data
def write_file(filename: str, data: torch.Tensor) -> None:
"""
Writes the contents of a uint8 tensor with one dimension to a
file.
Arguments:
filename (str): the path to the file to be written
data (Tensor): the contents to be written to the output file
"""
torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
output = torch.ops.image.decode_png(input)
return output
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding PNG file.
Parameters
----------
input: Tensor[channels, image_height, image_width]
int8 image tensor of `c` channels, where `c` must 3 or 1.
compression_level: int
Compression factor for the resulting file, it must be a number
between 0 and 9. Default: 6
Returns
-------
output: Tensor[1]
A one dimensional int8 tensor that contains the raw bytes of the
PNG file.
"""
output = torch.ops.image.encode_png(input, compression_level)
return output
def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
"""
Takes an input tensor in CHW layout (or HW in the case of grayscale images)
and saves it in a PNG file.
Parameters
----------
input: Tensor[channels, image_height, image_width]
int8 image tensor of `c` channels, where `c` must be 1 or 3.
filename: str
Path to save the image.
compression_level: int
Compression factor for the resulting file, it must be a number
between 0 and 9. Default: 6
"""
output = encode_png(input, compression_level)
write_file(filename, output)
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
output = torch.ops.image.decode_jpeg(input)
return output
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding JPEG file.
Parameters
----------
input: Tensor[channels, image_height, image_width])
int8 image tensor of `c` channels, where `c` must be 1 or 3.
quality: int
Quality of the resulting JPEG file, it must be a number between
1 and 100. Default: 75
Returns
-------
output: Tensor[1]
A one dimensional int8 tensor that contains the raw bytes of the
JPEG file.
"""
if quality < 1 or quality > 100:
raise ValueError('Image quality should be a positive number '
'between 1 and 100')
output = torch.ops.image.encode_jpeg(input, quality)
return output
def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
"""
Takes an input tensor in CHW layout and saves it in a JPEG file.
Parameters
----------
input: Tensor[channels, image_height, image_width]
int8 image tensor of `c` channels, where `c` must be 1 or 3.
filename: str
Path to save the image.
quality: int
Quality of the resulting JPEG file, it must be a number
between 1 and 100. Default: 75
"""
output = encode_jpeg(input, quality)
write_file(filename, output)
def decode_image(input: torch.Tensor) -> torch.Tensor:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Parameters
----------
input: Tensor
a one dimensional uint8 tensor containing the raw bytes of the
PNG or JPEG image.
Returns
-------
output: Tensor[3, image_height, image_width]
"""
output = torch.ops.image.decode_image(input)
return output
def read_image(path: str) -> torch.Tensor:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Parameters
----------
path: str
path of the JPEG or PNG image.
Returns
-------
output: Tensor[3, image_height, image_width]
"""
data = read_file(path)
return decode_image(data)