Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Runwise bold reference generation #268

Merged
merged 10 commits into from
Jan 19, 2023
99 changes: 32 additions & 67 deletions nibabies/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,80 +408,44 @@ def init_single_subject_wf(subject_id, session_id=None):

# Append the functional section to the existing anatomical exerpt
# That way we do not need to stream down the number of bold datasets
anat_preproc_wf.__postdesc__ = (
(anat_preproc_wf.__postdesc__ if hasattr(anat_preproc_wf, "__postdesc__") else "")
+ f"""
anat_preproc_wf.__postdesc__ = getattr(anat_preproc_wf, '__postdesc__') or ''
func_pre_desc = f"""

Functional data preprocessing

: For each of the {len(subject_data['bold'])} BOLD runs found per subject (across all
tasks and sessions), the following preprocessing was performed.
"""
)

# calculate reference image(s) for BOLD images
# group all BOLD files based on same:
# 1) session
# 2) PE direction
# 3) total readout time
from niworkflows.workflows.epi.refmap import init_epi_reference_wf

bold_groupings = group_bolds_ref(
layout=config.execution.layout,
subject=subject_id,
sessions=[session_id],
)
tasks and sessions), the following preprocessing was performed."""

func_preproc_wfs = []
has_fieldmap = bool(fmap_estimators)
for idx, grouping in enumerate(bold_groupings.values()):
bold_ref_wf = init_epi_reference_wf(
auto_bold_nss=True,
name=f"bold_reference_wf{idx}",
omp_nthreads=config.nipype.omp_nthreads,
)
bold_files = grouping.files
bold_ref_wf.inputs.inputnode.in_files = grouping.files

if grouping.multiecho_id is not None:
bold_files = [bold_files]
for idx, bold_file in enumerate(bold_files):
func_preproc_wf = init_func_preproc_wf(
bold_file,
has_fieldmap=has_fieldmap,
existing_derivatives=derivatives,
)
# fmt: off
workflow.connect([
(bold_ref_wf, func_preproc_wf, [
('outputnode.epi_ref_file', 'inputnode.bold_ref'),
(
('outputnode.xfm_files', _select_iter_idx, idx),
'inputnode.bold_ref_xfm'),
(
('outputnode.n_dummy', _select_iter_idx, idx),
'inputnode.n_dummy_scans'),
]),
(anat_preproc_wf, func_preproc_wf, [
('outputnode.anat_preproc', 'inputnode.anat_preproc'),
('outputnode.anat_mask', 'inputnode.anat_mask'),
('outputnode.anat_brain', 'inputnode.anat_brain'),
('outputnode.anat_dseg', 'inputnode.anat_dseg'),
('outputnode.anat_aseg', 'inputnode.anat_aseg'),
('outputnode.anat_aparc', 'inputnode.anat_aparc'),
('outputnode.anat_tpms', 'inputnode.anat_tpms'),
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
# Undefined if --fs-no-reconall, but this is safe
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.t1w2fsnative_xfm', 'inputnode.t1w2fsnative_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
]),
])
# fmt: on
func_preproc_wfs.append(func_preproc_wf)
for bold_file in subject_data['bold']:
mgxd marked this conversation as resolved.
Show resolved Hide resolved
func_preproc_wf = init_func_preproc_wf(bold_file, has_fieldmap=has_fieldmap)
if func_preproc_wf is None:
continue

func_preproc_wf.__desc__ = func_pre_desc + (getattr(func_preproc_wf, '__desc__') or '')
# fmt:off
workflow.connect([
(anat_preproc_wf, func_preproc_wf, [
('outputnode.anat_preproc', 'inputnode.anat_preproc'),
('outputnode.anat_mask', 'inputnode.anat_mask'),
('outputnode.anat_brain', 'inputnode.anat_brain'),
('outputnode.anat_dseg', 'inputnode.anat_dseg'),
('outputnode.anat_aseg', 'inputnode.anat_aseg'),
('outputnode.anat_aparc', 'inputnode.anat_aparc'),
('outputnode.anat_tpms', 'inputnode.anat_tpms'),
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
# Undefined if --fs-no-reconall, but this is safe
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.t1w2fsnative_xfm', 'inputnode.t1w2fsnative_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
]),
])
# fmt:on
func_preproc_wfs.append(func_preproc_wf)

if not has_fieldmap:
config.loggers.workflow.warning(
Expand All @@ -506,6 +470,7 @@ def init_single_subject_wf(subject_id, session_id=None):
subject=subject_id,
)
fmap_wf.__desc__ = f"""

Preprocessing of B<sub>0</sub> inhomogeneity mappings

: A total of {len(fmap_estimators)} fieldmaps were found available within the input
Expand Down
71 changes: 39 additions & 32 deletions nibabies/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from ...interfaces.reports import FunctionalSummary
from ...utils.bids import extract_entities
from ...utils.misc import combine_meepi_source
from .boldref import init_infant_epi_reference_wf

# BOLD workflows
from .confounds import init_bold_confs_wf, init_carpetplot_wf
Expand Down Expand Up @@ -127,12 +128,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
LTA-style affine matrix translating from T1w to FreeSurfer-conformed subject space
fsnative2t1w_xfm
LTA-style affine matrix translating from FreeSurfer-conformed subject space to T1w
bold_ref
BOLD reference file
bold_ref_xfm
Transform file in LTA format from bold to reference
n_dummy_scans
Number of nonsteady states at the beginning of the BOLD run

Outputs
-------
Expand Down Expand Up @@ -177,6 +172,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non

"""
from niworkflows.engine.workflows import LiterateWorkflow as Workflow
from niworkflows.interfaces.bold import NonsteadyStatesDetector
from niworkflows.interfaces.nibabel import ApplyMask
from niworkflows.interfaces.utility import DictMerge, KeySelect
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
Expand Down Expand Up @@ -244,9 +240,14 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
)

# Find associated sbref, if possible
entities["suffix"] = "sbref"
entities["extension"] = [".nii", ".nii.gz"] # Overwrite extensions
sbref_files = layout.get(scope="raw", return_type="file", **entities)
overrides = {
"suffix": "sbref",
"extension": [".nii", ".nii.gz"],
}
if config.execution.bids_filters:
overrides.update(config.execution.bids_filters.get('sbref', {}))
sb_ents = {**entities, **overrides}
sbref_files = layout.get(return_type="file", **sb_ents)

sbref_msg = f"No single-band-reference found for {os.path.basename(ref_file)}."
if sbref_files and "sbref" in config.workflow.ignore:
Expand Down Expand Up @@ -319,10 +320,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
"anat2std_xfm",
"std2anat_xfm",
"template",
# from bold reference workflow
"bold_ref",
"bold_ref_xfm",
"n_dummy_scans",
# from sdcflows (optional)
"fmap",
"fmap_ref",
Expand Down Expand Up @@ -514,12 +511,21 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
)
bold_confounds_wf.get_node("inputnode").inputs.t1_transform_flags = [False]

dummy_buffer = pe.Node(niu.IdentityInterface(fields=['n_dummy']), name='dummy_buffer')
if (dummy := config.workflow.dummy_scans) is not None:
dummy_buffer.inputs.n_dummy = dummy
else:
# Detect dummy scans
nss_detector = pe.Node(NonsteadyStatesDetector(), name='nss_detector')
nss_detector.inputs.in_file = ref_file
workflow.connect(nss_detector, 'n_dummy', dummy_buffer, 'n_dummy')

# SLICE-TIME CORRECTION (or bypass) #############################################
if run_stc:
bold_stc_wf = init_bold_stc_wf(name="bold_stc_wf", metadata=metadata)
# fmt:off
workflow.connect([
(inputnode, bold_stc_wf, [('n_dummy_scans', 'inputnode.skip_vols')]),
(dummy_buffer, bold_stc_wf, [('n_dummy', 'inputnode.skip_vols')]),
(select_bold, bold_stc_wf, [("out", 'inputnode.bold_file')]),
(bold_stc_wf, boldbuffer, [('outputnode.stc_file', 'bold_file')]),
])
Expand Down Expand Up @@ -577,8 +583,11 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
name="bold_final",
)

# Mask input BOLD reference image
initial_boldref_mask = pe.Node(BrainExtraction(), name="initial_boldref_mask")
# Create a reference image for the bold run
initial_boldref_wf = init_infant_epi_reference_wf(omp_nthreads, is_sbref=bool(sbref_files))
initial_boldref_wf.inputs.inputnode.epi_file = (
pop_file(sbref_files) if sbref_files else ref_file
)

# This final boldref will be calculated after bold_bold_trans_wf, which includes one or more:
# HMC (head motion correction)
Expand All @@ -602,8 +611,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
# BOLD buffer has slice-time corrected if it was run, original otherwise
(boldbuffer, bold_split, [('bold_file', 'in_file')]),
# HMC
(inputnode, bold_hmc_wf, [
('bold_ref', 'inputnode.raw_ref_image')]),
(initial_boldref_wf, bold_hmc_wf, [
('outputnode.boldref_file', 'inputnode.raw_ref_image')]),
(validate_bolds, bold_hmc_wf, [
(("out_file", pop_file), 'inputnode.bold_file')]),
(bold_hmc_wf, outputnode, [
Expand Down Expand Up @@ -659,8 +668,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
('outputnode.rmsd_file', 'inputnode.rmsd_file')]),
(bold_reg_wf, bold_confounds_wf, [
('outputnode.itk_t1_to_bold', 'inputnode.t1_bold_xform')]),
(inputnode, bold_confounds_wf, [
('n_dummy_scans', 'inputnode.skip_vols')]),
(dummy_buffer, bold_confounds_wf, [
('n_dummy', 'inputnode.skip_vols')]),
(bold_final, bold_confounds_wf, [
('bold', 'inputnode.bold'),
('mask', 'inputnode.bold_mask'),
Expand All @@ -672,7 +681,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
('outputnode.tcompcor_mask', 'tcompcor_mask'),
]),
# Summary
(inputnode, summary, [('n_dummy_scans', 'algo_dummy_scans')]),
(dummy_buffer, summary, [('n_dummy', 'algo_dummy_scans')]),
(bold_reg_wf, summary, [('outputnode.fallback', 'fallback')]),
(outputnode, summary, [('confounds', 'confounds_file')]),
# Select echo indices for original/validated BOLD files
Expand Down Expand Up @@ -874,8 +883,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
('bold_file', 'inputnode.name_source')]),
(bold_hmc_wf, ica_aroma_wf, [
('outputnode.movpar_file', 'inputnode.movpar_file')]),
(inputnode, ica_aroma_wf, [
('n_dummy_scans', 'inputnode.skip_vols')]),
(dummy_buffer, ica_aroma_wf, [
('n_dummy', 'inputnode.skip_vols')]),
(bold_confounds_wf, join, [
('outputnode.confounds_file', 'in_file')]),
(bold_confounds_wf, mrg_conf_metadata,
Expand Down Expand Up @@ -1051,9 +1060,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
("outputnode.bold", "inputnode.in_files"),
]),
] if not multiecho else [
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
(initial_boldref_mask, bold_t2s_wf, [
("out_mask", "inputnode.bold_mask"),
(initial_boldref_wf, bold_t2s_wf, [
("outputnode.boldref_mask", "inputnode.bold_mask"),
]),
(bold_bold_trans_wf, join_echos, [
("outputnode.bold", "bold_files"),
Expand Down Expand Up @@ -1125,14 +1133,13 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
("fmap_coeff", "inputnode.fmap_coeff"),
("fmap_mask", "inputnode.fmap_mask")]),
(output_select, summary, [("sdc_method", "distortion_correction")]),
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
(inputnode, coeff2epi_wf, [
("bold_ref", "inputnode.target_ref")]),
(initial_boldref_mask, coeff2epi_wf, [
("out_mask", "inputnode.target_mask")]), # skull-stripped brain
(initial_boldref_wf, coeff2epi_wf, [
("outputnode.boldref_file", "inputnode.target_ref")]),
(initial_boldref_wf, coeff2epi_wf, [
("outputnode.boldref_mask", "inputnode.target_mask")]), # skull-stripped brain
(coeff2epi_wf, unwarp_wf, [
("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
(inputnode, sdc_report, [("bold_ref", "before")]),
(initial_boldref_wf, sdc_report, [("outputnode.boldref_file", "before")]),
(bold_hmc_wf, unwarp_wf, [
("outputnode.xforms", "inputnode.hmc_xforms")]),
(bold_split, unwarp_wf, [
Expand Down
97 changes: 97 additions & 0 deletions nibabies/workflows/bold/boldref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import nipype.interfaces.utility as niu
import nipype.pipeline.engine as pe


def init_infant_epi_reference_wf(
omp_nthreads: int,
is_sbref: bool = False,
start_frame: int = 17,
name: str = 'infant_epi_reference_wf',
) -> pe.Workflow:
"""
Workflow to generate a reference map from one or more infant EPI images.

If any single-band references are provided, the reference map will be calculated from those.

If no single-band references are provided, the BOLD files are used.
To account for potential increased motion on the start of image acquisition, this
workflow discards a bigger chunk of the initial frames.

Parameters
----------
omp_nthreads
Maximum number of threads an individual process may use
has_sbref
A single-band reference is provided.
start_frame
BOLD frame to start creating the reference map from. Any earlier frames are discarded.

Inputs
------
bold_file
BOLD EPI file
sbref_file
single-band reference EPI

Outputs
-------
boldref_file
The generated reference map
boldref_mask
Binary brain mask of the ``boldref_file``
boldref_xfm
Rigid-body transforms in LTA format

"""
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
from sdcflows.interfaces.brainmask import BrainExtraction

wf = pe.Workflow(name=name)

inputnode = pe.Node(
niu.IdentityInterface(fields=['epi_file']),
name='inputnode',
)
outputnode = pe.Node(
niu.IdentityInterface(fields=['boldref_file', 'boldref_mask']),
name='outputnode',
)

epi_reference_wf = init_epi_reference_wf(omp_nthreads)

boldref_mask = pe.Node(BrainExtraction(), name='boldref_mask')

# fmt:off
wf.connect([
(inputnode, epi_reference_wf, [('epi_file', 'inputnode.in_files')]),
(epi_reference_wf, boldref_mask, [('outputnode.epi_ref_file', 'in_file')]),
(epi_reference_wf, outputnode, [('outputnode.epi_ref_file', 'boldref_file')]),
(boldref_mask, outputnode, [('out_mask', 'boldref_mask')]),
])
# fmt:on
if not is_sbref:
select_frames = pe.Node(
niu.Function(function=_select_frames, output_names=['t_masks']),
name='select_frames',
)
select_frames.inputs.start_frame = start_frame
# fmt:off
wf.connect([
(inputnode, select_frames, [('epi_file', 'in_file')]),
(select_frames, epi_reference_wf, [('t_masks', 'inputnode.t_masks')]),
])
# fmt:on
return wf


def _select_frames(in_file: str, start_frame: int) -> list:
import nibabel as nb
import numpy as np

img = nb.load(in_file)
img_len = img.shape[3]
if start_frame >= img_len:
start_frame = img_len - 1
t_mask = np.array([False] * img_len, dtype=bool)
t_mask[start_frame:] = True
return list(t_mask)
Loading