Skip to content

Commit

Permalink
fix: robustify displacements/fieldmap conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Oct 5, 2021
1 parent dbec0d9 commit 816c894
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 89 deletions.
36 changes: 36 additions & 0 deletions sdcflows/interfaces/fmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,39 @@ def _run_interface(self, runtime):
self._results["out_file"]
)
return runtime


class _DisplacementsField2FieldmapInputSpec(BaseInterfaceInputSpec):
transform = File(exists=True, mandatory=True, desc="input displacements field")
ro_time = traits.Float(mandatory=True, desc="total readout time")
pe_dir = traits.Enum(
"j-", "j", "i", "i-", "k", "k-", mandatory=True, desc="phase encoding direction"
)
itk_transform = traits.Bool(
True, usedefault=True, desc="whether this is an ITK/ANTs transform"
)


class _DisplacementsField2FieldmapOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="output fieldmap in Hz")


class DisplacementsField2Fieldmap(SimpleInterface):
"""Convert from a transform to a B0 fieldmap in Hz."""

input_spec = _DisplacementsField2FieldmapInputSpec
output_spec = _DisplacementsField2FieldmapOutputSpec

def _run_interface(self, runtime):
from sdcflows.transform import disp_to_fmap

self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_Hz", newpath=runtime.cwd
)
disp_to_fmap(
nb.load(self.inputs.transform),
ro_time=self.inputs.ro_time,
pe_dir=self.inputs.pe_dir,
itk_format=self.inputs.itk_transform,
).to_filename(self._results["out_file"])
return runtime
23 changes: 23 additions & 0 deletions sdcflows/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,26 @@ def test_displacements_field(tmpdir, testdata_dir, outdir, pe_dir, rotation, fli
f"_y-{rotation[1] or 0}_z-{rotation[2] or 0}.svg"
),
).run()


@pytest.mark.parametrize("pe_dir", ["j", "j-", "i", "i-", "k", "k-"])
def test_conversions(tmpdir, testdata_dir, pe_dir):
"""Check idempotency."""
tmpdir.chdir()

fmap_nii = nb.load(testdata_dir / "topup-field.nii.gz")
new_nii = tf.disp_to_fmap(
tf.fmap_to_disp(
fmap_nii,
ro_time=0.2,
pe_dir=pe_dir,
),
ro_time=0.2,
pe_dir=pe_dir,
)

new_nii.to_filename("test.nii.gz")
assert np.allclose(
fmap_nii.get_fdata(dtype="float32"),
new_nii.get_fdata(dtype="float32"),
)
133 changes: 102 additions & 31 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def fit(self, spatialimage):

# Interpolate the VSM (voxel-shift map)
vsm = np.zeros(spatialimage.shape[:3], dtype="float32")
vsm = (np.squeeze(np.vstack(coeffs).T) @ sparse_vstack(weights)).reshape(
vsm = (np.squeeze(np.hstack(coeffs).T) @ sparse_vstack(weights)).reshape(
vsm.shape
)

Expand Down Expand Up @@ -215,36 +215,107 @@ def to_displacements(self, ro_time, pe_dir, itk_format=True):
A NIfTI 1.0 object containing the distortion.
"""
# Set polarity & scale VSM (voxel-shift-map) by readout time
vsm = self.shifts.get_fdata().copy()
pe_axis = "ijk".index(pe_dir[0])
vsm *= -1.0 if pe_dir.endswith("-") else 1.0
vsm *= ro_time

# Shape of displacements field
# Note that ITK NIfTI fields are 5D (have an empty 4th dimension)
fieldshape = tuple(list(vsm.shape[:3]) + [1, 3])

# Convert VSM to voxel displacements
ijk_deltas = np.zeros((vsm.size, 3), dtype="float32")
ijk_deltas[:, pe_axis] = vsm.reshape(-1)

# To convert from VSM to RAS field we just apply the affine
aff = self.shifts.affine.copy()
aff[:3, 3] = 0 # Translations MUST NOT be applied, though.
xyz_deltas = nb.affines.apply_affine(aff, ijk_deltas)
if itk_format:
# ITK displacement vectors are in LPS orientation
xyz_deltas[..., (0, 1)] *= -1.0

xyz_nii = nb.Nifti1Image(
xyz_deltas.reshape(fieldshape),
self.shifts.affine,
None,
)
xyz_nii.header.set_intent("vector", (), "")
xyz_nii.header.set_xyzt_units("mm")
return xyz_nii
return fmap_to_disp(self.shifts, ro_time, pe_dir, itk_format=itk_format)


def fmap_to_disp(fmap_nii, ro_time, pe_dir, itk_format=True):
"""
Convert a fieldmap in Hz into an ITK/ANTs-compatible displacements field.
The displacements field can be calculated following
`Eq. (2) in the fieldmap fitting section
<sdcflows.workflows.fit.fieldmap.html#mjx-eqn-eq%3Afieldmap-2>`__.
Parameters
----------
fmap_nii : :obj:`os.pathlike`
Path to a voxel-shift-map (VSM) in NIfTI format
ro_time : :obj:`float`
The total readout time in seconds (only if ``vsm=False``).
pe_dir : :obj:`str`
The ``PhaseEncodingDirection`` metadata value (only if ``vsm=False``).
Returns
-------
spatialimage : :obj:`nibabel.nifti.Nifti1Image`
A NIfTI 1.0 object containing the distortion.
"""
# Set polarity & scale VSM (voxel-shift-map) by readout time
vsm = fmap_nii.get_fdata().copy() * (-ro_time if pe_dir.endswith("-") else ro_time)

# Shape of displacements field
# Note that ITK NIfTI fields are 5D (have an empty 4th dimension)
fieldshape = tuple(list(vsm.shape[:3]) + [1, 3])

# Convert VSM to voxel displacements
ijk_deltas = np.zeros((vsm.size, 3), dtype="float32")
ijk_deltas[:, "ijk".index(pe_dir[0])] = vsm.reshape(-1)

# To convert from VSM to RAS field we just apply the affine
aff = fmap_nii.affine.copy()
aff[:3, 3] = 0 # Translations MUST NOT be applied, though.
xyz_deltas = nb.affines.apply_affine(aff, ijk_deltas)
if itk_format:
# ITK displacement vectors are in LPS orientation
xyz_deltas[..., (0, 1)] *= -1.0

xyz_nii = nb.Nifti1Image(
xyz_deltas.reshape(fieldshape),
fmap_nii.affine,
None,
)
xyz_nii.header.set_intent("vector", (), "")
xyz_nii.header.set_xyzt_units("mm")
return xyz_nii


def disp_to_fmap(xyz_nii, ro_time, pe_dir, itk_format=True):
"""
Convert a displacements field into a fieldmap in Hz.
This is the dual operation to the previous function.
Parameters
----------
xyz_nii : :obj:`os.pathlike`
Path to a displacements field in NIfTI format.
ro_time : :obj:`float`
The total readout time in seconds (only if ``vsm=False``).
pe_dir : :obj:`str`
The ``PhaseEncodingDirection`` metadata value (only if ``vsm=False``).
Returns
-------
spatialimage : :obj:`nibabel.nifti.Nifti1Image`
A NIfTI 1.0 object containing the field in Hz.
"""
xyz_deltas = np.squeeze(xyz_nii.get_fdata(dtype="float32")).reshape((-1, 3))

if itk_format:
# ITK displacement vectors are in LPS orientation
xyz_deltas[:, (0, 1)] *= -1

inv_aff = np.linalg.inv(xyz_nii.affine)
inv_aff[:3, 3] = 0 # Translations MUST NOT be applied.

# Convert displacements from mm to voxel units
# Using the inverse affine accounts for reordering of axes, etc.
ijk_deltas = nb.affines.apply_affine(inv_aff, xyz_deltas).astype("float32")
ijk_deltas = (
ijk_deltas[:, "ijk".index(pe_dir[0])]
* (-1.0 if pe_dir.endswith("-") else 1.0)
/ ro_time
)

ijk_nii = nb.Nifti1Image(
ijk_deltas.reshape(xyz_nii.shape[:3]),
xyz_nii.affine,
None,
)
ijk_nii.header.set_xyzt_units("mm")
return ijk_nii


def _cubic_bspline(d):
Expand Down
64 changes: 6 additions & 58 deletions sdcflows/workflows/fit/syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def init_syn_sdc_wf(
)
from ...utils.misc import front as _pop, last as _pull
from ...interfaces.epi import GetReadoutTime
from ...interfaces.fmap import DisplacementsField2Fieldmap
from ...interfaces.bspline import (
ApplyCoeffsField,
BSplineApprox,
Expand Down Expand Up @@ -288,7 +289,7 @@ def init_syn_sdc_wf(
unwarp = pe.Node(ApplyCoeffsField(), name="unwarp")

# Extract nonzero component
extract_field = pe.Node(niu.Function(function=_extract_field), name="extract_field")
extract_field = pe.Node(DisplacementsField2Fieldmap(), name="extract_field")

# Check zooms (avoid very expensive B-Splines fitting)
zooms_field = pe.Node(
Expand Down Expand Up @@ -316,7 +317,6 @@ def init_syn_sdc_wf(
workflow.connect([
(inputnode, readout_time, [(("epi_ref", _pop), "in_file"),
(("epi_ref", _pull), "metadata")]),
(inputnode, extract_field, [("epi_ref", "epi_meta")]),
(inputnode, atlas_msk, [("sd_prior", "in_file")]),
(inputnode, clip_epi, [(("epi_ref", _pop), "in_file")]),
(inputnode, unwarp, [(("epi_ref", _pop), "in_data")]),
Expand Down Expand Up @@ -351,8 +351,10 @@ def init_syn_sdc_wf(
(fixed_masks, syn, [("out", "fixed_image_masks")]),
(epi_merge, syn, [("out", "moving_image")]),
(moving_masks, syn, [("out", "moving_image_masks")]),
(syn, extract_field, [("forward_transforms", "in_file")]),
(extract_field, zooms_field, [("out", "input_image")]),
(syn, extract_field, [(("forward_transforms", _pop), "transform")]),
(readout_time, extract_field, [("readout_time", "ro_time"),
("pe_direction", "pe_dir")]),
(extract_field, zooms_field, [("out_file", "input_image")]),
(zooms_field, zooms_bmask, [("output_image", "reference_image")]),
(zooms_field, bs_filter, [("output_image", "in_data")]),
(zooms_bmask, bs_filter, [("output_image", "in_mask")]),
Expand Down Expand Up @@ -631,60 +633,6 @@ def _warp_dir(intuple, nlevels=3):
return nlevels * [[1 if pe == ax else 0.1 for ax in "ijk"]]


def _extract_field(in_file, epi_meta, in_mask=None, demean=True):
"""
Extract the nonzero component of the deformation field estimated by ANTs.
Examples
--------
>>> nii = nb.load(
... _extract_field(
... ["field.nii.gz"],
... ("epi.nii.gz", {"PhaseEncodingDirection": "j-", "TotalReadoutTime": 0.005}),
... demean=False,
... )
... )
>>> nii.shape
(10, 10, 10)
>>> np.allclose(nii.get_fdata(), -200)
True
"""
from pathlib import Path
from nipype.utils.filemanip import fname_presuffix
import numpy as np
import nibabel as nb
from sdcflows.utils.epimanip import get_trt

fieldnii = nb.load(in_file[0])
trt = get_trt(epi_meta[1], in_file=epi_meta[0])
data = (
np.squeeze(fieldnii.get_fdata(dtype="float32"))[
..., "ijk".index(epi_meta[1]["PhaseEncodingDirection"][0])
]
/ trt
* (-1.0 if epi_meta[1]["PhaseEncodingDirection"].endswith("-") else 1.0)
)

if ["PhaseEncodingDirection"][0] in "ij":
data *= -1.0 # ITK/ANTs is an LPS system, flip direction

if demean:
mask = (
np.ones_like(data, dtype=bool) if in_mask is None
else np.asanyarray(nb.load(in_mask).dataobj, dtype=bool)
)
# De-mean the result
data -= np.median(data[mask])

out_file = Path(fname_presuffix(Path(in_file[0]).name, suffix="_fieldmap"))
nii = nb.Nifti1Image(data, fieldnii.affine, None)
nii.header.set_xyzt_units(fieldnii.header.get_xyzt_units()[0])
nii.to_filename(out_file)
return str(out_file.absolute())


def _merge_meta(epi_ref, meta_list):
"""Prepare a tuple of EPI reference and metadata."""
return (epi_ref, meta_list[0])
Expand Down

0 comments on commit 816c894

Please sign in to comment.