-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path_FourierOp.py
229 lines (193 loc) · 8.9 KB
/
_FourierOp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Fourier Operator."""
# Copyright 2023 Physikalisch-Technische Bundesanstalt
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import numpy as np
import torch
from torchkbnufft import KbNufft
from torchkbnufft import KbNufftAdjoint
from mrpro.data._kdata._KData import KData
from mrpro.data._KTrajectory import KTrajectory
from mrpro.data._SpatialDimension import SpatialDimension
from mrpro.data.enums import TrajType
from mrpro.operators._FastFourierOp import FastFourierOp
from mrpro.operators._LinearOperator import LinearOperator
class FourierOp(LinearOperator):
"""Fourier Operator class."""
def __init__(
self,
recon_matrix: SpatialDimension[int],
encoding_matrix: SpatialDimension[int],
traj: KTrajectory,
nufft_oversampling: float = 2.0,
nufft_numpoints: int = 6,
nufft_kbwidth: float = 2.34,
) -> None:
"""Fourier Operator class.
Parameters
----------
recon_matrix
dimension of the reconstructed image
encoding_matrix
dimension of the encoded k-space
traj
the k-space trajectories where the frequencies are sampled
nufft_oversampling
oversampling used for interpolation in non-uniform FFTs
nufft_numpoints
number of neighbors for interpolation in non-uniform FFTs
nufft_kbwidth
size of the Kaiser-Bessel kernel interpolation in non-uniform FFTs
"""
super().__init__()
def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
return [
s
for s, i in zip((spatial_dims.z, spatial_dims.y, spatial_dims.x), (-3, -2, -1), strict=True)
if i in dims
]
def get_traj(traj: KTrajectory, dims: Sequence[int]):
return [k for k, i in zip((traj.kz, traj.ky, traj.kx), (-3, -2, -1), strict=True) if i in dims]
self._ignore_dims, self._fft_dims, self._nufft_dims = [], [], []
for dim, type_ in zip((-3, -2, -1), traj.type_along_kzyx, strict=True):
if type_ & TrajType.SINGLEVALUE:
# dimension which do not require any transform
self._ignore_dims.append(dim)
elif type_ & TrajType.ONGRID:
self._fft_dims.append(dim)
else:
self._nufft_dims.append(dim)
if self._fft_dims:
self._fast_fourier_op = FastFourierOp(
dim=tuple(self._fft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims),
)
# Find dimensions which require NUFFT
if self._nufft_dims:
fft_dims_k210 = [
dim
for dim in (-3, -2, -1)
if (traj.type_along_k210[dim] & TrajType.ONGRID)
and not (traj.type_along_k210[dim] & TrajType.SINGLEVALUE)
]
if self._fft_dims != fft_dims_k210:
raise NotImplementedError(
'If both FFT and NUFFT dims are present, Cartesian FFT dims need to be aligned with the '
'k-space dimension, i.e. kx along k0, ky along k1 and kz along k2',
)
self._nufft_im_size = get_spatial_dims(recon_matrix, self._nufft_dims)
grid_size = [int(size * nufft_oversampling) for size in self._nufft_im_size]
omega = [
k * 2 * torch.pi / ks
for k, ks in zip(
get_traj(traj, self._nufft_dims),
get_spatial_dims(encoding_matrix, self._nufft_dims),
strict=True,
)
]
# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction
self._fwd_nufft_op = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._kshape = traj.broadcasted_shape
@classmethod
def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> FourierOp:
"""Create an instance of FourierOp from kdata with default settings.
Parameters
----------
kdata
k-space data
recon_shape
dimension of the reconstructed image. Defaults to KData.header.recon_matrix
"""
return cls(
recon_matrix=kdata.header.recon_matrix if recon_shape is None else recon_shape,
encoding_matrix=kdata.header.encoding_matrix,
traj=kdata.traj,
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Forward operator mapping the coil-images to the coil k-space data.
Parameters
----------
x
coil image data with shape: (... coils z y x)
Returns
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
if len(self._fft_dims):
# FFT
(x,) = self._fast_fourier_op.forward(x)
if self._nufft_dims:
# we need to move the nufft-dimensions to the end and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils nufft_dims
# we could move the permute to __init__ but then we still would need to prepend if len(other)>1
keep_dims = [-4, *self._nufft_dims] # -4 is always coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)
x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1)
# omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims)
# TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape).
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
x = self._fwd_nufft_op(x, omega, norm='ortho')
shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims]
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils
x = x.permute(*unpermute)
return (x,)
def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Adjoint operator mapping the coil k-space data to the coil images.
Parameters
----------
x
coil k-space data with shape: (... coils k2 k1 k0)
Returns
-------
coil image data with shape: (... coils z y x)
"""
if self._fft_dims:
# IFFT
(x,) = self._fast_fourier_op.adjoint(x)
if self._nufft_dims:
# we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils (nufft_dims)
keep_dims = [-4, *self._nufft_dims] # -4 is coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)
x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
x = self._adj_nufft_op(x, omega, norm='ortho')
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], *x.shape[-len(keep_dims) :])
x = x.permute(*unpermute)
return (x,)