Skip to content

Commit

Permalink
Merge pull request #234 from oesteban/enh/4d-resampling
Browse files Browse the repository at this point in the history
ENH: Improve support of 4D in ``sdcflows.interfaces.bspline.ApplyCoeffsField``
  • Loading branch information
oesteban authored Sep 21, 2021
2 parents bd61a37 + 8ca1e79 commit 4b464fa
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 78 deletions.
190 changes: 132 additions & 58 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#
"""Filtering of :math:`B_0` field mappings with B-Splines."""
from pathlib import Path
from functools import partial
import numpy as np
import nibabel as nb
from nibabel.affines import apply_affine
Expand All @@ -34,12 +33,13 @@
TraitedSpec,
File,
traits,
isdefined,
SimpleInterface,
InputMultiObject,
OutputMultiObject,
)

from sdcflows.transform import grid_bspline_weights as gbsw, B0FieldTransform
from sdcflows.transform import grid_bspline_weights as gbsw


LOW_MEM_BLOCK_SIZE = 1000
Expand Down Expand Up @@ -210,14 +210,17 @@ def _run_interface(self, runtime):


class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
in_target = InputMultiObject(
in_data = InputMultiObject(
File(exist=True, mandatory=True, desc="input EPI data to be corrected")
)
in_coeff = InputMultiObject(
File(exists=True),
mandatory=True,
desc="input coefficients, after alignment to the EPI data",
)
in_xfms = InputMultiObject(
File(exists=True), desc="list of head-motion correction matrices"
)
ro_time = InputMultiObject(
traits.Float(mandatory=True, desc="EPI readout time (s).")
)
Expand All @@ -230,14 +233,15 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
"k",
"k-",
mandatory=True,
desc="the phase-encoding direction corresponding to in_target",
desc="the phase-encoding direction corresponding to in_data",
)
)
num_threads = traits.Int(nohash=True, desc="number of threads")


class _ApplyCoeffsFieldOutputSpec(TraitedSpec):
out_corrected = OutputMultiObject(File(exists=True))
out_field = File(exists=True)
out_field = OutputMultiObject(File(exists=True))
out_warp = OutputMultiObject(File(exists=True))


Expand All @@ -248,38 +252,73 @@ class ApplyCoeffsField(SimpleInterface):
output_spec = _ApplyCoeffsFieldOutputSpec

def _run_interface(self, runtime):
# Prepare output names
filename = partial(fname_presuffix, newpath=runtime.cwd)

self._results["out_field"] = filename(self.inputs.in_coeff[0], suffix="_field")
self._results["out_warp"] = []
self._results["out_corrected"] = []
n = len(self.inputs.in_data)

xfm = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff]
)
xfm.fit(self.inputs.in_target[0])
xfm.shifts.to_filename(self._results["out_field"])

n_inputs = len(self.inputs.in_target)
ro_time = self.inputs.ro_time
if len(ro_time) == 1:
ro_time = [ro_time[0]] * n_inputs
ro_time *= n

pe_dir = self.inputs.pe_dir
if len(pe_dir) == 1:
pe_dir = [pe_dir[0]] * n_inputs
pe_dir *= n

unwarp = None
hmc_mats = [None] * n
if isdefined(self.inputs.in_xfms):
# Split ITK matrices in separate files if they come collated
hmc_mats = (
list(_split_itk_file(self.inputs.in_xfms[0]))
if len(self.inputs.in_xfms) == 1
else self.inputs.in_xfms
)
else:
from sdcflows.transform import B0FieldTransform

# Pre-cached interpolator object
unwarp = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff],
)

if not isdefined(self.inputs.num_threads) or self.inputs.num_threads < 2:
# Linear execution (1 core)
outputs = [
_b0_resampler(
fname,
self.inputs.in_coeff,
pe_dir[i],
ro_time[i],
hmc_mats[i],
unwarp, # if no HMC matrices, interpolator can be shared
runtime.cwd,
)
for i, fname in enumerate(self.inputs.in_data)
]
else:
# Embarrasingly parallel execution
from concurrent.futures import ProcessPoolExecutor

outputs = [None] * len(self.inputs.in_data)
with ProcessPoolExecutor(max_workers=self.inputs.num_threads) as ex:
outputs = ex.map(
_b0_resampler,
self.inputs.in_data,
[self.inputs.in_coeff] * n,
pe_dir,
ro_time,
hmc_mats,
[None] * n, # force a new interpolator for each process
[runtime.cwd] * n,
)

for fname, pe, ro in zip(self.inputs.in_target, pe_dir, ro_time):
# Generate warpfield
warp_name = filename(fname, suffix="_xfm")
xfm.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name)
self._results["out_warp"].append(warp_name)
(
self._results["out_corrected"],
self._results["out_warp"],
self._results["out_field"],
) = zip(*outputs)

# Generate resampled
out_name = filename(fname, suffix="_unwarped")
xfm.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name)
self._results["out_corrected"].append(out_name)
out_fields = set(self._results["out_field"]) - set([None])
if len(out_fields) == 1:
self._results["out_field"] = out_fields.pop()

return runtime

Expand All @@ -303,11 +342,21 @@ class TransformCoefficients(SimpleInterface):
output_spec = _TransformCoefficientsOutputSpec

def _run_interface(self, runtime):
self._results["out_coeff"] = _move_coeff(
self.inputs.in_coeff,
self.inputs.fmap_ref,
self.inputs.transform,
)
from sdcflows.transform import _move_coeff

self._results["out_coeff"] = []

for level in self.inputs.in_coeff:
movednii = _move_coeff(
level,
self.inputs.fmap_ref,
self.inputs.transform,
)
out_file = fname_presuffix(
level, suffix="_space-target", newpath=runtime.cwd
)
movednii.to_filename(out_file)
self._results["out_coeff"].append(out_file)
return runtime


Expand Down Expand Up @@ -408,30 +457,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def _move_coeff(in_coeff, fmap_ref, transform):
"""Read in a rigid transform from ANTs, and update the coefficients field affine."""
from pathlib import Path
import nibabel as nb
import nitransforms as nt

if isinstance(in_coeff, str):
in_coeff = [in_coeff]

xfm = nt.linear.Affine(
nt.io.itk.ITKLinearTransform.from_filename(transform).to_ras(),
reference=fmap_ref,
)

out = []
for i, c in enumerate(in_coeff):
out.append(str(Path(f"moved_coeff_{i:03d}.nii.gz").absolute()))
img = nb.load(c)
newaff = xfm.matrix @ img.affine
img.__class__(img.dataobj, newaff, img.header).to_filename(out[-1])

return out


def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=None):
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
from pathlib import Path
Expand Down Expand Up @@ -463,3 +488,52 @@ def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=Non

coeffnii.__class__(coeffnii.dataobj, newaff, header).to_filename(out_file)
return out_file


def _split_itk_file(in_file):
from pathlib import Path

lines = Path(in_file).read_text().splitlines()
header = lines.pop(0)

def _chunks(inlist, chunksize):
for i in range(0, len(inlist), chunksize):
yield "\n".join([header] + inlist[i : i + chunksize])

for i, xfm in enumerate(_chunks(lines, 4)):
p = Path(f"{i:05}")
p.write_text(xfm)
yield str(p)


def _b0_resampler(data, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=None):
"""Outsource the resampler into a separate callable function to allow parallelization."""
from functools import partial

# Prepare output names
filename = partial(fname_presuffix, newpath=newpath)
retval = [filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")]

if unwarp is None:
from sdcflows.transform import B0FieldTransform

# Create a new unwarp object
unwarp = B0FieldTransform(
coeffs=[nb.load(cname) for cname in coeffs],
)

if hmc_xfm is not None:
from nitransforms.linear import Affine
from nitransforms.io.itk import ITKLinearTransform as XFMLoader

unwarp.xfm = Affine(XFMLoader.from_filename(hmc_xfm).to_ras())

if unwarp.fit(data):
unwarp.shifts.to_filename(retval[2])
else:
retval[2] = None

unwarp.apply(nb.load(data), ro_time=ro, pe_dir=pe).to_filename(retval[0])
unwarp.to_displacements(ro_time=ro, pe_dir=pe).to_filename(retval[1])

return retval
4 changes: 2 additions & 2 deletions sdcflows/interfaces/tests/test_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_bsplines(tmp_path, testnum):
os.chdir(tmp_path)
# Check that we can interpolate the coefficients on a target
test1 = ApplyCoeffsField(
in_target=str(tmp_path / "target.nii.gz"),
in_data=str(tmp_path / "target.nii.gz"),
in_coeff=str(tmp_path / "coeffs.nii.gz"),
pe_dir="j-",
ro_time=1.0,
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_topup_coeffs_interpolation(tmpdir, testdata_dir):
"""Check that our interpolation is not far away from TOPUP's."""
tmpdir.chdir()
result = ApplyCoeffsField(
in_target=[str(testdata_dir / "epi.nii.gz")] * 2,
in_data=[str(testdata_dir / "epi.nii.gz")] * 2,
in_coeff=str(testdata_dir / "topup-coeff-fixed.nii.gz"),
pe_dir="j-",
ro_time=1.0,
Expand Down
Loading

0 comments on commit 4b464fa

Please sign in to comment.