Skip to content

Commit

Permalink
fix: refactoring the B0FieldTransform implementation
Browse files Browse the repository at this point in the history
Addresses the issue of properly applying fieldmaps on distorted data.

Resolves: #345.
  • Loading branch information
oesteban committed Apr 7, 2023
1 parent 891e81b commit bfa147e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 71 deletions.
130 changes: 73 additions & 57 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,11 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
in_coeff = InputMultiObject(
File(exists=True),
mandatory=True,
desc="input coefficients, after alignment to the EPI data",
desc="input coefficients as calculated in the estimation stage",
)
fmap2data_xfm = File(
exists=True,
desc="the transform by which the fieldmap can be resampled on the target EPI's grid.",
)
in_xfms = InputMultiObject(
File(exists=True), desc="list of head-motion correction matrices"
Expand All @@ -294,21 +298,60 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
)
)
num_threads = traits.Int(nohash=True, desc="number of threads")
approx = traits.Bool(
True,
usedefault=True,
desc=(
"reconstruct the fieldmap on it's original grid and then interpolate on the "
"rotated grid, rather than reconstructing directly on the rotated grid."
),
)


class _ApplyCoeffsFieldOutputSpec(TraitedSpec):
out_corrected = OutputMultiObject(File(exists=True))
out_field = OutputMultiObject(File(exists=True))
out_warp = OutputMultiObject(File(exists=True))


class ApplyCoeffsField(SimpleInterface):
"""Convert a set of B-Spline coefficients to a full displacements map."""
"""
Undistort a target, distorted image with a fieldmap, formalized by its B-Spline coefficients.
Preconditions:
* We have a "target EPI" - a BOLD or DWI dataset (or even MPRAGE, same principle),
without having gone through HMC or SDC.
* We have also the list of HMC matrices that *accounts for* head-motion, so after resampling
the dataset through this list of transforms *the head does not move anymore*.
* We have estimated the fieldmap's coefficients
* We have the "fieldmap-to-data" affine transform that aligns the target dataset (e.g., EPI)
and the fieldmap's "magnitude" (phasediff et al.) or "reference" (pepolar, syn).
The algorithm is implemented in the :obj:`~sdcflows.transform.B0FieldTransform` object.
First, we will call :obj:`~sdcflows.transform.B0FieldTransform.fit`, which
results in:
1. The reference grid of the target dataset is projected onto the fieldmap space
2. The B-Spline coefficients are applied to reconstruct the field on the grid resulting
above.
After which, we can then call :obj:`~sdcflows.transform.B0FieldTransform.apply`.
This second step will:
3. Find the location of every voxel on each timepoint (meaning, after the head moved)
and progress (or recede) along the phase-encoding axis to find the actual (voxel)
coordinates of each voxel.
With those coordinates known, interpolation is trivial.
4. Generate a spatial image with the new data.
"""

input_spec = _ApplyCoeffsFieldInputSpec
output_spec = _ApplyCoeffsFieldOutputSpec

def _run_interface(self, runtime):
from sdcflows.transform import B0FieldTransform

n = len(self.inputs.in_data)

ro_time = self.inputs.ro_time
Expand All @@ -320,63 +363,36 @@ def _run_interface(self, runtime):
pe_dir *= n

unwarp = None
hmc_mats = [None] * n
if isdefined(self.inputs.in_xfms):
# Split ITK matrices in separate files if they come collated
hmc_mats = (
list(_split_itk_file(self.inputs.in_xfms[0]))
if len(self.inputs.in_xfms) == 1
else self.inputs.in_xfms
)
else:
from sdcflows.transform import B0FieldTransform

# Pre-cached interpolator object
unwarp = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff],
)

if not isdefined(self.inputs.num_threads) or self.inputs.num_threads < 2:
# Linear execution (1 core)
outputs = [
_b0_resampler(
fname,
self.inputs.in_coeff,
pe_dir[i],
ro_time[i],
hmc_mats[i],
unwarp, # if no HMC matrices, interpolator can be shared
runtime.cwd,
)
for i, fname in enumerate(self.inputs.in_data)
]
else:
# Embarrasingly parallel execution
from concurrent.futures import ProcessPoolExecutor

outputs = [None] * len(self.inputs.in_data)
with ProcessPoolExecutor(max_workers=self.inputs.num_threads) as ex:
outputs = ex.map(
_b0_resampler,
self.inputs.in_data,
[self.inputs.in_coeff] * n,
pe_dir,
ro_time,
hmc_mats,
[None] * n, # force a new interpolator for each process
[runtime.cwd] * n,
)

(
self._results["out_corrected"],
self._results["out_warp"],
self._results["out_field"],
) = zip(*outputs)
# Pre-cached interpolator object
unwarp = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff],
num_threads=(
None if not isdefined(self.inputs.num_threads) else self.inputs.num_threads
),
)

out_fields = set(self._results["out_field"]) - set([None])
if len(out_fields) == 1:
self._results["out_field"] = out_fields.pop()
# Reconstruct the field from the coefficients, on the target dataset's grid.
unwarp.fit(
self.inputs.data,
affine=(
None if not isdefined(self.inputs.fmap2data_xfm) else self.inputs.fmap2data_xfm
),
approx=self.inputs.approx,
)

# We can now write out the fieldmap
# unwarp.mapped.to_filename(out_field)
# self._results["out_field"] = out_field

# HMC matrices are only necessary when reslicing the data (i.e., apply())
hmc_mats = None
self._results["out_corrected"] = unwarp.apply(
self.inputs.data,
pe_dir,
ro_time,
xfms=hmc_mats,
)
return runtime


Expand Down
34 changes: 21 additions & 13 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,39 @@
from niworkflows.interfaces.nibabel import reorient_image


def _clear_mapped(instance, attribute, value):
instance.mapped = None
return value


@attr.s(slots=True)
class B0FieldTransform:
"""Represents and applies the transform to correct for susceptibility distortions."""

coeffs = attr.ib(default=None)
"""B-Spline coefficients (one value per control point)."""
xfm = attr.ib(default=None, on_setattr=_clear_mapped)
"""A rigid-body transform to prepend to the unwarping displacements field."""
mapped = attr.ib(default=None, init=False)
"""
A cache of the interpolated field in Hz (i.e., the fieldmap *mapped* on to the
target image we want to correct).
"""

def fit(self, spatialimage):
def fit(self, target_reference, affine=None, approx=True):
r"""
Generate the interpolation matrix (and the VSM with it).
Implements Eq. :math:`\eqref{eq:1}`, interpolating :math:`f(\mathbf{s})`
for all voxels in the target-image's extent.
Parameters
----------
target_reference : `spatialimage`
The image object containing a reference grid (same as that of the data
to be resampled). If a 4D dataset is provided, then the fourth dimension
will be dropped.
affine : :obj:`nitransforms.linear.Affine`
Transform that maps coordinates on the target_reference on to the
fieldmap reference.
approx : :obj:`bool`
If ``True``, do not reconstruct the B-Spline field directly on the target
(which will likely not be aligned with the fieldmap's grid), but rather use
the fieldmap's grid and then use just regular interpolation.
Returns
-------
updated : :obj:`bool`
Expand Down Expand Up @@ -117,9 +124,10 @@ def fit(self, spatialimage):

def apply(
self,
spatialimage,
data,
pe_dir,
ro_time,
xfms=None,
order=3,
mode="constant",
cval=0.0,
Expand All @@ -131,12 +139,12 @@ def apply(
Parameters
----------
spatialimage : `spatialimage`
data : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
xfms : `None` or :obj:`list`
A list of rigid-body transformations previously estimated that will
realign the dataset (that is, compensate for head motion) after resampling.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/workflows/apply/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def init_coeff2epi_wf(
if not write_coeff:
return workflow

# Map the coefficients into the EPI space
# Resample the coefficients into the EPI grid
map_coeff = pe.Node(TransformCoefficients(), name="map_coeff")
map_coeff.interface._always_run = debug

Expand Down

0 comments on commit bfa147e

Please sign in to comment.