-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransforms.py
349 lines (267 loc) · 10.5 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import numpy as np
import skimage.transform
import scipy
import cv2 as cv
import torch
import torch.nn.functional as F
# ==================================================
# Padding, cropping
# ==================================================
def pad_base(x, padding, **kwargs):
if isinstance(x, np.ndarray):
padding = padding.astype(int).tolist()
x = np.pad(x, padding, **kwargs)
elif isinstance(x, torch.Tensor):
padding = tuple(np.flip(padding, axis=0).flatten().astype(int))
x = F.pad(x, padding, **kwargs)
else:
raise NotImplementedError
return x
class PadTo2Power(object):
"""
Pad input axes such that their sizes are multiples of 2^k
Use to ensure that inputs upsample to original sizes in autoencoder-style networks
"""
def __init__(self, k, axes=(0, 1), **kwargs):
self.axes = axes # axes to pad
self.k = k # 2 power to pad to
self.kwargs = kwargs
def __call__(self, x):
padding = np.zeros((len(x.shape), 2))
for axis in self.axes:
x_dim = x.shape[axis]
diff = 0 if x_dim % 2 ** self.k == 0 \
else ((x_dim // 2 ** self.k + 1) * 2 ** self.k) - x_dim
if diff == 0:
continue
padding[axis] = [diff // 2, diff // 2 + diff % 2]
return pad_base(x, padding, **self.kwargs)
class PadToSquare(object):
def __init__(self, axes=(0, 1), **kwargs):
self.axes = axes # axes to make square
self.kwargs = kwargs # pass to np.pad()
def __call__(self, x):
diff = x.shape[self.axes[0]] - x.shape[self.axes[1]]
if diff == 0:
return x
pad_axis = self.axes[0] if diff < 0 else self.axes[1]
padding = np.zeros((len(x.shape), 2))
padding[pad_axis] = [abs(diff) // 2, abs(diff) // 2 + abs(diff) % 2]
return pad_base(x, padding, **self.kwargs)
class PadToSize(object):
def __init__(self, size, axes=(0, 1), **kwargs):
self.axes = axes
self.size = size
self.kwargs = kwargs
def __call__(self, x):
h_diff = x.shape[self.axes[0]] - self.size[0]
w_diff = x.shape[self.axes[1]] - self.size[1]
assert h_diff <= 0 and w_diff <= 0
if h_diff == 0 and w_diff == 0:
return x
padding = np.zeros((len(x.shape), 2))
padding[0] = [abs(h_diff) // 2, abs(h_diff) // 2 + abs(h_diff) % 2]
padding[1] = [abs(w_diff) // 2, abs(w_diff) // 2 + abs(w_diff) % 2]
return pad_base(x, padding, **self.kwargs)
class CenterCrop(object):
""" Center crop ndarray image (h,w,...) """
def __init__(self, size):
self.size = size
def __call__(self, x):
assert isinstance(x, np.ndarray)
h_diff = x.shape[0] - self.size[0]
w_diff = x.shape[1] - self.size[1]
assert h_diff >= 0 and w_diff >= 0
if h_diff == 0 and w_diff == 0:
return x
elif h_diff == 0:
return x[:, w_diff // 2:-(w_diff // 2 + w_diff % 2)]
elif w_diff == 0:
return x[h_diff // 2:-(h_diff // 2 + h_diff % 2), :]
else:
return x[h_diff // 2:-(h_diff // 2 + h_diff % 2), w_diff // 2:-(w_diff // 2 + w_diff % 2)]
class PadOrCenterCrop(object):
""" Pad or center crop to given size (h,w,...) """
def __init__(self, size, **kwargs):
self.size = size
self.kwargs = kwargs
def __call__(self, x):
assert isinstance(x, np.ndarray)
h_diff = x.shape[0] - self.size[0]
w_diff = x.shape[1] - self.size[1]
if h_diff == 0 and w_diff == 0:
return x
elif h_diff < 0 and w_diff < 0: # pad
padding = np.zeros((len(x.shape), 2))
padding[0] = [abs(h_diff) // 2, abs(h_diff) // 2 + abs(h_diff) % 2]
padding[1] = [abs(w_diff) // 2, abs(w_diff) // 2 + abs(w_diff) % 2]
return pad_base(x, padding, **self.kwargs)
elif h_diff >= 0 and w_diff >= 0: # crop
return CenterCrop(self.size)(x)
else: # pad to square then crop
h_pad_size = np.max([x.shape[0], self.size[0]])
w_pad_size = np.max([x.shape[1], self.size[1]])
x = PadToSize((h_pad_size, w_pad_size), (0, 1), **self.kwargs)(x)
return CenterCrop(self.size)(x)
# ==================================================
# Resizing
# ==================================================
class MultiChannelTransform(object):
"""
Base class for wrapping transforms which only support single-channel
- Supports up to 1 channel dimension, assumes HWC format
- Implement __init__() and _transform() for child classes
"""
def __init__(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, x):
if len(x.shape) == 2:
return self._transform(x)
elif len(x.shape) == 3:
arr = []
for i in range(x.shape[2]):
arr += [self._transform(x[:, :, i])]
return np.stack(arr, axis=2)
else:
raise NotImplementedError
def _transform(self, x):
raise NotImplementedError
class DownsampleShortAxis(MultiChannelTransform):
""" Downsample to match shorter axis of the image to the given size """
def __init__(self, size, **kwargs):
self.size = size
self.kwargs = kwargs
def _transform(self, x):
# don't do anything if any axis is smaller than size
if np.min(x.shape) < self.size:
return x
ds_ratio = np.min(x.shape) / self.size
new_shape = (int(x.shape[0] / ds_ratio), int(x.shape[1] / ds_ratio))
return skimage.transform.resize(x, new_shape, **self.kwargs)
class Resize(MultiChannelTransform):
""" Resize ndarray image """
def __init__(self, size, **kwargs):
self.size = size
self.kwargs = kwargs
def _transform(self, x):
assert isinstance(x, np.ndarray)
return skimage.transform.resize(x, self.size, **self.kwargs)
# ==================================================
# Filtering
# ==================================================
class GaussianSmooth(MultiChannelTransform):
def __init__(self, size, sigma):
assert size % 2 == 1
self.size = size
x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
self.kernel = g / g.sum()[None]
def _transform(self, x):
assert isinstance(x, np.ndarray)
return scipy.signal.convolve2d(x, self.kernel, mode='same')
class CLAHE(MultiChannelTransform):
def __init__(self, clipLimit, tileGridSize):
self.clahe = cv.createCLAHE(
clipLimit=clipLimit, tileGridSize=tileGridSize)
def _transform(self, x):
assert isinstance(x, np.ndarray)
return self.clahe.apply(x.astype(np.uint16))
# ==================================================
# Scaling
# ==================================================
class MinMax(object):
def __init__(self, axes=(0, 1)):
self.axes = axes
def __call__(self, x):
return ((x - x.min(axis=self.axes)) /
(x.max(axis=self.axes) - x.min(axis=self.axes)))
class ZScore(object):
def __init__(self, axes=(0, 1)):
self.axes = axes
def __call__(self, x):
return (x - x.mean(axis=self.axes)) / x.std(axis=self.axes)
# ==================================================
# Misc
# ==================================================
class SelectChannel(object):
""" Select only one channel of the input """
def __init__(self, label_id, x_format="CHW"):
self.label_id = label_id
self.x_format = x_format
def __call__(self, x):
if self.x_format == "CHW":
return x[self.label_id]
elif self.x_format == "HWC":
return x[:, :, self.label_id]
else:
raise NotImplementedError
class SelectClass(object):
""" Select only one class in a multi-class label """
def __init__(self, class_id):
self.class_id = class_id
def __call__(self, x):
if isinstance(x, np.ndarray):
return (x == self.class_id).astype(int)
elif isinstance(x, torch.Tensor):
return (x == self.class_id).int()
else:
raise NotImplementedError
class AssertWidthMajor(object):
""" make sure that input width > height (specifically for BSDS500) """
def __call__(self, x):
assert isinstance(x, torch.Tensor)
if x.shape[-2] > x.shape[-1]:
x = torch.transpose(x, -2, -1)
return x
class ExpandDims(object):
def __init__(self, dim):
self.dim = dim
def __call__(self, x):
if isinstance(x, np.ndarray):
return np.expand_dims(x, self.dim)
elif isinstance(x, torch.Tensor):
return torch.unsqueeze(x, self.dim)
else:
raise NotImplementedError
class ToTensor(object):
def __init__(self, make_CHW=True, input_format="HW", out_type=float):
self.make_CHW = make_CHW
self.input_format = input_format
self.out_type = out_type
def __call__(self, x):
assert isinstance(x, np.ndarray), "Expected numpy.ndarray"
if self.make_CHW:
if self.input_format == "HW":
x = np.expand_dims(x, axis=0)
elif self.input_format == "HWC":
x = np.transpose(x, axes=(2, 0, 1))
else:
raise NotImplementedError
x = x.astype(self.out_type)
x = torch.from_numpy(x).contiguous()
return x
class TimeSeriesToTensor(object):
def __init__(self, make_TCHW=True, input_format="HWT", out_type=float):
self.make_TCHW = make_TCHW
self.input_format = input_format
self.out_type = out_type
def __call__(self, x):
assert isinstance(x, np.ndarray), "Expected numpy.ndarray"
if self.make_TCHW:
if self.input_format == "HWT":
x = np.transpose(x, axes=(2, 0, 1))
x = np.expand_dims(x, axis=1)
elif self.input_format == "THWC":
x = np.transpose(x, axes=(0, 3, 1, 2))
elif self.input_format == "HWCT":
x = np.transpose(x, axes=(3, 2, 0, 1))
else:
raise NotImplementedError
else:
if self.input_format == "HWT":
x = np.transpose(x, axes=(2, 0, 1))
else:
raise NotImplementedError
x = x.astype(self.out_type)
x = torch.from_numpy(x).contiguous()
return x