-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
259 lines (213 loc) · 7.56 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import numpy as np
import torch
def to_tensor(data):
if np.iscomplexobj(data):
data = np.stack((data.real, data.imag), axis=-1)
return torch.from_numpy(data)
def rfft2(data):
data = ifftshift(data, dim=(-2, -1))
data = torch.rfft(data, 2, normalized=False,onesided=False)
data = fftshift(data, dim=(-3, -2))
data=data.permute(0,1,4,2,3)
data=data.squeeze(1)
# data=data.reshape(1,2,320,320)
return data
def rfft2_regular(data):
data = ifftshift(data, dim=(-2, -1))
data = torch.rfft(data, 2, normalized=True,onesided=False)
data = fftshift(data, dim=(-3, -2))
return data
def irfft2(data):
data=data.unsqueeze(1)
data = data.permute(0, 1, 3, 4, 2)
# data=data.reshape(1,1,320,320,2)
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.irfft(data, 2, normalized=False,onesided=False)
data = fftshift(data, dim=(-2, -1))
return data
def irfft2_regular(data):
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.irfft(data, 2, normalized=False,onesided=False)
data = fftshift(data, dim=(-2, -1))
return data
def fft2(data):
assert data.size(-1) == 2
data = torch.fft.ifftshift(data, dim=(-3, -2))
data = torch.fft.fft2(data, dim = (-3,-2), norm='ortho')
data = torch.fft.fftshift(data, dim=(-3, -2))
return data
def fft2_cplx(data):
data = torch.fft.ifftshift(data, dim=(-2, -1))
data = torch.fft.fft2(data, dim=(-2,-1), norm=None)
data = torch.fft.fftshift(data, dim=(-2, -1))
return data
def ifft2_cplx(data):
data = torch.fft.ifftshift(data, dim=(-2, -1))
data = torch.fft.ifft2(data, dim=(-2,-1), norm=None)
data = torch.fft.fftshift(data, dim=(-2, -1))
return data
def ifft2(data):
data=data.unsqueeze(1)
data = data.permute(0, 1, 3, 4, 2)
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.ifft(data, 2, normalized=False)
data = fftshift(data, dim=(-3, -2))
data = data.permute(0, 1, 4, 2, 3)
data = data.squeeze(1)
return data
def ifft2_regular(data):
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.ifft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
def complex_abs(data):
"""
Compute the absolute value of a complex valued input tensor.
Args:
data (torch.Tensor): A complex valued tensor, where the size of the final dimension
should be 2.
Returns:
torch.Tensor: Absolute value of data
"""
assert data.size(-1) == 2
return (data ** 2).sum(dim=-1).sqrt()
def root_sum_of_squares(data, dim=0):
"""
Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor.
Args:
data (torch.Tensor): The input tensor
dim (int): The dimensions along which to apply the RSS transform
Returns:
torch.Tensor: The RSS value
"""
return torch.sqrt((data ** 2).sum(dim))
def center_crop(data, shape):
"""
Apply a center crop to the input real image or batch of real images.
Args:
data (torch.Tensor): The input tensor to be center cropped. It should have at
least 2 dimensions and the cropping is applied along the last two dimensions.
shape (int, int): The output shape. The shape should be smaller than the
corresponding dimensions of data.
Returns:
torch.Tensor: The center cropped image
"""
assert 0 < shape[0] <= data.shape[-2]
assert 0 < shape[1] <= data.shape[-1]
w_from = (data.shape[-2] - shape[0]) // 2
h_from = (data.shape[-1] - shape[1]) // 2
w_to = w_from + shape[0]
h_to = h_from + shape[1]
return data[..., w_from:w_to, h_from:h_to]
def complex_center_crop(data, shape):
"""
Apply a center crop to the input image or batch of complex images.
Args:
data (torch.Tensor): The complex input tensor to be center cropped. It should
have at least 3 dimensions and the cropping is applied along dimensions
-3 and -2 and the last dimensions should have a size of 2.
shape (int, int): The output shape. The shape should be smaller than the
corresponding dimensions of data.
Returns:
torch.Tensor: The center cropped image
"""
assert 0 < shape[0] <= data.shape[-3]
assert 0 < shape[1] <= data.shape[-2]
w_from = (data.shape[-3] - shape[0]) // 2
h_from = (data.shape[-2] - shape[1]) // 2
w_to = w_from + shape[0]
h_to = h_from + shape[1]
return data[..., w_from:w_to, h_from:h_to, :]
def normalize(data, mean, stddev, eps=0.):
"""
Normalize the given tensor using:
(data - mean) / (stddev + eps)
Args:
data (torch.Tensor): Input data to be normalized
mean (float): Mean value
stddev (float): Standard deviation
eps (float): Added to stddev to prevent dividing by zero
Returns:
torch.Tensor: Normalized tensor
"""
return (data - mean) / (stddev + eps)
def normalize_instance(data, eps=0.):
"""
Normalize the given tensor using:
(data - mean) / (stddev + eps)
where mean and stddev are computed from the data itself.
Args:
data (torch.Tensor): Input data to be normalized
eps (float): Added to stddev to prevent dividing by zero
Returns:
torch.Tensor: Normalized tensor
"""
mean = data.mean()
std = data.std()
return normalize(data, mean, std, eps), mean, std
def normalize_instance_per_channel(data, eps=0.):
"""
Normalize the given tensor using:
(data - mean) / (stddev + eps)
where mean and stddev are computed from the data itself.
Args:
data (torch.Tensor): Input data to be normalized
eps (float): Added to stddev to prevent dividing by zero
Returns:
torch.Tensor: Normalized tensor
"""
for i in range(data.shape[0]):
for j in range(2):
mean = data[i, :, :, j].mean()
std = data[i, :, :, j].std()
data[i, :, :, j] = normalize(data[i, :, :, j], mean, std, eps)
return data, mean, std
# Helper functions
def roll(x, shift, dim):
"""
Similar to np.roll but applies to PyTorch Tensors
"""
if isinstance(shift, (tuple, list)):
assert len(shift) == len(dim)
for s, d in zip(shift, dim):
x = roll(x, s, d)
return x
shift = shift % x.size(dim)
if shift == 0:
return x
left = x.narrow(dim, 0, x.size(dim) - shift)
right = x.narrow(dim, x.size(dim) - shift, shift)
return torch.cat((right, left), dim=dim)
def fftshift(x, dim=None):
"""
Similar to np.fft.fftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [dim // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = x.shape[dim] // 2
else:
shift = [x.shape[i] // 2 for i in dim]
return roll(x, shift, dim)
def ifftshift(x, dim=None):
"""
Similar to np.fft.ifftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [(dim + 1) // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = (x.shape[dim] + 1) // 2
else:
shift = [(x.shape[i] + 1) // 2 for i in dim]
return roll(x, shift, dim)