From 8580871bfe215e22c80ce67574b602cfae79aaaa Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 15 Sep 2021 14:57:43 +0200 Subject: [PATCH 01/11] enh: rotate fieldmap to each specific frame of the target through the head-motion correction matrices --- sdcflows/interfaces/bspline.py | 92 ++++++++++++++++++++-------------- sdcflows/transform.py | 24 ++++++++- setup.cfg | 1 + 3 files changed, 78 insertions(+), 39 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index ac3625856a..f6669ae9e9 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -34,6 +34,7 @@ TraitedSpec, File, traits, + isdefined, SimpleInterface, InputMultiObject, OutputMultiObject, @@ -218,6 +219,9 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): mandatory=True, desc="input coefficients, after alignment to the EPI data", ) + in_xfms = InputMultiObject( + File(exists=True), desc="list of head-motion correction matrices" + ) ro_time = InputMultiObject( traits.Float(mandatory=True, desc="EPI readout time (s).") ) @@ -237,7 +241,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): class _ApplyCoeffsFieldOutputSpec(TraitedSpec): out_corrected = OutputMultiObject(File(exists=True)) - out_field = File(exists=True) + out_field = OutputMultiObject(File(exists=True)) out_warp = OutputMultiObject(File(exists=True)) @@ -248,20 +252,37 @@ class ApplyCoeffsField(SimpleInterface): output_spec = _ApplyCoeffsFieldOutputSpec def _run_interface(self, runtime): + from nitransforms.linear import Affine + from nitransforms.io.itk import ITKLinearTransform as XFMLoader + # Prepare output names filename = partial(fname_presuffix, newpath=runtime.cwd) - self._results["out_field"] = filename(self.inputs.in_coeff[0], suffix="_field") + self._results["out_field"] = [] self._results["out_warp"] = [] self._results["out_corrected"] = [] - xfm = B0FieldTransform( + # Prepare a transform object + unwarp = B0FieldTransform( coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] ) - xfm.fit(self.inputs.in_target[0]) - xfm.shifts.to_filename(self._results["out_field"]) - + # Retrieve the number of target 3D EPIs n_inputs = len(self.inputs.in_target) + + # Load head-motion correction matrices + hmc_mats = None + if isdefined(self.inputs.in_xfms): + hmc_mats = self.inputs.in_xfms + else: + unwarp.fit(self.inputs.in_target[0]) + hmc_mats = [None] * n_inputs + + # Displacements field is constant through time + self._results["out_field"] = filename( + self.inputs.in_target[0], suffix="_field" + ) + unwarp.shifts.to_filename(self._results["out_field"]) + ro_time = self.inputs.ro_time if len(ro_time) == 1: ro_time = [ro_time[0]] * n_inputs @@ -270,15 +291,24 @@ def _run_interface(self, runtime): if len(pe_dir) == 1: pe_dir = [pe_dir[0]] * n_inputs - for fname, pe, ro in zip(self.inputs.in_target, pe_dir, ro_time): + for fname, pe, ro, hmc in zip(self.inputs.in_target, pe_dir, ro_time, hmc_mats): + # Apply hmc + if hmc is not None: + unwarp.xfm = Affine(XFMLoader.from_filename(hmc).to_ras()) + unwarp.fit(fname) + + # Write out a new field for this particular frame + self._results["out_field"].append(filename(fname, suffix="_field")) + unwarp.shifts.to_filename(self._results["out_field"][-1]) + # Generate warpfield warp_name = filename(fname, suffix="_xfm") - xfm.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name) + unwarp.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name) self._results["out_warp"].append(warp_name) # Generate resampled out_name = filename(fname, suffix="_unwarped") - xfm.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name) + unwarp.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name) self._results["out_corrected"].append(out_name) return runtime @@ -303,11 +333,21 @@ class TransformCoefficients(SimpleInterface): output_spec = _TransformCoefficientsOutputSpec def _run_interface(self, runtime): - self._results["out_coeff"] = _move_coeff( - self.inputs.in_coeff, - self.inputs.fmap_ref, - self.inputs.transform, - ) + from sdcflows.transform import _move_coeff + + self._results["out_coeff"] = [] + + for level in self.inputs.in_coeff: + movednii = _move_coeff( + level, + self.inputs.fmap_ref, + self.inputs.transform, + ) + out_file = fname_presuffix( + level, suffix="_space-target", newpath=runtime.cwd + ) + movednii.to_filename(out_file) + self._results["out_coeff"].append(out_file) return runtime @@ -408,30 +448,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM): return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine) -def _move_coeff(in_coeff, fmap_ref, transform): - """Read in a rigid transform from ANTs, and update the coefficients field affine.""" - from pathlib import Path - import nibabel as nb - import nitransforms as nt - - if isinstance(in_coeff, str): - in_coeff = [in_coeff] - - xfm = nt.linear.Affine( - nt.io.itk.ITKLinearTransform.from_filename(transform).to_ras(), - reference=fmap_ref, - ) - - out = [] - for i, c in enumerate(in_coeff): - out.append(str(Path(f"moved_coeff_{i:03d}.nii.gz").absolute())) - img = nb.load(c) - newaff = xfm.matrix @ img.affine - img.__class__(img.dataobj, newaff, img.header).to_filename(out[-1]) - - return out - - def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=None): """Read in a coefficients file generated by TOPUP and fix x-form headers.""" from pathlib import Path diff --git a/sdcflows/transform.py b/sdcflows/transform.py index d857d9812d..f20bea4b4b 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -29,14 +29,21 @@ from scipy.sparse import vstack as sparse_vstack, csr_matrix, kron import nibabel as nb +import nitransforms as nt from bids.utils import listify +def _clear_shifts(instance, attribute, value): + instance.shifts = None + return value + + @attr.s(slots=True) class B0FieldTransform: """Represents and applies the transform to correct for susceptibility distortions.""" coeffs = attr.ib(default=None) + xfm = attr.ib(default=nt.linear.Affine(), on_setattr=_clear_shifts) shifts = attr.ib(default=None, init=False) def fit(self, spatialimage): @@ -65,7 +72,11 @@ def fit(self, spatialimage): # Generate tensor-product B-Spline weights for level in listify(self.coeffs): - wmat = grid_bspline_weights(spatialimage, level) + self.xfm.reference = spatialimage + moved_cs = level.__class__( + level.dataobj, self.xfm.matrix @ level.affine, level.header + ) + wmat = grid_bspline_weights(spatialimage, moved_cs) weights.append(wmat) coeffs.append(level.get_fdata(dtype="float32").reshape(-1)) @@ -294,3 +305,14 @@ def grid_bspline_weights(target_nii, ctrl_nii): wd.append(csr_matrix(weights)) return kron(kron(wd[0], wd[1]), wd[2]) + + +def _move_coeff(in_coeff, fmap_ref, transform): + """Read in a rigid transform from ANTs, and update the coefficients field affine.""" + xfm = nt.linear.Affine( + nt.io.itk.ITKLinearTransform.from_filename(transform).to_ras(), + reference=fmap_ref, + ) + coeff = nb.load(in_coeff) + newaff = xfm.matrix @ coeff.affine + return coeff.__class__(coeff.dataobj, newaff, coeff.header) diff --git a/setup.cfg b/setup.cfg index 852811be82..7ea28a6446 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,7 @@ setup_requires = setuptools_scm_git_archive toml install_requires = + attrs >= 20.1.0 nibabel >=3.0.1 nipype >=1.5.1,<2.0 niworkflows >= 1.4.0rc5 From 5cbd9c0845133c5d9ca0dc1f41932c3fc78f754b Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 15 Sep 2021 15:51:08 +0200 Subject: [PATCH 02/11] enh: update unwarp workflow to propagate hmc transforms --- sdcflows/workflows/apply/correction.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 5687e103cd..65e361d6ba 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -55,6 +55,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): dictionary of metadata corresponding to the target EPI image fmap_coeff fieldmap coefficients in distorted EPI space. + hmc_xforms + list of head-motion correction matrices (in ITK format) Outputs ------- @@ -78,7 +80,9 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): workflow = Workflow(name=name) inputnode = pe.Node( - niu.IdentityInterface(fields=["distorted", "metadata", "fmap_coeff"]), + niu.IdentityInterface( + fields=["distorted", "metadata", "fmap_coeff", "hmc_xforms"] + ), name="inputnode", ) outputnode = pe.Node( @@ -99,7 +103,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): (inputnode, rotime, [("distorted", "in_file"), ("metadata", "metadata")]), (inputnode, resample, [("distorted", "in_target"), - ("fmap_coeff", "in_coeff")]), + ("fmap_coeff", "in_coeff"), + ("hmc_xforms", "in_xfms")]), (rotime, resample, [("readout_time", "ro_time"), ("pe_direction", "pe_dir")]), (resample, outputnode, [("out_field", "fieldmap"), From dd864a9e0079ff246613e0f8f4a65fff79009968 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 15 Sep 2021 17:40:41 +0200 Subject: [PATCH 03/11] fix: interface expecting only one file / imports --- sdcflows/workflows/apply/correction.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 65e361d6ba..9cedfcc7f6 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -74,9 +74,10 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): a fast mask calculated from the corrected EPI reference. """ - from ...interfaces.epi import GetReadoutTime - from ...interfaces.bspline import ApplyCoeffsField - from ..ancillary import init_brainextraction_wf + from sdcflows.interfaces.epi import GetReadoutTime + from sdcflows.interfaces.bspline import ApplyCoeffsField + from sdcflows.ancillary import init_brainextraction_wf + from sdcflows.utils.misc import front as _pop workflow = Workflow(name=name) inputnode = pe.Node( @@ -100,7 +101,7 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): # fmt:off workflow.connect([ - (inputnode, rotime, [("distorted", "in_file"), + (inputnode, rotime, [(("distorted", _pop), "in_file"), ("metadata", "metadata")]), (inputnode, resample, [("distorted", "in_target"), ("fmap_coeff", "in_coeff"), From 01abc29bb4e5874f2662d4aa0ba844974a79647c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 15 Sep 2021 18:00:43 +0200 Subject: [PATCH 04/11] fix: allow transform array itk file --- sdcflows/interfaces/bspline.py | 14 +++++++++++--- sdcflows/workflows/apply/correction.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index f6669ae9e9..59bd9427a3 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -253,7 +253,7 @@ class ApplyCoeffsField(SimpleInterface): def _run_interface(self, runtime): from nitransforms.linear import Affine - from nitransforms.io.itk import ITKLinearTransform as XFMLoader + from nitransforms.io.itk import ITKLinearTransformArray as XFMLoader # Prepare output names filename = partial(fname_presuffix, newpath=runtime.cwd) @@ -272,7 +272,15 @@ def _run_interface(self, runtime): # Load head-motion correction matrices hmc_mats = None if isdefined(self.inputs.in_xfms): - hmc_mats = self.inputs.in_xfms + hmc_mats = [] + + for in_xfm in self.inputs.in_xfms: + xfm = XFMLoader.from_filename(in_xfm) + + if hasattr(xfm, "xforms"): + hmc_mats += [Affine(x.to_ras()) for x in xfm.xforms] + else: + hmc_mats.append(Affine(xfm.to_ras())) else: unwarp.fit(self.inputs.in_target[0]) hmc_mats = [None] * n_inputs @@ -294,7 +302,7 @@ def _run_interface(self, runtime): for fname, pe, ro, hmc in zip(self.inputs.in_target, pe_dir, ro_time, hmc_mats): # Apply hmc if hmc is not None: - unwarp.xfm = Affine(XFMLoader.from_filename(hmc).to_ras()) + unwarp.xfm = hmc unwarp.fit(fname) # Write out a new field for this particular frame diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 9cedfcc7f6..65b7b32fff 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -76,7 +76,7 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): """ from sdcflows.interfaces.epi import GetReadoutTime from sdcflows.interfaces.bspline import ApplyCoeffsField - from sdcflows.ancillary import init_brainextraction_wf + from sdcflows.workflows.ancillary import init_brainextraction_wf from sdcflows.utils.misc import front as _pop workflow = Workflow(name=name) From bb57834498317565f2eaa62bd920f08178fb23bb Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 16 Sep 2021 11:10:05 +0200 Subject: [PATCH 05/11] enh: apply head-motion correction of coordinates first --- sdcflows/transform.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/sdcflows/transform.py b/sdcflows/transform.py index f20bea4b4b..c43caefd50 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -30,6 +30,7 @@ import nibabel as nb import nitransforms as nt +from nitransforms.base import _as_homogeneous from bids.utils import listify @@ -135,6 +136,9 @@ def apply( """ # Ensure the vsm has been computed + if isinstance(spatialimage, (str, bytes, Path)): + spatialimage = nb.load(spatialimage) + self.fit(spatialimage) vsm = self.shifts.get_fdata().copy() @@ -146,9 +150,22 @@ def apply( pe_axis = "ijk".index(pe_dir[0]) # Map voxel coordinates applying the VSM - ijk_axis = tuple([np.arange(s) for s in vsm.shape]) - voxcoords = np.array(np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32") - voxcoords[pe_axis, ...] += vsm * ro_time + if self.xfm is None: + ijk_axis = tuple([np.arange(s) for s in vsm.shape]) + voxcoords = np.array( + np.meshgrid(*ijk_axis, indexing="ij"), + dtype="float32" + ).reshape(3, -1) + else: + # Map coordinates from reference to time-step + hmc_xyz = self.xfm.map(self.xfm.reference.ndcoords.T) + # Convert from RAS to voxel coordinates + voxcoords = ( + np.linalg.inv(self.xfm.reference.affine) + @ _as_homogeneous(np.vstack(hmc_xyz), dim=self.xfm.reference.ndim).T + )[:3, ...] + + voxcoords[pe_axis, ...] += vsm.reshape(-1) * ro_time # Prepare data data = np.squeeze(np.asanyarray(spatialimage.dataobj)) @@ -157,7 +174,7 @@ def apply( # Resample resampled = ndi.map_coordinates( data, - voxcoords.reshape(3, -1), + voxcoords, output=output_dtype, order=order, mode=mode, From 83e33b1b4a8ad9366ad99412425b3c4d4d788969 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 16 Sep 2021 11:26:51 +0200 Subject: [PATCH 06/11] enh: improve unwarp workflow to allow 4D files --- sdcflows/workflows/apply/correction.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 65b7b32fff..e9b8d0e012 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -74,6 +74,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): a fast mask calculated from the corrected EPI reference. """ + from niworkflows.interfaces.images import RobustAverage + from niworkflows.interfaces.nibabel import MergeSeries from sdcflows.interfaces.epi import GetReadoutTime from sdcflows.interfaces.bspline import ApplyCoeffsField from sdcflows.workflows.ancillary import init_brainextraction_wf @@ -88,7 +90,7 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): ) outputnode = pe.Node( niu.IdentityInterface( - fields=["fieldmap", "fieldwarp", "corrected", "corrected_mask"] + fields=["fieldmap", "fieldwarp", "corrected", "corrected_ref", "corrected_mask"] ), name="outputnode", ) @@ -96,6 +98,8 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): rotime = pe.Node(GetReadoutTime(), name="rotime") rotime.interface._always_run = debug resample = pe.Node(ApplyCoeffsField(), name="resample") + merge = pe.Node(MergeSeries(), name="merge") + average = pe.Node(RobustAverage(mc_method=None), name="average") brainextraction_wf = init_brainextraction_wf() @@ -108,11 +112,14 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): ("hmc_xforms", "in_xfms")]), (rotime, resample, [("readout_time", "ro_time"), ("pe_direction", "pe_dir")]), + (resample, merge, [("out_corrected", "in_files")]), + (merge, average, [("out_file", "in_file")]), + (average, brainextraction_wf, [("out_file", "inputnode.in_file")]), + (merge, outputnode, [("out_file", "corrected")]), (resample, outputnode, [("out_field", "fieldmap"), ("out_warp", "fieldwarp")]), - (resample, brainextraction_wf, [("out_corrected", "inputnode.in_file")]), (brainextraction_wf, outputnode, [ - ("outputnode.out_file", "corrected"), + ("outputnode.out_file", "corrected_ref"), ("outputnode.out_mask", "corrected_mask"), ]), ]) From 3dda866ca27446620f03948f22305c7da5ea3ec8 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 17 Sep 2021 00:22:08 +0200 Subject: [PATCH 07/11] fix: memory issues --- sdcflows/interfaces/bspline.py | 62 +++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 59bd9427a3..3f9a602fd8 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -42,7 +42,6 @@ from sdcflows.transform import grid_bspline_weights as gbsw, B0FieldTransform - LOW_MEM_BLOCK_SIZE = 1000 DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm @@ -253,7 +252,8 @@ class ApplyCoeffsField(SimpleInterface): def _run_interface(self, runtime): from nitransforms.linear import Affine - from nitransforms.io.itk import ITKLinearTransformArray as XFMLoader + from nitransforms.io.itk import ITKLinearTransform as XFMLoader + import gc # Prepare output names filename = partial(fname_presuffix, newpath=runtime.cwd) @@ -262,28 +262,25 @@ def _run_interface(self, runtime): self._results["out_warp"] = [] self._results["out_corrected"] = [] - # Prepare a transform object - unwarp = B0FieldTransform( - coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] - ) # Retrieve the number of target 3D EPIs n_inputs = len(self.inputs.in_target) # Load head-motion correction matrices hmc_mats = None - if isdefined(self.inputs.in_xfms): - hmc_mats = [] - - for in_xfm in self.inputs.in_xfms: - xfm = XFMLoader.from_filename(in_xfm) + unwarp = None - if hasattr(xfm, "xforms"): - hmc_mats += [Affine(x.to_ras()) for x in xfm.xforms] - else: - hmc_mats.append(Affine(xfm.to_ras())) + if isdefined(self.inputs.in_xfms): + hmc_mats = ( + list(_split_itk_file(self.inputs.in_xfms[0])) + if len(self.inputs.in_xfms) == 1 + else self.inputs.in_xfms + ) else: + # Prepare a transform object + unwarp = B0FieldTransform( + coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] + ) unwarp.fit(self.inputs.in_target[0]) - hmc_mats = [None] * n_inputs # Displacements field is constant through time self._results["out_field"] = filename( @@ -299,10 +296,17 @@ def _run_interface(self, runtime): if len(pe_dir) == 1: pe_dir = [pe_dir[0]] * n_inputs - for fname, pe, ro, hmc in zip(self.inputs.in_target, pe_dir, ro_time, hmc_mats): + for i, fname in enumerate(self.inputs.in_target): + pe = pe_dir[i] + ro = ro_time[i] + # Apply hmc - if hmc is not None: - unwarp.xfm = hmc + if hmc_mats: + # Create a new unwarp object + unwarp = B0FieldTransform( + coeffs=[nb.load(cname) for cname in self.inputs.in_coeff], + xfm=Affine(XFMLoader.from_filename(hmc_mats[i]).to_ras()), + ) unwarp.fit(fname) # Write out a new field for this particular frame @@ -319,6 +323,10 @@ def _run_interface(self, runtime): unwarp.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name) self._results["out_corrected"].append(out_name) + if hmc_mats: + unwarp = None + gc.collect() + return runtime @@ -487,3 +495,19 @@ def _fix_topup_fieldcoeff(in_coeff, fmap_ref, refpe_reversed=False, out_file=Non coeffnii.__class__(coeffnii.dataobj, newaff, header).to_filename(out_file) return out_file + + +def _split_itk_file(in_file): + from pathlib import Path + + lines = Path(in_file).read_text().splitlines() + header = lines.pop(0) + + def _chunks(inlist, chunksize): + for i in range(0, len(inlist), chunksize): + yield "\n".join([header] + inlist[i : i + chunksize]) + + for i, xfm in enumerate(_chunks(lines, 4)): + p = Path(f"{i:05}") + p.write_text(xfm) + yield str(p) From 234d141eb47d764459d04250162d3d982b974cc9 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 17 Sep 2021 10:13:04 +0200 Subject: [PATCH 08/11] enh: new defaultlist object to cleanup a little --- sdcflows/interfaces/bspline.py | 14 ++++---------- sdcflows/utils/misc.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 3f9a602fd8..ea55b7bb8e 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -41,6 +41,8 @@ ) from sdcflows.transform import grid_bspline_weights as gbsw, B0FieldTransform +from sdcflows.utils.misc import defaultlist + LOW_MEM_BLOCK_SIZE = 1000 DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm @@ -262,9 +264,6 @@ def _run_interface(self, runtime): self._results["out_warp"] = [] self._results["out_corrected"] = [] - # Retrieve the number of target 3D EPIs - n_inputs = len(self.inputs.in_target) - # Load head-motion correction matrices hmc_mats = None unwarp = None @@ -288,13 +287,8 @@ def _run_interface(self, runtime): ) unwarp.shifts.to_filename(self._results["out_field"]) - ro_time = self.inputs.ro_time - if len(ro_time) == 1: - ro_time = [ro_time[0]] * n_inputs - - pe_dir = self.inputs.pe_dir - if len(pe_dir) == 1: - pe_dir = [pe_dir[0]] * n_inputs + ro_time = defaultlist(self.inputs.ro_time) + pe_dir = defaultlist(self.inputs.pe_dir) for i, fname in enumerate(self.inputs.in_target): pe = pe_dir[i] diff --git a/sdcflows/utils/misc.py b/sdcflows/utils/misc.py index b00c4cc641..38220b2e6d 100644 --- a/sdcflows/utils/misc.py +++ b/sdcflows/utils/misc.py @@ -67,3 +67,35 @@ def get_free_mem(): return round(virtual_memory().free, 1) except Exception: return None + + +class defaultlist(list): + """ + A sort of default dict for lists. + + Examples + -------- + >>> defaultlist(range(3)) + [0, 1, 2] + + >>> defaultlist(["abc"])[100] + 'abc' + + >>> defaultlist(range(3))[1] + 1 + + >>> l = defaultlist(reversed(range(3))) + >>> l[0] + 2 + + >>> _ = l.pop(0) + >>> _ = l.pop(0) + >>> l[4] + 0 + + """ + + def __getitem__(self, i): + if len(self) == 1: + i = 0 + return super().__getitem__(i) From c9443087970f4c754bd5459a0e11e9efe9096418 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 17 Sep 2021 10:39:00 +0200 Subject: [PATCH 09/11] enh: rename in_target input to in_data, preparing for more sophisticated transforms --- sdcflows/interfaces/bspline.py | 10 +++++----- sdcflows/interfaces/tests/test_bspline.py | 4 ++-- sdcflows/workflows/apply/correction.py | 2 +- sdcflows/workflows/fit/pepolar.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index ea55b7bb8e..9c5b168dde 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -212,7 +212,7 @@ def _run_interface(self, runtime): class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): - in_target = InputMultiObject( + in_data = InputMultiObject( File(exist=True, mandatory=True, desc="input EPI data to be corrected") ) in_coeff = InputMultiObject( @@ -235,7 +235,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): "k", "k-", mandatory=True, - desc="the phase-encoding direction corresponding to in_target", + desc="the phase-encoding direction corresponding to in_data", ) ) @@ -279,18 +279,18 @@ def _run_interface(self, runtime): unwarp = B0FieldTransform( coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] ) - unwarp.fit(self.inputs.in_target[0]) + unwarp.fit(self.inputs.in_data[0]) # Displacements field is constant through time self._results["out_field"] = filename( - self.inputs.in_target[0], suffix="_field" + self.inputs.in_data[0], suffix="_field" ) unwarp.shifts.to_filename(self._results["out_field"]) ro_time = defaultlist(self.inputs.ro_time) pe_dir = defaultlist(self.inputs.pe_dir) - for i, fname in enumerate(self.inputs.in_target): + for i, fname in enumerate(self.inputs.in_data): pe = pe_dir[i] ro = ro_time[i] diff --git a/sdcflows/interfaces/tests/test_bspline.py b/sdcflows/interfaces/tests/test_bspline.py index e49eb5f263..66b4de6b37 100644 --- a/sdcflows/interfaces/tests/test_bspline.py +++ b/sdcflows/interfaces/tests/test_bspline.py @@ -63,7 +63,7 @@ def test_bsplines(tmp_path, testnum): os.chdir(tmp_path) # Check that we can interpolate the coefficients on a target test1 = ApplyCoeffsField( - in_target=str(tmp_path / "target.nii.gz"), + in_data=str(tmp_path / "target.nii.gz"), in_coeff=str(tmp_path / "coeffs.nii.gz"), pe_dir="j-", ro_time=1.0, @@ -114,7 +114,7 @@ def test_topup_coeffs_interpolation(tmpdir, testdata_dir): """Check that our interpolation is not far away from TOPUP's.""" tmpdir.chdir() result = ApplyCoeffsField( - in_target=[str(testdata_dir / "epi.nii.gz")] * 2, + in_data=[str(testdata_dir / "epi.nii.gz")] * 2, in_coeff=str(testdata_dir / "topup-coeff-fixed.nii.gz"), pe_dir="j-", ro_time=1.0, diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index e9b8d0e012..7e48d736c9 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -107,7 +107,7 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): workflow.connect([ (inputnode, rotime, [(("distorted", _pop), "in_file"), ("metadata", "metadata")]), - (inputnode, resample, [("distorted", "in_target"), + (inputnode, resample, [("distorted", "in_data"), ("fmap_coeff", "in_coeff"), ("hmc_xforms", "in_xfms")]), (rotime, resample, [("readout_time", "ro_time"), diff --git a/sdcflows/workflows/fit/pepolar.py b/sdcflows/workflows/fit/pepolar.py index d6e7927ca5..7bf8f14415 100644 --- a/sdcflows/workflows/fit/pepolar.py +++ b/sdcflows/workflows/fit/pepolar.py @@ -212,7 +212,7 @@ def init_topup_wf( (pad_blip_slices, topup, [("out_file", "in_file")]), (fix_coeff, unwarp, [("out_coeff", "in_coeff")]), (realign, split_blips, [("out_file", "in_file")]), - (split_blips, unwarp, [("out_files", "in_target")]), + (split_blips, unwarp, [("out_files", "in_data")]), (readout_time, unwarp, [("readout_time", "ro_time"), ("pe_direction", "pe_dir")]), (unwarp, outputnode, [("out_warp", "out_warps"), From 33aa51950bff2514628ab80b6d635ef31ff65766 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 17 Sep 2021 12:08:04 +0200 Subject: [PATCH 10/11] enh: encapsulate code in function allowing list comprehension --- sdcflows/interfaces/bspline.py | 110 +++++++++++++++++---------------- sdcflows/transform.py | 12 +++- 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 9c5b168dde..2c21a08632 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -40,7 +40,7 @@ OutputMultiObject, ) -from sdcflows.transform import grid_bspline_weights as gbsw, B0FieldTransform +from sdcflows.transform import grid_bspline_weights as gbsw from sdcflows.utils.misc import defaultlist @@ -253,21 +253,12 @@ class ApplyCoeffsField(SimpleInterface): output_spec = _ApplyCoeffsFieldOutputSpec def _run_interface(self, runtime): - from nitransforms.linear import Affine - from nitransforms.io.itk import ITKLinearTransform as XFMLoader - import gc - - # Prepare output names - filename = partial(fname_presuffix, newpath=runtime.cwd) - - self._results["out_field"] = [] - self._results["out_warp"] = [] - self._results["out_corrected"] = [] - # Load head-motion correction matrices - hmc_mats = None - unwarp = None + ro_time = defaultlist(self.inputs.ro_time) + pe_dir = defaultlist(self.inputs.pe_dir) + unwarp = None + hmc_mats = defaultlist([None]) if isdefined(self.inputs.in_xfms): hmc_mats = ( list(_split_itk_file(self.inputs.in_xfms[0])) @@ -275,51 +266,34 @@ def _run_interface(self, runtime): else self.inputs.in_xfms ) else: - # Prepare a transform object + from sdcflows.transform import B0FieldTransform + unwarp = B0FieldTransform( - coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] + coeffs=[nb.load(cname) for cname in self.inputs.in_coeff], ) - unwarp.fit(self.inputs.in_data[0]) - # Displacements field is constant through time - self._results["out_field"] = filename( - self.inputs.in_data[0], suffix="_field" + outputs = [ + _b0_resampler( + fname, + self.inputs.in_coeff, + pe_dir[i], + ro_time[i], + hmc_mats[i], + unwarp, + runtime.cwd, ) - unwarp.shifts.to_filename(self._results["out_field"]) - - ro_time = defaultlist(self.inputs.ro_time) - pe_dir = defaultlist(self.inputs.pe_dir) - - for i, fname in enumerate(self.inputs.in_data): - pe = pe_dir[i] - ro = ro_time[i] - - # Apply hmc - if hmc_mats: - # Create a new unwarp object - unwarp = B0FieldTransform( - coeffs=[nb.load(cname) for cname in self.inputs.in_coeff], - xfm=Affine(XFMLoader.from_filename(hmc_mats[i]).to_ras()), - ) - unwarp.fit(fname) - - # Write out a new field for this particular frame - self._results["out_field"].append(filename(fname, suffix="_field")) - unwarp.shifts.to_filename(self._results["out_field"][-1]) - - # Generate warpfield - warp_name = filename(fname, suffix="_xfm") - unwarp.to_displacements(ro_time=ro, pe_dir=pe).to_filename(warp_name) - self._results["out_warp"].append(warp_name) + for i, fname in enumerate(self.inputs.in_data) + ] - # Generate resampled - out_name = filename(fname, suffix="_unwarped") - unwarp.apply(nb.load(fname), ro_time=ro, pe_dir=pe).to_filename(out_name) - self._results["out_corrected"].append(out_name) + ( + self._results["out_corrected"], + self._results["out_warp"], + self._results["out_field"], + ) = zip(*outputs) - if hmc_mats: - unwarp = None - gc.collect() + out_fields = set(self._results["out_field"]) - set([None]) + if len() == 1: + self._results["out_field"] = out_fields.pop() return runtime @@ -505,3 +479,33 @@ def _chunks(inlist, chunksize): p = Path(f"{i:05}") p.write_text(xfm) yield str(p) + + +def _b0_resampler(data, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=None): + # Prepare output names + filename = partial(fname_presuffix, newpath=newpath) + retval = tuple([filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")]) + + if unwarp is None: + from sdcflows.transform import B0FieldTransform + + # Create a new unwarp object + unwarp = B0FieldTransform( + coeffs=[nb.load(cname) for cname in coeffs], + ) + + if hmc_xfm is not None: + from nitransforms.linear import Affine + from nitransforms.io.itk import ITKLinearTransform as XFMLoader + + unwarp.xfm = Affine(XFMLoader.from_filename(hmc_xfm).to_ras()) + + if unwarp.fit(data): + unwarp.shifts.to_filename(retval[2]) + else: + retval[2] = None + + unwarp.apply(nb.load(data), ro_time=ro, pe_dir=pe).to_filename(retval[0]) + unwarp.to_displacements(ro_time=ro, pe_dir=pe).to_filename(retval[1]) + + return retval diff --git a/sdcflows/transform.py b/sdcflows/transform.py index c43caefd50..dc5a9f00c4 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -54,6 +54,12 @@ def fit(self, spatialimage): Implements Eq. :math:`\eqref{eq:1}`, interpolating :math:`f(\mathbf{s})` for all voxels in the target-image's extent. + Returns + ------- + updated : :obj:`bool` + ``True`` if the internal field representation was fit, + ``False`` if cache was valid and will be reused. + """ # Calculate the physical coordinates of target grid if isinstance(spatialimage, (str, bytes, Path)): @@ -66,7 +72,7 @@ def fit(self, spatialimage): if np.all(newshape == self.shifts.shape) and np.allclose( newaff, self.shifts.affine ): - return + return False weights = [] coeffs = [] @@ -89,6 +95,7 @@ def fit(self, spatialimage): # Cache self.shifts = nb.Nifti1Image(vsm, spatialimage.affine, None) + return True def apply( self, @@ -153,8 +160,7 @@ def apply( if self.xfm is None: ijk_axis = tuple([np.arange(s) for s in vsm.shape]) voxcoords = np.array( - np.meshgrid(*ijk_axis, indexing="ij"), - dtype="float32" + np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32" ).reshape(3, -1) else: # Map coordinates from reference to time-step From 8ca1e79e69554c1feb4a39687705b565749ad3ba Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 17 Sep 2021 12:57:32 +0200 Subject: [PATCH 11/11] enh: parallelization --- sdcflows/interfaces/bspline.py | 68 ++++++++++++++++++-------- sdcflows/utils/misc.py | 32 ------------ sdcflows/workflows/apply/correction.py | 10 +++- 3 files changed, 56 insertions(+), 54 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 2c21a08632..ed3abe71f8 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -22,7 +22,6 @@ # """Filtering of :math:`B_0` field mappings with B-Splines.""" from pathlib import Path -from functools import partial import numpy as np import nibabel as nb from nibabel.affines import apply_affine @@ -41,7 +40,6 @@ ) from sdcflows.transform import grid_bspline_weights as gbsw -from sdcflows.utils.misc import defaultlist LOW_MEM_BLOCK_SIZE = 1000 @@ -238,6 +236,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): desc="the phase-encoding direction corresponding to in_data", ) ) + num_threads = traits.Int(nohash=True, desc="number of threads") class _ApplyCoeffsFieldOutputSpec(TraitedSpec): @@ -253,13 +252,20 @@ class ApplyCoeffsField(SimpleInterface): output_spec = _ApplyCoeffsFieldOutputSpec def _run_interface(self, runtime): - # Load head-motion correction matrices - ro_time = defaultlist(self.inputs.ro_time) - pe_dir = defaultlist(self.inputs.pe_dir) + n = len(self.inputs.in_data) + + ro_time = self.inputs.ro_time + if len(ro_time) == 1: + ro_time *= n + + pe_dir = self.inputs.pe_dir + if len(pe_dir) == 1: + pe_dir *= n unwarp = None - hmc_mats = defaultlist([None]) + hmc_mats = [None] * n if isdefined(self.inputs.in_xfms): + # Split ITK matrices in separate files if they come collated hmc_mats = ( list(_split_itk_file(self.inputs.in_xfms[0])) if len(self.inputs.in_xfms) == 1 @@ -268,22 +274,41 @@ def _run_interface(self, runtime): else: from sdcflows.transform import B0FieldTransform + # Pre-cached interpolator object unwarp = B0FieldTransform( coeffs=[nb.load(cname) for cname in self.inputs.in_coeff], ) - outputs = [ - _b0_resampler( - fname, - self.inputs.in_coeff, - pe_dir[i], - ro_time[i], - hmc_mats[i], - unwarp, - runtime.cwd, - ) - for i, fname in enumerate(self.inputs.in_data) - ] + if not isdefined(self.inputs.num_threads) or self.inputs.num_threads < 2: + # Linear execution (1 core) + outputs = [ + _b0_resampler( + fname, + self.inputs.in_coeff, + pe_dir[i], + ro_time[i], + hmc_mats[i], + unwarp, # if no HMC matrices, interpolator can be shared + runtime.cwd, + ) + for i, fname in enumerate(self.inputs.in_data) + ] + else: + # Embarrasingly parallel execution + from concurrent.futures import ProcessPoolExecutor + + outputs = [None] * len(self.inputs.in_data) + with ProcessPoolExecutor(max_workers=self.inputs.num_threads) as ex: + outputs = ex.map( + _b0_resampler, + self.inputs.in_data, + [self.inputs.in_coeff] * n, + pe_dir, + ro_time, + hmc_mats, + [None] * n, # force a new interpolator for each process + [runtime.cwd] * n, + ) ( self._results["out_corrected"], @@ -292,7 +317,7 @@ def _run_interface(self, runtime): ) = zip(*outputs) out_fields = set(self._results["out_field"]) - set([None]) - if len() == 1: + if len(out_fields) == 1: self._results["out_field"] = out_fields.pop() return runtime @@ -482,9 +507,12 @@ def _chunks(inlist, chunksize): def _b0_resampler(data, coeffs, pe, ro, hmc_xfm=None, unwarp=None, newpath=None): + """Outsource the resampler into a separate callable function to allow parallelization.""" + from functools import partial + # Prepare output names filename = partial(fname_presuffix, newpath=newpath) - retval = tuple([filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")]) + retval = [filename(data, suffix=s) for s in ("_unwarped", "_xfm", "_field")] if unwarp is None: from sdcflows.transform import B0FieldTransform diff --git a/sdcflows/utils/misc.py b/sdcflows/utils/misc.py index 38220b2e6d..b00c4cc641 100644 --- a/sdcflows/utils/misc.py +++ b/sdcflows/utils/misc.py @@ -67,35 +67,3 @@ def get_free_mem(): return round(virtual_memory().free, 1) except Exception: return None - - -class defaultlist(list): - """ - A sort of default dict for lists. - - Examples - -------- - >>> defaultlist(range(3)) - [0, 1, 2] - - >>> defaultlist(["abc"])[100] - 'abc' - - >>> defaultlist(range(3))[1] - 1 - - >>> l = defaultlist(reversed(range(3))) - >>> l[0] - 2 - - >>> _ = l.pop(0) - >>> _ = l.pop(0) - >>> l[4] - 0 - - """ - - def __getitem__(self, i): - if len(self) == 1: - i = 0 - return super().__getitem__(i) diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 7e48d736c9..91380b6492 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -90,14 +90,20 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): ) outputnode = pe.Node( niu.IdentityInterface( - fields=["fieldmap", "fieldwarp", "corrected", "corrected_ref", "corrected_mask"] + fields=[ + "fieldmap", + "fieldwarp", + "corrected", + "corrected_ref", + "corrected_mask", + ] ), name="outputnode", ) rotime = pe.Node(GetReadoutTime(), name="rotime") rotime.interface._always_run = debug - resample = pe.Node(ApplyCoeffsField(), name="resample") + resample = pe.Node(ApplyCoeffsField(num_threads=omp_nthreads), name="resample") merge = pe.Node(MergeSeries(), name="merge") average = pe.Node(RobustAverage(mc_method=None), name="average")