From 0cf793fa2cf983e7f8386a3b5ed15c5aa39bb666 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Fri, 8 Nov 2024 17:15:48 -0500 Subject: [PATCH] :necktie: Handle connections from unconnected graphs that run separately --- .../longitudinal_workflow.py | 31 +++++++++++++------ .../pipeline/nipype_pipeline_engine/engine.py | 18 ++++++++++- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/CPAC/longitudinal_pipeline/longitudinal_workflow.py b/CPAC/longitudinal_pipeline/longitudinal_workflow.py index f8dd8f183..4cd9b61da 100644 --- a/CPAC/longitudinal_pipeline/longitudinal_workflow.py +++ b/CPAC/longitudinal_pipeline/longitudinal_workflow.py @@ -19,7 +19,7 @@ import os import shutil import time -from typing import Optional +from typing import cast, Optional from CPAC.pipeline.nodeblock import nodeblock @@ -455,8 +455,8 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur # Loop over the sessions to create the input for the longitudinal # algorithm - strats_dct: dict[str, list[tuple[pe.Node, str]]] = {"desc-brain_T1w": [], - "desc-head_T1w": []} + strats_dct: dict[str, list[tuple[pe.Node, str] | str]] = {"desc-brain_T1w": [], + "desc-head_T1w": []} for i, session in enumerate(sub_list): unique_id: str = session['unique_id'] @@ -489,13 +489,16 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur session_wfs[unique_id] = rpool rpool.gather_pipes(workflow, config) - for key in strats_dct.keys(): - _resource: tuple[pe.Node, str] = rpool.get_data(key) - clone = _resource[0].clone(f"{_resource[0].name}_{session_id_list[i]}") - workflow.copy_input_connections(_resource[0], clone) - strats_dct[key].append((clone, _resource[1])) + if dry_run: # build tbe graphs with connections that may be in other graphs + for key in strats_dct.keys(): + _resource = cast(tuple[pe.Node, str], rpool.get_data(key)) + clone = _resource[0].clone(f"{_resource[0].name}_{session_id_list[i]}") + workflow.copy_input_connections(_resource[0], clone) + strats_dct[key].append((clone, _resource[1])) if not dry_run: workflow.run() + for key in strats_dct.keys(): # get the outputs from run-nodes + strats_dct[key].append(workflow.get_output_path(key, rpool)) wf = initialize_nipype_wf(config, sub_list[0], # just grab the first one for the name @@ -533,8 +536,8 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur merge_skulls = pe.Node(Merge(num_sessions), name="merge_skulls") for i in list(range(0, num_sessions)): - wf.connect(*strats_dct["desc-brain_T1w"][i], merge_brains, f"in{i + 1}") - wf.connect(*strats_dct["desc-head_T1w"][i], merge_skulls, f"in{i + 1}") + _connect_node_or_path(wf, merge_brains, strats_dct, "desc-brain_T1w", i) + _connect_node_or_path(wf, merge_skulls, strats_dct, "desc-head_T1w", i) wf.connect(merge_brains, "out", template_node, "input_brain_list") wf.connect(merge_skulls, "out", template_node, "input_skull_list") @@ -1198,3 +1201,11 @@ def func_longitudinal_template_wf(subject_id, strat_list, config): workflow.run() return + +def _connect_node_or_path(wf: pe.Workflow, node: pe.Node, strats_dct: dict[str, list[tuple[pe.Node, str] | str]], key: str, index: int) -> None: + """Set input appropriately for either a Node or a path string.""" + input: str = f"in{index + 1}" + if isinstance(strats_dct[key][index], str): + setattr(node.inputs, input, strats_dct[key][index]) + else: + wf.connect(*strats_dct[key][index], node, input) diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index c899fbc12..ca2fefa73 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -51,8 +51,9 @@ import re from copy import deepcopy from inspect import Parameter, Signature, signature -from typing import ClassVar, Optional, Union +from typing import ClassVar, Optional, TYPE_CHECKING, Union from nibabel import load +from nipype.interfaces.base.support import InterfaceResult from nipype.interfaces.utility import Function from nipype.pipeline import engine as pe from nipype.pipeline.engine.utils import ( @@ -71,6 +72,8 @@ from traits.trait_handlers import TraitListObject from CPAC.utils.monitoring.custom_logging import getLogger from CPAC.utils.typing import DICT +if TYPE_CHECKING: + from CPAC.pipeline.engine import ResourcePool # set global default mem_gb DEFAULT_MEM_GB = 2.0 @@ -664,6 +667,19 @@ def _get_dot( logger.debug("cross connection: %s", dotlist[-1]) return ("\n" + prefix).join(dotlist) + def get_output_path(self, key: str, rpool: "ResourcePool") -> str: + """Get an output path from an already-run Node.""" + _node, _out = rpool.get_data(key) + assert isinstance(_node, pe.Node) + assert isinstance(_out, str) + try: + _run_node: pe.Node = [_ for _ in self.run(updatehash=True).nodes if _.fullname == _node.fullname][0] + except IndexError as index_error: + msg = f"Could not find {key} in {self}'s run Nodes." + raise LookupError(msg) from index_error + _res: InterfaceResult = _run_node.run() + return getattr(_res.outputs, _out) + def _handle_just_in_time_exception(self, node): # pylint: disable=protected-access if hasattr(self, '_local_func_scans'):