Skip to content

Commit

Permalink
enh: do not register - trust inputs are not wonky
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed May 13, 2021
1 parent 96e300b commit 95d4157
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 119 deletions.
61 changes: 61 additions & 0 deletions sdcflows/interfaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,67 @@ def _run_interface(self, runtime):
return runtime


class _UniformGridInputSpec(BaseInterfaceInputSpec):
in_data = InputMultiObject(
File(exists=True),
mandatory=True,
desc="list of input data",
)
reference = traits.Int(0, usedefault=True, desc="reference index")


class _UniformGridOutputSpec(TraitedSpec):
out_data = OutputMultiObject(File(exists=True))
reference = File(exists=True)


class UniformGrid(SimpleInterface):
"""Ensure all images in input have the same spatial parameters."""

input_spec = _UniformGridInputSpec
output_spec = _UniformGridOutputSpec

def _run_interface(self, runtime):
import nibabel as nb
import numpy as np
from nitransforms.linear import Affine
from nipype.utils.filemanip import fname_presuffix

retval = [None] * len(self.inputs.in_data)
self._results["reference"] = self.inputs.in_data[self.inputs.reference]
retval[self.inputs.reference] = self._results["reference"]

refnii = nb.load(self._results["reference"])
refshape = refnii.shape[:3]
refaff = refnii.affine

resampler = Affine(reference=refnii)
for i, fname in enumerate(self.inputs.in_data):
if retval[i] is not None:
continue

nii = nb.load(fname)
retval[i] = fname_presuffix(
fname, suffix=f"_regrid{i:03d}", newpath=runtime.cwd
)

if np.allclose(nii.shape[:3], refshape) and np.allclose(nii.affine, refaff):
if np.all(nii.affine == refaff):
retval[i] = fname
else:
# Set reference's affine if difference is small
nii.__class__(nii.dataobj, refaff, nii.header).to_filename(
retval[i]
)
continue

resampler.apply(nii).to_filename(retval[i])

self._results["out_data"] = retval

return runtime


class _ConvertWarpInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc="output of 3dQwarp")

Expand Down
132 changes: 13 additions & 119 deletions sdcflows/workflows/fit/pepolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@


def init_topup_wf(
omp_nthreads=1, sloppy=False, debug=False, name="pepolar_estimate_wf"
grid_reference=0,
omp_nthreads=1,
sloppy=False,
debug=False,
name="pepolar_estimate_wf",
):
"""
Create the PEPOLAR field estimation workflow based on FSL's ``topup``.
Expand All @@ -39,6 +43,8 @@ def init_topup_wf(
Parameters
----------
grid_reference : :obj:`int`
Index of the volume (after flattening) that will be taken for gridding reference.
sloppy : :obj:`bool`
Whether a fast configuration of topup (less accurate) should be applied.
debug : :obj:`bool`
Expand Down Expand Up @@ -76,7 +82,7 @@ def init_topup_wf(

from ...utils.misc import front as _front
from ...interfaces.epi import GetReadoutTime
from ...interfaces.utils import Flatten
from ...interfaces.utils import Flatten, UniformGrid
from ...interfaces.bspline import TOPUPCoeffReorient
from ..ancillary import init_brainextraction_wf

Expand Down Expand Up @@ -104,6 +110,7 @@ def init_topup_wf(
outputnode.inputs.method = "PEB/PEPOLAR (phase-encoding based / PE-POLARity)"

flatten = pe.Node(Flatten(), name="flatten")
regrid = pe.Node(UniformGrid(reference=grid_reference), name="regrid")
concat_blips = pe.Node(MergeSeries(), name="concat_blips")
readout_time = pe.MapNode(
GetReadoutTime(),
Expand Down Expand Up @@ -134,12 +141,13 @@ def init_topup_wf(
("metadata", "in_meta")]),
(flatten, readout_time, [("out_data", "in_file"),
("out_meta", "metadata")]),
(flatten, concat_blips, [("out_data", "in_files")]),
(flatten, regrid, [("out_data", "in_data")]),
(regrid, concat_blips, [("out_data", "in_files")]),
(readout_time, topup, [("readout_time", "readout_times"),
("pe_dir_fsl", "encoding_direction")]),
(concat_blips, topup, [("out_file", "in_file")]),
(flatten, fix_coeff, [(("out_data", _front), "fmap_ref")]),
(readout_time, topup, [(("pe_direction", _front), "pe_dir")]),
(regrid, fix_coeff, [("reference", "fmap_ref")]),
(readout_time, fix_coeff, [(("pe_direction", _front), "pe_dir")]),
(topup, fix_coeff, [("out_fieldcoef", "in_coeff")]),
(topup, outputnode, [("out_jacs", "jacobians"),
("out_mats", "xfms")]),
Expand Down Expand Up @@ -332,112 +340,6 @@ def init_3dQwarp_wf(omp_nthreads=1, debug=False, name="pepolar_estimate_wf"):
return workflow


def init_prepare_blips_wf(*, omp_nthreads=1, name="prepare_blips_wf"):
"""
Prepare fieldmaps for PEPOLAR correction.
This workflow takes in two or four EPI files.
Parameters
----------
omp_nthreads : :obj:`int`
Parallelize internal tasks across the number of CPUs given by this option.
name : :obj:`str`
Name for this workflow
Inputs
------
epi_files : :obj:`list` of :obj:`str`
A list of two or four EPI files, the first of which will be taken as reference.
metadata : obj:`list` of :obj:`dict`
A list of dictionaries containing the metadata corresponding to each file
in ``epi_files``.
Outputs
-------
reg_blips : :obj:`str`
A 4D file containing one volume per phase-encoding direction
"""
import pkg_resources as pkgr
from nipype.interfaces.ants.segmentation import N4BiasFieldCorrection
from niworkflows.interfaces.fixes import FixHeaderRegistration as Registration
from niworkflows.interfaces.freesurfer import StructuralReference
from niworkflows.interfaces.nibabel import MergeSeries
from ...interfaces.utils import Flatten

inputnode = pe.Node(
niu.IdentityInterface(fields=["epi_files", "metadata"]), name="inputnode"
)
outputnode = pe.Node(
niu.IdentityInterface(fields=["reg_blips", "readout_times"]), name="outputnode"
)

flatten = pe.MapNode(Flatten(), name="flatten", iterfield=["in_data", "in_meta"])
gen_pe_refs = pe.MapNode(
StructuralReference(
auto_detect_sensitivity=True,
initial_timepoint=1,
fixed_timepoint=True, # Align to first image
intensity_scaling=True, # 7-DOF (rigid + intensity)
no_iteration=True,
subsample_threshold=200,
transform_outputs=True,
out_file="template.nii.gz",
),
iterfield=["in_files"],
name="gen_pe_refs",
)
n4_refs = pe.MapNode(
N4BiasFieldCorrection(
dimension=3,
copy_header=True,
n_iterations=[50] * 5,
convergence_threshold=1e-7,
shrink_factor=4,
),
n_procs=omp_nthreads,
name="n4_refs",
iterfield=["input_image"],
)

get_reg_files = pe.Node(
niu.Function(function=_separate_first, output_names=["ref_image", "blips"]),
name="get_reg_files",
)

reg_settings = pkgr.resource_filename("sdcflows", "data/translation_rigid.json")
reg_blips = pe.MapNode(
Registration(from_file=reg_settings, output_warped_image=True),
name="reg_blips",
iterfield=["moving_image"],
n_procs=omp_nthreads,
)

concat_blips = pe.Node(niu.Merge(2), name="concat_blips")
merge_blips = pe.Node(MergeSeries(), name="merge_blips")

workflow = Workflow(name=name)
# fmt: off
workflow.connect([
(inputnode, flatten, [
("epi_files", "in_data"),
("metadata", "in_meta")]),
(flatten, gen_pe_refs, [("out_data", "in_files")]),
(gen_pe_refs, n4_refs, [("out_file", "input_image")]),
(n4_refs, get_reg_files, [("output_image", "in_files")]),
(get_reg_files, reg_blips, [("ref_image", "fixed_image"),
("blips", "moving_image")]),
(get_reg_files, concat_blips, [("ref_image", "in1")]),
(reg_blips, concat_blips, [("warped_image", "in2")]),
(concat_blips, merge_blips, [("out", "in_files")]),
(merge_blips, outputnode, [("out_file", 'reg_blips')]),
])
# fmt: on
return workflow


def _sorted_pe(inlist):
"""
Generate suitable inputs to ``3dQwarp``.
Expand Down Expand Up @@ -488,11 +390,3 @@ def _sorted_pe(inlist):
ref_pe[0]
],
)


def _separate_first(in_files):
"""Take in a list of files and separate the first from the rest."""
# TODO: check for best resolution image?
if isinstance(in_files, (list, tuple)):
return in_files[0], in_files[1:]
raise RuntimeError(f"Expected an iterable but given {in_files}")

0 comments on commit 95d4157

Please sign in to comment.