Skip to content

Commit

Permalink
♻️ Move get_cpac_provenance and regressor_dct into StratPool
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Jul 17, 2024
1 parent 52c38bf commit f2423a2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 56 deletions.
2 changes: 1 addition & 1 deletion CPAC/nuisance/nuisance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,7 +2681,7 @@ def nuisance_regression(wf, cfg, strat_pool, pipe_num, opt, space, res=None):
outputs : dict
"""
opt = strat_pool.regressor_dct(cfg)
opt = strat_pool.regressor_dct
bandpass = "Bandpass" in opt
bandpass_before = (
bandpass
Expand Down
111 changes: 56 additions & 55 deletions CPAC/pipeline/engine/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def __init__(self) -> None:
self.unique_id: str
self.zscoring_bool: bool
self.wf: pe.Workflow
self._regressor_dct: dict

def __repr__(self) -> str:
"""Return reproducible ResourcePool string."""
Expand Down Expand Up @@ -647,43 +646,6 @@ def get_resource_from_prov(prov: LIST_OF_LIST_OF_STR) -> Optional[str]:
return prov[-1].split(":")[0]
return None

def regressor_dct(self, cfg) -> dict:
"""Return the regressor dictionary for the current strategy if one exists.
Raises KeyError otherwise.
"""
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "_regressor_dct"): # memoized
# pylint: disable=access-member-before-definition
return self._regressor_dct
key_error = KeyError(
"[!] No regressors in resource pool. \n\n"
"Try turning on create_regressors or "
"ingress_regressors."
)
_nr = cfg["nuisance_corrections", "2-nuisance_regression"]
if not hasattr(self, "timeseries"):
if _nr["Regressors"]:
self.regressors = {reg["Name"]: reg for reg in _nr["Regressors"]}
else:
self.regressors = []
if self.check_rpool("parsed_regressors"): # ingressed regressor
# name regressor workflow without regressor_prov
strat_name = _nr["ingress_regressors"]["Regressors"]["Name"]
if strat_name in self.regressors:
self._regressor_dct = self.regressors[strat_name]
return self._regressor_dct
self._regressor_dct = _nr["ingress_regressors"]["Regressors"]
return self._regressor_dct
prov = self.get_cpac_provenance("desc-confounds_timeseries")
strat_name_components = prov[-1].split("_")
for _ in list(range(prov[-1].count("_"))):
reg_name = "_".join(strat_name_components[-_:])
if isinstance(self.regressors, dict) and reg_name in self.regressors:
self._regressor_dct = self.regressors[reg_name]
return self._regressor_dct
raise key_error

def set_data(
self,
resource: str,
Expand Down Expand Up @@ -830,20 +792,6 @@ def get_json(self, resource, strat=None):
raise Exception(msg)
return strat_json

def get_cpac_provenance(
self, resource: list[str] | str, strat: Optional[str | list | tuple] = None
) -> list:
# NOTE: strat_resource has to be entered properly by the developer
# it has to either be rpool[resource][strat] or strat_pool[resource]
if isinstance(resource, list):
for _resource in resource:
try:
return self.get_cpac_provenance(_resource, strat)
except KeyError:
continue
json_data = self.get_json(resource, strat)
return json_data["CpacProvenance"]


class ResourcePool(_Pool):
"""A pool of Resources."""
Expand Down Expand Up @@ -1611,7 +1559,7 @@ def get_strats( # noqa: PLR0912,PLR0915
# make the merged strat label from the multiple inputs
# strat_list is actually the merged CpacProvenance lists
pipe_idx = str(strat_list)
new_strats[pipe_idx] = StratPool(name=pipe_idx)
new_strats[pipe_idx] = StratPool(name=pipe_idx, cfg=self.cfg)
# new_strats is A DICTIONARY OF StratPool OBJECTS!
new_strats[pipe_idx].json = {"CpacProvenance": strat_list}

Expand Down Expand Up @@ -1647,7 +1595,7 @@ def get_strats( # noqa: PLR0912,PLR0915
strat_resource = self.rpool[resource][pipe_idx]
# remember, `strat_resource` is a Resource.
new_strats[pipe_idx] = StratPool(
rpool={resource: strat_resource}, name=pipe_idx
rpool={resource: strat_resource}, name=pipe_idx, cfg=self.cfg
) # <----- again, new_strats is A DICTIONARY OF StratPool OBJECTS!
new_strats[pipe_idx].json = strat_resource.json
new_strats[pipe_idx].json["subjson"] = {}
Expand Down Expand Up @@ -3142,8 +3090,9 @@ class StratPool(_Pool):

def __init__(
self,
rpool: Optional[dict] = None,
cfg: Configuration,
*,
rpool: Optional[dict] = None,
name: str | list[str] = "",
) -> None:
"""Initialize a StratPool."""
Expand All @@ -3153,9 +3102,11 @@ def __init__(
else:
self.rpool = STRAT_DICT(rpool)
self._json: dict[str, dict] = {"subjson": {}}
self.cfg = cfg
if not isinstance(name, list):
name = [name]
self.name: list[str] = name
self._regressor_dct: dict

def append_name(self, name: str) -> None:
"""Append a name to the StratPool."""
Expand Down Expand Up @@ -3256,6 +3207,18 @@ def get_data(self, resource, report_fetched=False):
doc="""Return a deep copy of strategy-specific JSON.""",
)

def get_cpac_provenance(self, resource: list[str] | str) -> list:
"""Get CpacProvenance for a given Resource."""
# NOTE: strat_resource has to be entered properly by the developer
# it has to either be rpool[resource][strat] or strat_pool[resource]
if isinstance(resource, list):
for _resource in resource:
try:
return self.get_cpac_provenance(_resource)
except KeyError:
continue
return self.get(resource).cpac_provenance

def filter_name(self, cfg: Configuration) -> str:
"""
Return the name of the filter for this strategy.
Expand Down Expand Up @@ -3295,6 +3258,44 @@ def preserve_json_info(self, resource: str, strat_resource: Resource) -> None:
self._json["subjson"][data_type] = {}
self._json["subjson"][data_type].update(strat_resource.json)

@property
def regressor_dct(self) -> dict:
"""Return the regressor dictionary for the current strategy if one exists.
Raises KeyError otherwise.
"""
# pylint: disable=attribute-defined-outside-init
if hasattr(self, "_regressor_dct"): # memoized
# pylint: disable=access-member-before-definition
return self._regressor_dct
key_error = KeyError(
"[!] No regressors in resource pool. \n\n"
"Try turning on create_regressors or "
"ingress_regressors."
)
_nr = self.cfg["nuisance_corrections", "2-nuisance_regression"]
if not hasattr(self, "timeseries"):
if _nr["Regressors"]:
self.regressors = {reg["Name"]: reg for reg in _nr["Regressors"]}
else:
self.regressors = []
if self.check_rpool("parsed_regressors"): # ingressed regressor
# name regressor workflow without regressor_prov
strat_name = _nr["ingress_regressors"]["Regressors"]["Name"]
if strat_name in self.regressors:
self._regressor_dct = self.regressors[strat_name]
return self._regressor_dct
self._regressor_dct = _nr["ingress_regressors"]["Regressors"]
return self._regressor_dct
prov = self.get_cpac_provenance("desc-confounds_timeseries")
strat_name_components = prov[-1].split("_")
for _ in list(range(prov[-1].count("_"))):
reg_name = "_".join(strat_name_components[-_:])
if isinstance(self.regressors, dict) and reg_name in self.regressors:
self._regressor_dct = self.regressors[reg_name]
return self._regressor_dct
raise key_error

@property
def filtered_movement(self) -> bool:
"""Check if the movement parameters have been filtered in this StratPool."""
Expand Down

0 comments on commit f2423a2

Please sign in to comment.