Skip to content

Commit

Permalink
👔 Handle connections from unconnected graphs that run separately
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Nov 8, 2024
1 parent 4ddb58f commit 0cf793f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
31 changes: 21 additions & 10 deletions CPAC/longitudinal_pipeline/longitudinal_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import shutil
import time
from typing import Optional
from typing import cast, Optional

from CPAC.pipeline.nodeblock import nodeblock

Expand Down Expand Up @@ -455,8 +455,8 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur

# Loop over the sessions to create the input for the longitudinal
# algorithm
strats_dct: dict[str, list[tuple[pe.Node, str]]] = {"desc-brain_T1w": [],
"desc-head_T1w": []}
strats_dct: dict[str, list[tuple[pe.Node, str] | str]] = {"desc-brain_T1w": [],
"desc-head_T1w": []}
for i, session in enumerate(sub_list):

unique_id: str = session['unique_id']
Expand Down Expand Up @@ -489,13 +489,16 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur
session_wfs[unique_id] = rpool

rpool.gather_pipes(workflow, config)
for key in strats_dct.keys():
_resource: tuple[pe.Node, str] = rpool.get_data(key)
clone = _resource[0].clone(f"{_resource[0].name}_{session_id_list[i]}")
workflow.copy_input_connections(_resource[0], clone)
strats_dct[key].append((clone, _resource[1]))
if dry_run: # build tbe graphs with connections that may be in other graphs
for key in strats_dct.keys():
_resource = cast(tuple[pe.Node, str], rpool.get_data(key))
clone = _resource[0].clone(f"{_resource[0].name}_{session_id_list[i]}")
workflow.copy_input_connections(_resource[0], clone)
strats_dct[key].append((clone, _resource[1]))
if not dry_run:
workflow.run()
for key in strats_dct.keys(): # get the outputs from run-nodes
strats_dct[key].append(workflow.get_output_path(key, rpool))

wf = initialize_nipype_wf(config, sub_list[0],
# just grab the first one for the name
Expand Down Expand Up @@ -533,8 +536,8 @@ def anat_longitudinal_wf(subject_id: str, sub_list: list[dict], config: Configur
merge_skulls = pe.Node(Merge(num_sessions), name="merge_skulls")

for i in list(range(0, num_sessions)):
wf.connect(*strats_dct["desc-brain_T1w"][i], merge_brains, f"in{i + 1}")
wf.connect(*strats_dct["desc-head_T1w"][i], merge_skulls, f"in{i + 1}")
_connect_node_or_path(wf, merge_brains, strats_dct, "desc-brain_T1w", i)
_connect_node_or_path(wf, merge_skulls, strats_dct, "desc-head_T1w", i)
wf.connect(merge_brains, "out", template_node, "input_brain_list")
wf.connect(merge_skulls, "out", template_node, "input_skull_list")

Expand Down Expand Up @@ -1198,3 +1201,11 @@ def func_longitudinal_template_wf(subject_id, strat_list, config):
workflow.run()

return

def _connect_node_or_path(wf: pe.Workflow, node: pe.Node, strats_dct: dict[str, list[tuple[pe.Node, str] | str]], key: str, index: int) -> None:
"""Set input appropriately for either a Node or a path string."""
input: str = f"in{index + 1}"
if isinstance(strats_dct[key][index], str):
setattr(node.inputs, input, strats_dct[key][index])
else:
wf.connect(*strats_dct[key][index], node, input)
18 changes: 17 additions & 1 deletion CPAC/pipeline/nipype_pipeline_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
import re
from copy import deepcopy
from inspect import Parameter, Signature, signature
from typing import ClassVar, Optional, Union
from typing import ClassVar, Optional, TYPE_CHECKING, Union
from nibabel import load
from nipype.interfaces.base.support import InterfaceResult
from nipype.interfaces.utility import Function
from nipype.pipeline import engine as pe
from nipype.pipeline.engine.utils import (
Expand All @@ -71,6 +72,8 @@
from traits.trait_handlers import TraitListObject
from CPAC.utils.monitoring.custom_logging import getLogger
from CPAC.utils.typing import DICT
if TYPE_CHECKING:
from CPAC.pipeline.engine import ResourcePool

# set global default mem_gb
DEFAULT_MEM_GB = 2.0
Expand Down Expand Up @@ -664,6 +667,19 @@ def _get_dot(
logger.debug("cross connection: %s", dotlist[-1])
return ("\n" + prefix).join(dotlist)

def get_output_path(self, key: str, rpool: "ResourcePool") -> str:
"""Get an output path from an already-run Node."""
_node, _out = rpool.get_data(key)
assert isinstance(_node, pe.Node)
assert isinstance(_out, str)
try:
_run_node: pe.Node = [_ for _ in self.run(updatehash=True).nodes if _.fullname == _node.fullname][0]
except IndexError as index_error:
msg = f"Could not find {key} in {self}'s run Nodes."
raise LookupError(msg) from index_error
_res: InterfaceResult = _run_node.run()
return getattr(_res.outputs, _out)

def _handle_just_in_time_exception(self, node):
# pylint: disable=protected-access
if hasattr(self, '_local_func_scans'):
Expand Down

0 comments on commit 0cf793f

Please sign in to comment.