From e7e82b75bb8bec20d59f2c9d07bd19ee85dd97ed Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Fri, 1 Nov 2024 13:04:59 -0400 Subject: [PATCH] :recycle: SSOT `warp_to_template` --- .../longitudinal_workflow.py | 2 +- CPAC/pipeline/cpac_pipeline.py | 7 +- CPAC/registration/registration.py | 146 +++++++----------- CPAC/registration/utils.py | 33 ++++ 4 files changed, 97 insertions(+), 91 deletions(-) diff --git a/CPAC/longitudinal_pipeline/longitudinal_workflow.py b/CPAC/longitudinal_pipeline/longitudinal_workflow.py index 081374f168..67b52c975c 100644 --- a/CPAC/longitudinal_pipeline/longitudinal_workflow.py +++ b/CPAC/longitudinal_pipeline/longitudinal_workflow.py @@ -499,7 +499,7 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur # Rename nodes to include session name to avoid duplicates for key in strats_dct: for i, resource in enumerate(strats_dct[key]): - resource = ( + strats_dct[key][i] = ( resource[0].clone(f"{resource[0].name}_{session_id_list[i]}"), resource[1]) diff --git a/CPAC/pipeline/cpac_pipeline.py b/CPAC/pipeline/cpac_pipeline.py index 89d49b8d6f..a9df746179 100644 --- a/CPAC/pipeline/cpac_pipeline.py +++ b/CPAC/pipeline/cpac_pipeline.py @@ -98,8 +98,7 @@ coregistration, create_func_to_T1template_xfm, create_func_to_T1template_symmetric_xfm, - warp_wholeheadT1_to_template, - warp_mask_to_template, + warp_to_template, apply_phasediff_to_timeseries_separately, apply_blip_to_timeseries_separately, warp_timeseries_to_T1template, @@ -1045,8 +1044,8 @@ def build_T1w_registration_stack(rpool, cfg, pipeline_blocks=None, reg_blocks = [ [register_ANTs_anat_to_template, register_FSL_anat_to_template], overwrite_transform_anat_to_template, - warp_wholeheadT1_to_template, - warp_mask_to_template(space) + warp_to_template("wholehead", space), + warp_to_template("mask", space) ] if not rpool.check_rpool('desc-restore-brain_T1w'): diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index 6bed6b0b14..a73ff840b2 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -33,7 +33,8 @@ hardcoded_reg, \ one_d_to_mat, \ run_c3d, \ - run_c4d + run_c4d, \ + prepend_space from CPAC.utils.interfaces.fsl import Merge as fslMerge from CPAC.utils.typing import LIST_OR_STR, TUPLE from CPAC.utils.utils import check_prov_for_motion_tool, check_prov_for_regtool @@ -3512,83 +3513,53 @@ def apply_blip_to_timeseries_separately(wf, cfg, strat_pool, pipe_num, return (wf, outputs) -@nodeblock( - name="transform_whole_head_T1w_to_T1template", - config=["registration_workflows", "anatomical_registration"], - switch=["run"], - inputs=[ - ( - ["desc-head_T1w", "space-longitudinal_desc-reorient_T1w"], - ["from-T1w_to-template_mode-image_xfm", - "from-longitudinal_to-template_mode-image_xfm"], - "space-template_desc-head_T1w", - ), - "T1w-template", - ], - outputs={"space-template_desc-head_T1w": {"Template": "T1w-template"}}, -) -def warp_wholeheadT1_to_template(wf, cfg, strat_pool, pipe_num, opt=None): - xfm: list[str] = ["from-T1w_to-template_mode-image_xfm", - "from-longitudinal_to-template_mode-image_xfm"] - xfm_prov = strat_pool.get_cpac_provenance(xfm) - reg_tool = check_prov_for_regtool(xfm_prov) - - num_cpus = cfg.pipeline_setup['system_config'][ - 'max_cores_per_participant'] - - num_ants_cores = cfg.pipeline_setup['system_config']['num_ants_threads'] - - apply_xfm = apply_transform(f'warp_wholehead_T1w_to_T1template_{pipe_num}', - reg_tool, time_series=False, num_cpus=num_cpus, - num_ants_cores=num_ants_cores) - - if reg_tool == 'ants': - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - 'functional_registration']['func_registration_to_template'][ - 'ANTs_pipelines']['interpolation'] - elif reg_tool == 'fsl': - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - 'functional_registration']['func_registration_to_template'][ - 'FNIRT_pipelines']['interpolation'] - - connect = strat_pool.get_data(["desc-head_T1w", - "space-longitudinal_desc-reorient_T1w"]) - node, out = connect - wf.connect(node, out, apply_xfm, 'inputspec.input_image') - - node, out = strat_pool.get_data("T1w-template") - wf.connect(node, out, apply_xfm, 'inputspec.reference') - - node, out = strat_pool.get_data("from-T1w_to-template_mode-image_xfm") - wf.connect(node, out, apply_xfm, 'inputspec.transform') - - outputs = { - 'space-template_desc-head_T1w': (apply_xfm, 'outputspec.output_image') - } +def warp_to_template(warp_what: Literal["mask", "wholehead"], + space_from: Literal["longitudinal", "T1w"]) -> NodeBlockFunction: + """Get a NodeBlockFunction to transform a resource from ``space`` to template. - return (wf, outputs) - -def warp_mask_to_template(space: Literal["longitudinal", "T1w"]) -> NodeBlockFunction: - """Get a NodeBlockFunction to transform a mask from ``space`` to template.""" - @nodeblock( - name=f"transform_{space}-mask_to_T1-template", - switch=[ + The resource being warped needs to be the first list or string in the tuple + in the first position of the decorator's "inputs". + """ + _decorators = {"mask": { + "name": f"transform_{space_from}-mask_to_T1-template", + "switch": [ ["registration_workflows", "anatomical_registration", "run"], ["anatomical_preproc", "run"], ["anatomical_preproc", "brain_extraction", "run"], ], - inputs=[ - (f"space-{space}_desc-brain_mask", - f"from-{space}_to-template_mode-image_xfm"), + "inputs": [ + (f"space-{space_from}_desc-brain_mask", + f"from-{space_from}_to-template_mode-image_xfm"), "T1w-template", ], - outputs={"space-template_desc-brain_mask": {"Template": "T1w-template"}}, - ) - def warp_mask_to_template_fxn(wf, cfg, strat_pool, pipe_num, opt=None): - """Transform a mask to template space.""" + "outputs": {"space-template_desc-brain_mask": {"Template": "T1w-template"}}, + }, "wholehead": { + "name": f"transform_wholehead_{space_from}_to_T1template", + "config": ["registration_workflows", "anatomical_registration"], + "switch": ["run"], + "inputs": [ + ( + ["desc-head_T1w", "desc-reorient_T1w"], + [f"from-{space_from}_to-template_mode-image_xfm", + f"from-{space_from}_to-template_mode-image_xfm"], + "space-template_desc-head_T1w", + ), + "T1w-template", + ], + "outputs": {"space-template_desc-head_T1w": {"Template": "T1w-template"}}, + }} + if space_from != "T1w": + _decorators[warp_what]["inputs"][0] = tuple((prepend_space( + _decorators[warp_what]["inputs"][0][0], space_from), + *_decorators[warp_what]["inputs"][0][1:] + )) + + @nodeblock(**_decorators[warp_what]) + def warp_to_template_fxn(wf, cfg, strat_pool, pipe_num, opt=None): + """Transform a resource to template space.""" xfm_prov = strat_pool.get_cpac_provenance( - f'from-{space}_to-template_mode-image_xfm') + f'from-{space_from}_to-template_mode-image_xfm') reg_tool = check_prov_for_regtool(xfm_prov) num_cpus = cfg.pipeline_setup['system_config'][ @@ -3596,37 +3567,40 @@ def warp_mask_to_template_fxn(wf, cfg, strat_pool, pipe_num, opt=None): num_ants_cores = cfg.pipeline_setup['system_config']['num_ants_threads'] - apply_xfm = apply_transform(f'warp_T1mask_to_T1template_{pipe_num}', - reg_tool, time_series=False, num_cpus=num_cpus, - num_ants_cores=num_ants_cores) + apply_xfm = apply_transform( + f'warp_{space_from}{warp_what}_to_T1template_{pipe_num}', + reg_tool, time_series=False, num_cpus=num_cpus, + num_ants_cores=num_ants_cores) - apply_xfm.inputs.inputspec.interpolation = "NearestNeighbor" - ''' - if reg_tool == 'ants': - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - 'functional_registration']['func_registration_to_template'][ - 'ANTs_pipelines']['interpolation'] - elif reg_tool == 'fsl': + if warp_what == "mask": + apply_xfm.inputs.inputspec.interpolation = "NearestNeighbor" + else: + tool = "ANTs" if reg_tool == 'ants' else 'FNIRT' if reg_tool == 'fsl' else None + if not tool: + msg = f"Warp {warp_what} to template not implemented for {reg_tool}." + raise NotImplementedError(msg) apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ 'functional_registration']['func_registration_to_template'][ - 'FNIRT_pipelines']['interpolation'] - ''' - connect = strat_pool.get_data(f"space-{space}_desc-brain_mask") - node, out = connect + f'{tool}_pipelines']['interpolation'] + + # the resource being warped needs to be inputs[0][0] for this + node, out = strat_pool.get_data(_decorators[warp_what]["inputs"][0][0]) wf.connect(node, out, apply_xfm, 'inputspec.input_image') node, out = strat_pool.get_data("T1w-template") wf.connect(node, out, apply_xfm, 'inputspec.reference') - node, out = strat_pool.get_data(f"from-{space}_to-template_mode-image_xfm") + node, out = strat_pool.get_data(f"from-{space_from}_to-template_mode-image_xfm") wf.connect(node, out, apply_xfm, 'inputspec.transform') outputs = { - 'space-template_desc-brain_mask': (apply_xfm, 'outputspec.output_image') + # there's only one output, so that's what we give here + list(_decorators[warp_what]["outputs"].keys())[0]: ( + apply_xfm, 'outputspec.output_image') } return wf, outputs - return warp_mask_to_template_fxn + return warp_to_template_fxn @nodeblock( diff --git a/CPAC/registration/utils.py b/CPAC/registration/utils.py index 1185f0190b..2bf3d62850 100644 --- a/CPAC/registration/utils.py +++ b/CPAC/registration/utils.py @@ -1,4 +1,22 @@ +# Copyright (C) 2014-2024 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +# pylint: disable=too-many-lines,ungrouped-imports,wrong-import-order import os +from typing import overload import numpy as np @@ -638,3 +656,18 @@ def run_c4d(input, output_name): os.system(cmd) return output1, output2, output3 + + +@overload +def prepend_space(resource: list[str], space: str) -> list[str]: ... +@overload +def prepend_space(resource: str, space: str) -> str: ... +def prepend_space(resource: str | list[str], space: str) -> str | list[str]: + """Given a resource or list of resources, return same but with updated space.""" + if isinstance(resource, list): + return [prepend_space(_, space) for _ in resource] + if "space" not in resource: + return f"space-{space}_{resource}" + pre, post = resource.split("space-") + _old_space, post = post.split("_", 1) + return f"space-{space}_".join([pre, post])