From bbf3e97a239c8cf4402048a555cc73e0ae5614b4 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Wed, 10 Jul 2024 17:01:03 -0400 Subject: [PATCH] :recycle: Fold `initiate_rpool` into `ResourcePool.__init__` --- .../longitudinal_workflow.py | 40 +- CPAC/pipeline/cpac_pipeline.py | 31 +- CPAC/pipeline/engine/__init__.py | 3 +- CPAC/pipeline/engine/engine.py | 4 +- CPAC/pipeline/engine/resource.py | 406 ++++++++++-------- CPAC/pipeline/test/test_engine.py | 46 +- CPAC/utils/configuration/configuration.py | 7 + 7 files changed, 263 insertions(+), 274 deletions(-) diff --git a/CPAC/longitudinal_pipeline/longitudinal_workflow.py b/CPAC/longitudinal_pipeline/longitudinal_workflow.py index 829e123de4..aacbea6b8d 100644 --- a/CPAC/longitudinal_pipeline/longitudinal_workflow.py +++ b/CPAC/longitudinal_pipeline/longitudinal_workflow.py @@ -29,9 +29,8 @@ build_segmentation_stack, build_T1w_registration_stack, connect_pipeline, - initialize_nipype_wf, ) -from CPAC.pipeline.engine import initiate_rpool +from CPAC.pipeline.engine import ResourcePool from CPAC.pipeline.nodeblock import nodeblock from CPAC.registration import ( create_fsl_flirt_linear_reg, @@ -429,16 +428,13 @@ def anat_longitudinal_wf(subject_id, sub_list, config): except KeyError: input_creds_path = None - workflow = initialize_nipype_wf( - config, - session, - # just grab the first one for the name - name="anat_longitudinal_pre-preproc", + rpool = ResourcePool( + cfg=config, + data_paths=session, + pipeline_name="anat_longitudinal_pre-preproc", ) - - rpool = initiate_rpool(workflow, config, session) pipeline_blocks = build_anat_preproc_stack(rpool, config) - workflow = connect_pipeline(workflow, config, rpool, pipeline_blocks) + workflow = connect_pipeline(rpool.wf, config, rpool, pipeline_blocks) session_wfs[unique_id] = rpool @@ -474,13 +470,6 @@ def anat_longitudinal_wf(subject_id, sub_list, config): ) for strat in strats_brain_dct.keys(): - wf = initialize_nipype_wf( - config, - sub_list[0], - # just grab the first one for the name - name=f"template_node_{strat}", - ) - config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}" template_node_name = f"longitudinal_anat_template_{strat}" @@ -508,7 +497,9 @@ def anat_longitudinal_wf(subject_id, sub_list, config): template_node.inputs.input_skull_list = strats_head_dct[strat] long_id = f"longitudinal_{subject_id}_strat-{strat}" - rpool = initiate_rpool(wf, config, part_id=long_id) + rpool = ResourcePool( + cfg=config, part_id=long_id, pipeline_name=f"template_node_{strat}" + ) rpool.set_data( "space-longitudinal_desc-brain_T1w", template_node, @@ -551,7 +542,7 @@ def anat_longitudinal_wf(subject_id, sub_list, config): pipeline_blocks = build_segmentation_stack(rpool, config, pipeline_blocks) - wf = connect_pipeline(wf, config, rpool, pipeline_blocks) + wf = connect_pipeline(rpool.wf, config, rpool, pipeline_blocks) excl = [ "space-longitudinal_desc-brain_T1w", @@ -586,10 +577,9 @@ def anat_longitudinal_wf(subject_id, sub_list, config): except KeyError: session["creds_path"] = None - wf = initialize_nipype_wf(config, session) - rpool = initiate_rpool(wf, config, session) - config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}" + rpool = ResourcePool(cfg=config, data_paths=session) + wf = rpool.wf rpool.ingress_output_dir() select_node_name = f"select_{unique_id}" @@ -651,15 +641,13 @@ def anat_longitudinal_wf(subject_id, sub_list, config): except KeyError: input_creds_path = None session["creds_path"] = input_creds_path - wf = initialize_nipype_wf(config, session) - rpool = initiate_rpool(wf, config, session) - + rpool = ResourcePool(cfg=config, data_paths=session) pipeline_blocks = [ warp_longitudinal_T1w_to_template, warp_longitudinal_seg_to_T1w, ] - wf = connect_pipeline(wf, config, rpool, pipeline_blocks) + wf = connect_pipeline(rpool.wf, config, rpool, pipeline_blocks) rpool.gather_pipes(wf, config) diff --git a/CPAC/pipeline/cpac_pipeline.py b/CPAC/pipeline/cpac_pipeline.py index f0baaa323c..9b5ed67141 100644 --- a/CPAC/pipeline/cpac_pipeline.py +++ b/CPAC/pipeline/cpac_pipeline.py @@ -128,9 +128,8 @@ ) # pylint: disable=wrong-import-order -from CPAC.pipeline import nipype_pipeline_engine as pe from CPAC.pipeline.check_outputs import check_outputs -from CPAC.pipeline.engine import initiate_rpool, NodeBlock +from CPAC.pipeline.engine import NodeBlock, ResourcePool from CPAC.pipeline.nipype_pipeline_engine.plugins import ( LegacyMultiProcPlugin, MultiProcPlugin, @@ -856,24 +855,6 @@ def remove_workdir(wdpath: str) -> None: FMLOGGER.warning("Could not remove working directory %s", wdpath) -def initialize_nipype_wf(cfg, sub_data_dct, name=""): - """Initialize a new nipype workflow.""" - if name: - name = f"_{name}" - - workflow_name = ( - f'cpac{name}_{sub_data_dct["subject_id"]}_{sub_data_dct["unique_id"]}' - ) - wf = pe.Workflow(name=workflow_name) - wf.base_dir = cfg.pipeline_setup["working_directory"]["path"] - wf.config["execution"] = { - "hash_method": "timestamp", - "crashdump_dir": os.path.abspath(cfg.pipeline_setup["log_directory"]["path"]), - } - - return wf - - def load_cpac_pipe_config(pipe_config): """Load in pipeline config file.""" config_file = os.path.realpath(pipe_config) @@ -1074,7 +1055,6 @@ def build_T1w_registration_stack(rpool, cfg, pipeline_blocks=None): warp_wholeheadT1_to_template, warp_T1mask_to_template, ] - if not rpool.check_rpool("desc-restore-brain_T1w"): reg_blocks.append(correct_restore_brain_intensity_abcd) @@ -1176,7 +1156,6 @@ def connect_pipeline(wf, cfg, rpool, pipeline_blocks): WFLOGGER.info( "Connecting pipeline blocks:\n%s", list_blocks(pipeline_blocks, indent=1) ) - previous_nb = None for block in pipeline_blocks: try: @@ -1221,9 +1200,6 @@ def build_workflow(subject_id, sub_dict, cfg, pipeline_name=None): """Build a C-PAC workflow for a single subject.""" from CPAC.utils.datasource import gather_extraction_maps - # Workflow setup - wf = initialize_nipype_wf(cfg, sub_dict, name=pipeline_name) - # Extract credentials path if it exists try: creds_path = sub_dict["creds_path"] @@ -1247,8 +1223,7 @@ def build_workflow(subject_id, sub_dict, cfg, pipeline_name=None): # PREPROCESSING # """"""""""""""""""""""""""""""""""""""""""""""""""" - rpool = initiate_rpool(wf, cfg, sub_dict) - + rpool = ResourcePool(cfg=cfg, data_paths=sub_dict, pipeline_name=pipeline_name) pipeline_blocks = build_anat_preproc_stack(rpool, cfg) # Anatomical to T1 template registration @@ -1615,7 +1590,7 @@ def build_workflow(subject_id, sub_dict, cfg, pipeline_name=None): # Connect the entire pipeline! try: - wf = connect_pipeline(wf, cfg, rpool, pipeline_blocks) + wf = connect_pipeline(rpool.wf, cfg, rpool, pipeline_blocks) except LookupError as lookup_error: missing_key = None errorstrings = [arg for arg in lookup_error.args[0].split("\n") if arg.strip()] diff --git a/CPAC/pipeline/engine/__init__.py b/CPAC/pipeline/engine/__init__.py index 1350e2bb36..dc1d077656 100644 --- a/CPAC/pipeline/engine/__init__.py +++ b/CPAC/pipeline/engine/__init__.py @@ -21,10 +21,9 @@ run_node_blocks, wrap_block, ) -from .resource import initiate_rpool, NodeData, ResourcePool +from .resource import NodeData, ResourcePool __all__ = [ - "initiate_rpool", "NodeBlock", "NodeData", "ResourcePool", diff --git a/CPAC/pipeline/engine/engine.py b/CPAC/pipeline/engine/engine.py index e6280ace5f..4187476bf7 100644 --- a/CPAC/pipeline/engine/engine.py +++ b/CPAC/pipeline/engine/engine.py @@ -527,7 +527,7 @@ def wrap_block(node_blocks, interface, wf, cfg, strat_pool, pipe_num, opt): def run_node_blocks(blocks, data_paths, cfg=None): from CPAC.pipeline.engine import NodeBlock - from CPAC.pipeline.engine.resource import initiate_rpool + from CPAC.pipeline.engine.resource import ResourcePool if not cfg: cfg = { @@ -540,7 +540,7 @@ def run_node_blocks(blocks, data_paths, cfg=None): # TODO: WE HAVE TO PARSE OVER UNIQUE ID'S!!! wf = pe.Workflow(name="node_blocks") - rpool = initiate_rpool(wf, cfg, data_paths) + rpool = ResourcePool(wf=wf, cfg=cfg, data_paths=data_paths) wf.base_dir = cfg.pipeline_setup["working_directory"]["path"] wf.config["execution"] = { "hash_method": "timestamp", diff --git a/CPAC/pipeline/engine/resource.py b/CPAC/pipeline/engine/resource.py index 5fc9add3db..72a3036fbf 100644 --- a/CPAC/pipeline/engine/resource.py +++ b/CPAC/pipeline/engine/resource.py @@ -17,12 +17,14 @@ """Resources and ResourcePools for C-PAC.""" import ast +from collections.abc import KeysView import copy from itertools import chain import os from pathlib import Path import re -from typing import Optional +from types import NoneType +from typing import Any, Optional import warnings from nipype.interfaces import utility as util @@ -39,7 +41,7 @@ from CPAC.registration.registration import transform_derivative from CPAC.resources.templates.lookup_table import lookup_identifier from CPAC.utils.bids_utils import res_in_filename -from CPAC.utils.configuration import Configuration +from CPAC.utils.configuration.configuration import Configuration, EmptyConfiguration from CPAC.utils.datasource import ( calc_delta_te_and_asym_ratio, check_for_s3, @@ -71,6 +73,78 @@ EXTS = [".nii", ".gz", ".mat", ".1D", ".txt", ".csv", ".rms", ".tsv"] +class DataPaths: + """Store subject-session specific data paths.""" + + def __init__(self, *, data_paths: Optional[dict] = None, part_id: str = "") -> None: + """Initialize a ``DataPaths`` instance.""" + if not data_paths: + data_paths = {} + if part_id and "part_id" in data_paths and part_id != data_paths["part_id"]: + WFLOGGER.warning( + "both 'part_id' (%s) and data_paths['part_id'] (%s) provided. " + "Using '%s'.", + part_id, + data_paths["part_id"], + part_id, + ) + anat: dict[str, str] | str = data_paths.get("anat", {}) + if isinstance(anat, str): + anat = {"T1": anat} + self.anat: dict[str, str] = anat + self.creds_path: Optional[str] = data_paths.get("creds_path") + self.fmap: Optional[dict] = data_paths.get("fmap") + self.func: dict[str, dict[str, str | dict]] = data_paths.get("func", {}) + self.part_id: str = data_paths.get("subject_id", "") + self.site_id: str = data_paths.get("site_id", "") + self.ses_id: str = data_paths.get("unique_id", "") + self.unique_id: str = "_".join([self.part_id, self.ses_id]) + self.derivatives_dir: Optional[str] = data_paths.get("derivatives_dir") + + def __repr__(self) -> str: + """Return reproducible string representation of ``DataPaths`` instance.""" + return f"DataPaths(data_paths={self.as_dict()})" + + def __str__(self) -> str: + """Return string representation of a ``DataPaths`` instance.""" + return f"" + + def as_dict(self) -> dict: + """Return ``data_paths`` dictionary. + + data_paths format:: + + {"anat": {"T1w": "{T1w path}", "T2w": "{T2w path}"}, + "creds_path": {None OR path to credentials CSV}, + "func": { + "{scan ID}": { + "scan": "{path to BOLD}", + "scan_parameters": {scan parameter dictionary}, + } + }, + "site_id": "site-ID", + "subject_id": "sub-01", + "unique_id": "ses-1", + "derivatives_dir": "{derivatives_dir path}",} + """ + return { + k: v + for k, v in { + key: getattr(self, key) + for key in [ + "anat", + "creds_path", + "func", + "site_id", + "subject_id", + "unique_id", + "derivatives_dir", + ] + }.items() + if v + } + + def generate_prov_string(prov: list[str]) -> tuple[str, str]: """Generate a string from a SINGLE RESOURCE'S dictionary of MULTIPLE PRECEDING RESOURCES (or single, if just one). @@ -253,30 +327,26 @@ class ResourcePool: def __init__( self, - rpool: Optional[dict] = None, name: str = "", cfg: Optional[Configuration] = None, pipe_list: Optional[list] = None, *, - creds_path: Optional[str] = None, - data_paths: Optional[dict] = None, - part_id: Optional[str] = None, - ses_id: Optional[str] = None, - unique_id: Optional[str] = None, + data_paths: Optional[DataPaths | dict] = None, + pipeline_name: str = "", wf: Optional[pe.Workflow] = None, - **kwargs, ): """Initialize a ResourcePool.""" - self.creds_path = creds_path + if isinstance(data_paths, dict): + data_paths = DataPaths(data_paths=data_paths) + elif not data_paths: + data_paths = DataPaths() self.data_paths = data_paths - self.part_id = part_id - self.ses_id = ses_id - self.unique_id = unique_id - self._init_wf = wf - if not rpool: - self.rpool = {} - else: - self.rpool = rpool + # pass-through for convenient access + self.creds_path = self.data_paths.creds_path + self.part_id = self.data_paths.part_id + self.ses_id = self.data_paths.ses_id + self.unique_id = self.data_paths.unique_id + self.rpool = {} if not pipe_list: self.pipe_list = [] @@ -288,36 +358,67 @@ def __init__( if cfg: self.cfg = cfg - self.logdir = cfg.pipeline_setup["log_directory"]["path"] + else: + self.cfg = EmptyConfiguration() - self.num_cpus = cfg.pipeline_setup["system_config"][ - "max_cores_per_participant" + self.logdir = self._config_lookup(["pipeline_setup", "log_directory", "path"]) + self.num_cpus = self._config_lookup( + ["pipeline_setup", "system_config", "max_cores_per_participant"] + ) + self.num_ants_cores = self._config_lookup( + ["pipeline_setup", "system_config", "num_ants_threads"] + ) + + self.ants_interp = self._config_lookup( + [ + "registration_workflows", + "functional_registration", + "func_registration_to_template", + "ANTs_pipelines", + "interpolation", + ] + ) + self.fsl_interp = self._config_lookup( + [ + "registration_workflows", + "functional_registration", + "func_registration_to_template", + "FNIRT_pipelines", + "interpolation", ] - self.num_ants_cores = cfg.pipeline_setup["system_config"][ - "num_ants_threads" + ) + self.func_reg = self._config_lookup( + [ + "registration_workflows", + "functional_registration", + "func_registration_to_template", + "run", ] + ) - self.ants_interp = cfg.registration_workflows["functional_registration"][ - "func_registration_to_template" - ]["ANTs_pipelines"]["interpolation"] - self.fsl_interp = cfg.registration_workflows["functional_registration"][ - "func_registration_to_template" - ]["FNIRT_pipelines"]["interpolation"] - - self.func_reg = cfg.registration_workflows["functional_registration"][ - "func_registration_to_template" - ]["run"] + self.run_smoothing = "smoothed" in self._config_lookup( + ["post_processing", "spatial_smoothing", "output"], list + ) + self.smoothing_bool = self._config_lookup( + ["post_processing", "spatial_smoothing", "run"] + ) + self.run_zscoring = "z-scored" in self._config_lookup( + ["post_processing", "z-scoring", "output"], list + ) + self.zscoring_bool = self._config_lookup( + ["post_processing", "z-scoring", "run"] + ) + self.fwhm = self._config_lookup( + ["post_processing", "spatial_smoothing", "fwhm"] + ) + self.smooth_opts = self._config_lookup( + ["post_processing", "spatial_smoothing", "smoothing_method"] + ) - self.run_smoothing = ( - "smoothed" in cfg.post_processing["spatial_smoothing"]["output"] - ) - self.smoothing_bool = cfg.post_processing["spatial_smoothing"]["run"] - self.run_zscoring = "z-scored" in cfg.post_processing["z-scoring"]["output"] - self.zscoring_bool = cfg.post_processing["z-scoring"]["run"] - self.fwhm = cfg.post_processing["spatial_smoothing"]["fwhm"] - self.smooth_opts = cfg.post_processing["spatial_smoothing"][ - "smoothing_method" - ] + if wf: + self.wf = wf + else: + self.initialize_nipype_wf(pipeline_name) self.xfm = [ "alff", @@ -333,6 +434,21 @@ def __init__( "desc-zstd_reho", "desc-sm-zstd_reho", ] + ingress_derivatives = False + try: + if self.data_paths.derivatives_dir and self._config_lookup( + ["pipeline_setup", "outdir_ingress", "run"], bool + ): + ingress_derivatives = True + except (AttributeError, KeyError, TypeError): + pass + if ingress_derivatives: + self.ingress_output_dir() + else: + self.ingress_raw_anat_data() + if data_paths.func: + self.ingress_raw_func_data() + self.ingress_pipeconfig_paths() def __repr__(self) -> str: """Return reproducible ResourcePool string.""" @@ -349,6 +465,27 @@ def __str__(self) -> str: return f"ResourcePool({self.name}): {list(self.rpool)}" return f"ResourcePool: {list(self.rpool)}" + def initialize_nipype_wf(self, name: str = "") -> None: + """Initialize a new nipype workflow.""" + if name: + name = f"_{name}" + workflow_name = f"cpac{name}_{self.unique_id}" + self.wf = pe.Workflow(name=workflow_name) + self.wf.base_dir = self.cfg.pipeline_setup["working_directory"]["path"] + self.wf.config["execution"] = { + "hash_method": "timestamp", + "crashdump_dir": os.path.abspath( + self.cfg.pipeline_setup["log_directory"]["path"] + ), + } + + def _config_lookup(self, keylist, fallback_type: type = NoneType) -> Any: + """Lookup a config key, return None if not found.""" + try: + return self.cfg[keylist] + except (AttributeError, KeyError): + return fallback_type() + def back_propogate_template_name( self, resource_idx: str, json_info: dict, id_string: "pe.Node" ) -> None: @@ -369,7 +506,7 @@ def back_propogate_template_name( if "template" in resource_idx and self.check_rpool("derivatives-dir"): if self.check_rpool("template"): node, out = self.get_data("template") - self._init_wf.connect(node, out, id_string, "template_desc") + self.wf.connect(node, out, id_string, "template_desc") elif "Template" in json_info: id_string.inputs.template_desc = json_info["Template"] elif ( @@ -536,16 +673,12 @@ def set_pool_info(self, info_dct): def get_entire_rpool(self): return self.rpool - def get_resources(self): + def keys(self) -> KeysView: + """Return rpool's keys.""" return self.rpool.keys() - def copy_rpool(self): - return ResourcePool( - rpool=copy.deepcopy(self.get_entire_rpool()), - name=self.name, - cfg=self.cfg, - pipe_list=copy.deepcopy(self.pipe_list), - ) + def get_resources(self): + return self.rpool.keys() @staticmethod def get_raw_label(resource: str) -> str: @@ -863,10 +996,9 @@ def flatten_prov(self, prov): return flat_prov return None - def get_strats(self, resources, debug=False): + def get_strats(self, resources, debug=False) -> dict[str | tuple, "StratPool"]: # TODO: NOTE: NOT COMPATIBLE WITH SUB-RPOOL/STRAT_POOLS # TODO: (and it doesn't have to be) - import itertools linked_resources = [] @@ -952,7 +1084,7 @@ def get_strats(self, resources, debug=False): # we now currently have "strats", the combined permutations of all the strategies, as a list of tuples, each tuple combining one version of input each, being one of the permutations. # OF ALL THE DIFFERENT INPUTS. and they are tagged by their fetched inputs with {name}:{strat}. # so, each tuple has ONE STRAT FOR EACH INPUT, so if there are three inputs, each tuple will have 3 items. - new_strats = {} + new_strats: dict[str | tuple, StratPool] = {} # get rid of duplicates - TODO: refactor .product strat_str_list = [] @@ -1055,7 +1187,7 @@ def get_strats(self, resources, debug=False): # make the merged strat label from the multiple inputs # strat_list is actually the merged CpacProvenance lists pipe_idx = str(strat_list) - new_strats[pipe_idx] = ResourcePool() + new_strats[pipe_idx] = StratPool() # new_strats is A DICTIONARY OF RESOURCEPOOL OBJECTS! # placing JSON info at one level higher only for copy convenience new_strats[pipe_idx].rpool["json"] = {} @@ -1098,7 +1230,7 @@ def get_strats(self, resources, debug=False): resource, pipe_idx = generate_prov_string(cpac_prov) resource_strat_dct = self.rpool[resource][pipe_idx] # remember, `resource_strat_dct` is the dct of 'data' and 'json'. - new_strats[pipe_idx] = ResourcePool( + new_strats[pipe_idx] = StratPool( rpool={resource: resource_strat_dct} ) # <----- again, new_strats is A DICTIONARY OF RESOURCEPOOL OBJECTS! # placing JSON info at one level higher only for copy convenience @@ -1429,9 +1561,9 @@ def gather_pipes(self, wf, cfg, all=False, add_incl=None, add_excl=None): # TODO: other stuff like acq- etc. for pipe_idx in self.rpool[resource]: - unique_id = self.get_name() - part_id = unique_id.split("_")[0] - ses_id = unique_id.split("_")[1] + unique_id = self.unique_id + part_id = self.part_id + ses_id = self.ses_id if "ses-" not in ses_id: ses_id = f"ses-{ses_id}" @@ -1819,7 +1951,7 @@ def ingress_freesurfer(self) -> None: def ingress_output_dir(self) -> None: """Ingress an output directory into a ResourcePool.""" - dir_path = self.data_paths["derivatives_dir"] + dir_path = self.data_paths.derivatives_dir WFLOGGER.info("\nPulling outputs from %s.\n", dir_path) @@ -1971,11 +2103,11 @@ def ingress_func_metadata( blip = False fmap_rp_list = [] fmap_TE_list = [] - if "fmap" in self.data_paths: + if self.data_paths.fmap: second = False - for orig_key in self.data_paths["fmap"]: + for orig_key in self.data_paths.fmap: gather_fmap = create_fmap_datasource( - self.data_paths["fmap"], f"fmap_gather_{orig_key}_{self.part_id}" + self.data_paths.fmap, f"fmap_gather_{orig_key}_{self.part_id}" ) gather_fmap.inputs.inputnode.set( subject=self.part_id, @@ -2023,7 +2155,7 @@ def ingress_func_metadata( name=f"{key}_get_metadata{name_suffix}", ) - self._init_wf.connect( + self.wf.connect( gather_fmap, "outputspec.scan_params", get_fmap_metadata, @@ -2140,13 +2272,13 @@ def ingress_func_metadata( node, out_file = self.get(fmap_file)[ f"['{fmap_file}:fmap_TE_ingress']" ]["data"] - self._init_wf.connect( + self.wf.connect( node, out_file, gather_echoes, f"echotime_{idx}" ) except KeyError: pass - self._init_wf.connect( + self.wf.connect( gather_echoes, "echotime_list", calc_delta_ratio, "echo_times" ) @@ -2185,7 +2317,7 @@ def ingress_func_metadata( ) node, out = self.get("scan")["['scan:func_ingress']"]["data"] - self._init_wf.connect(node, out, scan_params, "scan") + self.wf.connect(node, out, scan_params, "scan") # Workaround for extracting metadata with ingress if self.check_rpool("derivatives-dir"): @@ -2198,10 +2330,10 @@ def ingress_func_metadata( ), name="selectrest_json", ) - selectrest_json.inputs.rest_dict = self.data_paths + selectrest_json.inputs.rest_dict = self.data_paths.as_dict() selectrest_json.inputs.resource = "scan_parameters" - self._init_wf.connect(node, out, selectrest_json, "scan") - self._init_wf.connect( + self.wf.connect(node, out, selectrest_json, "scan") + self.wf.connect( selectrest_json, "file_path", scan_params, "data_config_scan_params" ) @@ -2210,7 +2342,7 @@ def ingress_func_metadata( node, out = self.get("scan-params")["['scan-params:scan_params_ingress']"][ "data" ] - self._init_wf.connect(node, out, scan_params, "data_config_scan_params") + self.wf.connect(node, out, scan_params, "data_config_scan_params") self.set_data("TR", scan_params, "tr", {}, "", "func_metadata_ingress") self.set_data( @@ -2242,9 +2374,7 @@ def ingress_func_metadata( node, out_file = self.get("effectiveEchoSpacing")[ "['effectiveEchoSpacing:func_metadata_ingress']" ]["data"] - self._init_wf.connect( - node, out_file, calc_delta_ratio, "effective_echo_spacing" - ) + self.wf.connect(node, out_file, calc_delta_ratio, "effective_echo_spacing") self.set_data( "deltaTE", calc_delta_ratio, "deltaTE", {}, "", "deltaTE_ingress" ) @@ -2372,7 +2502,7 @@ def ingress_pipeconfig_paths(self): def ingress_raw_func_data(self): """Ingress raw functional data.""" - func_paths_dct = self.data_paths["func"] + func_paths_dct = self.data_paths.func func_wf = self.create_func_datasource( func_paths_dct, f"func_ingress_{self.part_id}_{self.ses_id}" @@ -2411,7 +2541,7 @@ def ingress_raw_func_data(self): ] if local_func_scans: # pylint: disable=protected-access - self._init_wf._local_func_scans = local_func_scans + self.wf._local_func_scans = local_func_scans if self.cfg.pipeline_setup["Debugging"]["verbose"]: verbose_logger = getLogger("CPAC.engine") verbose_logger.debug("local_func_scans: %s", local_func_scans) @@ -2464,7 +2594,7 @@ def func_outdir_ingress(self, func_dict: dict, key: str, func_paths: dict) -> No ) iterables.inputs.mask_paths = func_paths[mask_paths_key] iterables.inputs.ts_paths = func_paths[ts_paths_key] - self._init_wf.connect(ingress, "outputspec.scan", iterables, "scan") + self.wf.connect(ingress, "outputspec.scan", iterables, "scan") for key in func_paths: if key in (mask_paths_key, ts_paths_key): @@ -2474,13 +2604,9 @@ def func_outdir_ingress(self, func_dict: dict, key: str, func_paths: dict) -> No creds_path=self.creds_path, dl_dir=self.cfg.pipeline_setup["working_directory"]["path"], ) - self._init_wf.connect( - iterables, "out_scan", ingress_func, "inputnode.scan" - ) + self.wf.connect(iterables, "out_scan", ingress_func, "inputnode.scan") if key == mask_paths_key: - self._init_wf.connect( - iterables, "mask", ingress_func, "inputnode.data" - ) + self.wf.connect(iterables, "mask", ingress_func, "inputnode.data") self.set_data( key, ingress_func, @@ -2490,7 +2616,7 @@ def func_outdir_ingress(self, func_dict: dict, key: str, func_paths: dict) -> No f"outdir_{key}_ingress", ) elif key == ts_paths_key: - self._init_wf.connect( + self.wf.connect( iterables, "confounds", ingress_func, "inputnode.data" ) self.set_data( @@ -2504,19 +2630,15 @@ def func_outdir_ingress(self, func_dict: dict, key: str, func_paths: dict) -> No def ingress_raw_anat_data(self) -> None: """Ingress raw anatomical data.""" - if "anat" not in self.data_paths: + if not self.data_paths.anat: WFLOGGER.warning("No anatomical data present.") return - anat_flow = create_anat_datasource( - f"anat_T1w_gather_{self.part_id}_{self.ses_id}" - ) + anat_flow = create_anat_datasource(f"anat_T1w_gather_{self.unique_id}") anat = {} - if isinstance(self.data_paths["anat"], str): - anat["T1"] = self.data_paths["anat"] - elif "T1w" in self.data_paths["anat"]: - anat["T1"] = self.data_paths["anat"]["T1w"] + if "T1w" in self.data_paths.anat: + anat["T1"] = self.data_paths.anat["T1w"] if "T1" in anat: anat_flow.inputs.inputnode.set( @@ -2528,13 +2650,13 @@ def ingress_raw_anat_data(self) -> None: ) self.set_data("T1w", anat_flow, "outputspec.anat", {}, "", "anat_ingress") - if "T2w" in self.data_paths["anat"]: + if "T2w" in self.data_paths.anat: anat_flow_T2 = create_anat_datasource( f"anat_T2w_gather_{self.part_id}_{self.ses_id}" ) anat_flow_T2.inputs.inputnode.set( subject=self.part_id, - anat=self.data_paths["anat"]["T2w"], + anat=self.data_paths.anat["T2w"], creds_path=self.creds_path, dl_dir=self.cfg.pipeline_setup["working_directory"]["path"], img_type="anat", @@ -2547,91 +2669,19 @@ def ingress_raw_anat_data(self) -> None: self.ingress_freesurfer() -def initiate_rpool( - wf: pe.Workflow, - cfg: Configuration, - data_paths: Optional[dict] = None, - part_id: Optional[str] = None, -) -> ResourcePool: - """ - Initialize a new ResourcePool. - - data_paths format:: - - {'anat': { - 'T1w': '{T1w path}', - 'T2w': '{T2w path}' - }, - 'creds_path': {None OR path to credentials CSV}, - 'func': { - '{scan ID}': - { - 'scan': '{path to BOLD}', - 'scan_parameters': {scan parameter dictionary} - } - }, - 'site_id': 'site-ID', - 'subject_id': 'sub-01', - 'unique_id': 'ses-1', - 'derivatives_dir': '{derivatives_dir path}'} - """ - # TODO: refactor further, integrate with the ingress_data functionality - # TODO: used for BIDS-Derivatives (below), and possible refactoring of - # TODO: the raw data config to use 'T1w' label instead of 'anat' etc. - - kwargs = {"cfg": cfg, "wf": wf} - if data_paths: - part_id: str = data_paths["subject_id"] - ses_id: str = data_paths["unique_id"] - if "creds_path" not in data_paths: - creds_path = None - else: - creds_path: Optional[Path | str] = data_paths["creds_path"] - unique_id: str = f"{part_id}_{ses_id}" - kwargs.update( - { - "part_id": part_id, - "ses_id": ses_id, - "creds_path": creds_path, - "data_paths": data_paths, - } - ) - elif part_id: - unique_id = part_id - creds_path = None - kwargs.update({"part_id": part_id, "creds_path": creds_path}) - else: - unique_id = "" - kwargs.update({"unique_id": unique_id}) - - rpool = ResourcePool(name=unique_id, **kwargs) - - if data_paths: - # ingress outdir - try: - if ( - data_paths["derivatives_dir"] - and cfg.pipeline_setup["outdir_ingress"]["run"] - ): - rpool.ingress_output_dir() - except (AttributeError, KeyError): - rpool.ingress_raw_anat_data() - if "func" in data_paths: - rpool.ingress_raw_func_data() - - # grab any file paths from the pipeline config YAML - rpool.ingress_pipeconfig_paths() - - # output files with 4 different scans - - return rpool._init_wf, rpool - - class StratPool(ResourcePool): - """All resources for a strategy.""" + """A pool of ResourcePools keyed by strategy.""" - def __init__(self): - """Initialize a ResourcePool.""" + def __init__(self, rpool: Optional[dict[ResourcePool]] = None) -> None: + """Initialize a StratPool.""" + if not rpool: + self.rpool = {} + else: + self.rpool = rpool def append_name(self, name): self.name.append(name) + + def get_strats(self, resources, debug) -> None: + """ResourcePool method that is not valid for a StratPool.""" + raise NotImplementedError diff --git a/CPAC/pipeline/test/test_engine.py b/CPAC/pipeline/test/test_engine.py index 46df0a2dec..8193fc744d 100644 --- a/CPAC/pipeline/test/test_engine.py +++ b/CPAC/pipeline/test/test_engine.py @@ -24,12 +24,8 @@ build_anat_preproc_stack, build_workflow, connect_pipeline, - initialize_nipype_wf, -) -from CPAC.pipeline.engine import ( - initiate_rpool, - ResourcePool, ) +from CPAC.pipeline.engine import ResourcePool from CPAC.utils.bids_utils import create_cpac_data_config from CPAC.utils.configuration import Configuration, Preconfiguration @@ -53,14 +49,8 @@ def test_ingress_func_raw_data( ) -> None: """Test :py:method:~`CPAC.pipeline.engine.resource.ResourcePool.ingress_raw_func_data`.""" cfg, sub_data_dct = _set_up_test(bids_examples, preconfig, tmp_path) - wf = initialize_nipype_wf(cfg, sub_data_dct) - part_id = sub_data_dct["subject_id"] - ses_id = sub_data_dct["unique_id"] - unique_id = f"{part_id}_{ses_id}" - rpool = ResourcePool(name=unique_id, cfg=cfg, data_paths=sub_data_dct, wf=wf) - if "func" in sub_data_dct: - rpool.ingress_raw_func_data() - rpool.gather_pipes(wf, cfg, all=True) + rpool = ResourcePool(cfg=cfg, data_paths=sub_data_dct) + rpool.gather_pipes(rpool.wf, cfg, all=True) @pytest.mark.parametrize("preconfig", ["default"]) @@ -69,21 +59,12 @@ def test_ingress_anat_raw_data( ) -> None: """Test :py:method:~`CPAC.pipeline.engine.resource.ResourcePool.ingress_raw_anat_data`.""" cfg, sub_data_dct = _set_up_test(bids_examples, preconfig, tmp_path) - wf = initialize_nipype_wf(cfg, sub_data_dct) - part_id = sub_data_dct["subject_id"] - ses_id = sub_data_dct["unique_id"] - unique_id = f"{part_id}_{ses_id}" rpool = ResourcePool( - name=unique_id, cfg=cfg, data_paths=sub_data_dct, - unique_id=unique_id, - part_id=part_id, - ses_id=ses_id, - wf=wf, ) rpool.ingress_raw_anat_data() - rpool.gather_pipes(wf, cfg, all=True) + rpool.gather_pipes(rpool.wf, cfg, all=True) @pytest.mark.parametrize("preconfig", ["default"]) @@ -92,20 +73,11 @@ def test_ingress_pipeconfig_data( ) -> None: """Test :py:method:~`CPAC.pipeline.engine.resource.ResourcePool.ingress_pipeconfig_paths`.""" cfg, sub_data_dct = _set_up_test(bids_examples, preconfig, tmp_path) - wf = initialize_nipype_wf(cfg, sub_data_dct) - part_id = sub_data_dct["subject_id"] - ses_id = sub_data_dct["unique_id"] - unique_id = f"{part_id}_{ses_id}" rpool = ResourcePool( - name=unique_id, cfg=cfg, data_paths=sub_data_dct, - part_id=part_id, - ses_id=ses_id, - unique_id=unique_id, ) - rpool.ingress_pipeconfig_paths() - rpool.gather_pipes(wf, cfg, all=True) + rpool.gather_pipes(rpool.wf, cfg, all=True) @pytest.mark.parametrize("preconfig", ["anat-only"]) @@ -115,10 +87,9 @@ def test_build_anat_preproc_stack( """Test :py:func:~`CPAC.pipeline.cpac_pipeline.build_anat_preproc_stack`.""" cfg, sub_data_dct = _set_up_test(bids_examples, preconfig, tmp_path) - wf = initialize_nipype_wf(cfg, sub_data_dct) - rpool = initiate_rpool(wf, cfg, sub_data_dct) + rpool = ResourcePool(cfg=cfg, data_paths=sub_data_dct) pipeline_blocks = build_anat_preproc_stack(rpool, cfg) - wf = connect_pipeline(wf, cfg, rpool, pipeline_blocks) + wf = connect_pipeline(rpool.wf, cfg, rpool, pipeline_blocks) rpool.gather_pipes(wf, cfg) @@ -126,7 +97,6 @@ def test_build_anat_preproc_stack( def test_build_workflow(bids_examples: Path, preconfig: str, tmp_path: Path) -> None: """Test :py:func:~`CPAC.pipeline.cpac_pipeline.build_workflow`.""" cfg, sub_data_dct = _set_up_test(bids_examples, preconfig, tmp_path) - wf = initialize_nipype_wf(cfg, sub_data_dct) - rpool = initiate_rpool(wf, cfg, sub_data_dct) + rpool = ResourcePool(cfg=cfg, data_paths=sub_data_dct) wf = build_workflow(sub_data_dct["subject_id"], sub_data_dct, cfg) rpool.gather_pipes(wf, cfg) diff --git a/CPAC/utils/configuration/configuration.py b/CPAC/utils/configuration/configuration.py index 8444cce105..bcac06df3a 100644 --- a/CPAC/utils/configuration/configuration.py +++ b/CPAC/utils/configuration/configuration.py @@ -622,6 +622,13 @@ def key_type_error(self, key): ) +class EmptyConfiguration(Configuration): + """A Configuration with all methods and no values.""" + + def __init__(self) -> None: + """Initialize an empty configuration.""" + + def check_pname(p_name: str, pipe_config: Configuration) -> str: """Check / set `p_name`, the str representation of a pipeline for use in filetrees.