Skip to content

Commit

Permalink
♻️ Move post_process method back into ResourcePool
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Jul 17, 2024
1 parent 3613f8c commit 4bf5f00
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 148 deletions.
4 changes: 2 additions & 2 deletions CPAC/pipeline/engine/nodeblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

if TYPE_CHECKING:
from CPAC.pipeline.engine.resource import Resource, StratPool
from CPAC.pipeline.engine.resource import ResourceData, StratPool

NODEBLOCK_INPUTS = list[str | list | tuple]
PIPELINE_BLOCKS = list["NodeBlockFunction | PIPELINE_BLOCKS"]
Expand Down Expand Up @@ -101,7 +101,7 @@ def __call__(
strat_pool: "StratPool",
pipe_num: Optional[int | str],
opt: Optional[str] = None,
) -> tuple[Workflow, dict[str, "Resource"]]:
) -> tuple[Workflow, dict[str, "ResourceData"]]:
"""Call a NodeBlockFunction.
All node block functions have the same signature.
Expand Down
304 changes: 158 additions & 146 deletions CPAC/pipeline/engine/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from CPAC.pipeline import nipype_pipeline_engine as pe
from CPAC.pipeline.check_outputs import ExpectedOutputs
from CPAC.pipeline.engine.nodeblock import NodeBlock
from CPAC.pipeline.engine.nodeblock import NodeBlock, NodeBlockFunction
from CPAC.pipeline.utils import name_fork, source_set
from CPAC.registration.registration import transform_derivative
from CPAC.resources.templates.lookup_table import lookup_identifier
Expand Down Expand Up @@ -999,149 +999,6 @@ def filter_name(self, cfg: Configuration) -> str:
return sidecar["CpacVariant"][key][0][::-1].split("_", 1)[0][::-1]
return "none"

def post_process(self, wf, label, connection, json_info, pipe_idx, pipe_x, outs):
input_type = "func_derivative"

post_labels = [(label, connection[0], connection[1])]

if re.match(r"(.*_)?[ed]c[bw]$", label) or re.match(r"(.*_)?lfcd[bw]$", label):
# suffix: [eigenvector or degree] centrality [binarized or weighted]
# or lfcd [binarized or weighted]
mask = "template-specification-file"
elif "space-template" in label:
if "space-template_res-derivative_desc-bold_mask" in self.keys():
mask = "space-template_res-derivative_desc-bold_mask"
else:
mask = "space-template_desc-bold_mask"
else:
mask = "space-bold_desc-brain_mask"

mask_idx = None
for entry in json_info["CpacProvenance"]:
if isinstance(entry, list):
if entry[-1].split(":")[0] == mask:
mask_prov = entry
mask_idx = self.generate_prov_string(mask_prov)[1]
break

if self.smoothing_bool:
if label in Outputs.to_smooth:
for smooth_opt in self.smooth_opts:
sm = spatial_smoothing(
f"{label}_smooth_{smooth_opt}_{pipe_x}",
self.fwhm,
input_type,
smooth_opt,
)
wf.connect(connection[0], connection[1], sm, "inputspec.in_file")
node, out = self.get_data(
mask, pipe_idx=mask_idx, quick_single=mask_idx is None
)
wf.connect(node, out, sm, "inputspec.mask")

if "desc-" not in label:
if "space-" in label:
for tag in label.split("_"):
if "space-" in tag:
smlabel = label.replace(tag, f"{tag}_desc-sm")
break
else:
smlabel = f"desc-sm_{label}"
else:
for tag in label.split("_"):
if "desc-" in tag:
newtag = f"{tag}-sm"
smlabel = label.replace(tag, newtag)
break

post_labels.append((smlabel, sm, "outputspec.out_file"))

self.set_data(
smlabel,
sm,
"outputspec.out_file",
json_info,
pipe_idx,
f"spatial_smoothing_{smooth_opt}",
fork=True,
)
self.set_data(
"fwhm",
sm,
"outputspec.fwhm",
json_info,
pipe_idx,
f"spatial_smoothing_{smooth_opt}",
fork=True,
)

if self.zscoring_bool:
for label_con_tpl in post_labels:
label = label_con_tpl[0]
connection = (label_con_tpl[1], label_con_tpl[2])
if label in Outputs.to_zstd:
zstd = z_score_standardize(f"{label}_zstd_{pipe_x}", input_type)

wf.connect(connection[0], connection[1], zstd, "inputspec.in_file")

node, out = self.get_data(mask, pipe_idx=mask_idx)
wf.connect(node, out, zstd, "inputspec.mask")

if "desc-" not in label:
if "space-template" in label:
new_label = label.replace(
"space-template", "space-template_desc-zstd"
)
else:
new_label = f"desc-zstd_{label}"
else:
for tag in label.split("_"):
if "desc-" in tag:
newtag = f"{tag}-zstd"
new_label = label.replace(tag, newtag)
break

post_labels.append((new_label, zstd, "outputspec.out_file"))

self.set_data(
new_label,
zstd,
"outputspec.out_file",
json_info,
pipe_idx,
"zscore_standardize",
fork=True,
)

elif label in Outputs.to_fisherz:
zstd = fisher_z_score_standardize(
f"{label}_zstd_{pipe_x}", label, input_type
)

wf.connect(
connection[0], connection[1], zstd, "inputspec.correlation_file"
)

# if the output is 'space-template_desc-MeanSCA_correlations', we want 'desc-MeanSCA_timeseries'
oned = label.replace("correlations", "timeseries")

node, out = outs[oned]
wf.connect(node, out, zstd, "inputspec.timeseries_oned")

post_labels.append((new_label, zstd, "outputspec.out_file"))

self.set_data(
new_label,
zstd,
"outputspec.out_file",
json_info,
pipe_idx,
"fisher_zscore_standardize",
fork=True,
)

return wf, post_labels


class ResourcePool(_Pool):
"""A pool of Resources."""
Expand Down Expand Up @@ -2800,7 +2657,7 @@ def connect_block(self, wf: pe.Workflow, block: NodeBlock) -> pe.Workflow: # no
inputs: NODEBLOCK_INPUTS = block.check_null(block_dct["inputs"])
outputs = block.check_null(block_dct["outputs"])

block_function = block_dct["block_function"]
block_function: NodeBlockFunction = block_dct["block_function"]

opts = []
if option_key and option_val:
Expand Down Expand Up @@ -3065,7 +2922,9 @@ def connect_block(self, wf: pe.Workflow, block: NodeBlock) -> pe.Workflow: # no

if self.func_reg:
for postlabel in post_labels:
connection = (postlabel[1], postlabel[2]) # noqa: PLW2901
connection = ResourceData( # noqa: PLW2901
postlabel[1], postlabel[2]
)
wf = self.derivative_xfm(
wf,
postlabel[0],
Expand Down Expand Up @@ -3132,6 +2991,159 @@ def connect_pipeline(

return wf

def post_process(
self,
wf: pe.Workflow,
label: str,
connection: ResourceData | tuple[pe.Node, str],
json_info: dict,
pipe_idx: str | tuple,
pipe_x: int,
outs: dict[str, ResourceData],
) -> tuple[pe.Workflow, list[tuple[str, pe.Node | pe.Workflow, str]]]:
"""Connect smoothing and z-scoring, if configured."""
input_type = "func_derivative"

post_labels = [(label, connection[0], connection[1])]

if re.match(r"(.*_)?[ed]c[bw]$", label) or re.match(r"(.*_)?lfcd[bw]$", label):
# suffix: [eigenvector or degree] centrality [binarized or weighted]
# or lfcd [binarized or weighted]
mask = "template-specification-file"
elif "space-template" in label:
if "space-template_res-derivative_desc-bold_mask" in self.keys():
mask = "space-template_res-derivative_desc-bold_mask"
else:
mask = "space-template_desc-bold_mask"
else:
mask = "space-bold_desc-brain_mask"

mask_idx = None
for entry in json_info["CpacProvenance"]:
if isinstance(entry, list):
if entry[-1].split(":")[0] == mask:
mask_prov = entry
mask_idx = self.generate_prov_string(mask_prov)[1]
break

if self.smoothing_bool:
if label in Outputs.to_smooth:
for smooth_opt in self.smooth_opts:
sm = spatial_smoothing(
f"{label}_smooth_{smooth_opt}_{pipe_x}",
self.fwhm,
input_type,
smooth_opt,
)
wf.connect(connection[0], connection[1], sm, "inputspec.in_file")
node, out = self.get_data(
mask, pipe_idx=mask_idx, quick_single=mask_idx is None
)
wf.connect(node, out, sm, "inputspec.mask")

if "desc-" not in label:
if "space-" in label:
for tag in label.split("_"):
if "space-" in tag:
smlabel = label.replace(tag, f"{tag}_desc-sm")
break
else:
smlabel = f"desc-sm_{label}"
else:
for tag in label.split("_"):
if "desc-" in tag:
newtag = f"{tag}-sm"
smlabel = label.replace(tag, newtag)
break

post_labels.append((smlabel, sm, "outputspec.out_file"))

self.set_data(
smlabel,
sm,
"outputspec.out_file",
json_info,
pipe_idx,
f"spatial_smoothing_{smooth_opt}",
fork=True,
)
self.set_data(
"fwhm",
sm,
"outputspec.fwhm",
json_info,
pipe_idx,
f"spatial_smoothing_{smooth_opt}",
fork=True,
)

if self.zscoring_bool:
for label_con_tpl in post_labels:
label = label_con_tpl[0]
connection = (label_con_tpl[1], label_con_tpl[2])
if label in Outputs.to_zstd:
zstd = z_score_standardize(f"{label}_zstd_{pipe_x}", input_type)

wf.connect(connection[0], connection[1], zstd, "inputspec.in_file")

node, out = self.get_data(mask, pipe_idx=mask_idx)
wf.connect(node, out, zstd, "inputspec.mask")

if "desc-" not in label:
if "space-template" in label:
new_label = label.replace(
"space-template", "space-template_desc-zstd"
)
else:
new_label = f"desc-zstd_{label}"
else:
for tag in label.split("_"):
if "desc-" in tag:
newtag = f"{tag}-zstd"
new_label = label.replace(tag, newtag)
break

post_labels.append((new_label, zstd, "outputspec.out_file"))

self.set_data(
new_label,
zstd,
"outputspec.out_file",
json_info,
pipe_idx,
"zscore_standardize",
fork=True,
)

elif label in Outputs.to_fisherz:
zstd = fisher_z_score_standardize(
f"{label}_zstd_{pipe_x}", label, input_type
)

wf.connect(
connection[0], connection[1], zstd, "inputspec.correlation_file"
)

# if the output is 'space-template_desc-MeanSCA_correlations', we want 'desc-MeanSCA_timeseries'
oned = label.replace("correlations", "timeseries")

node, out = outs[oned]
wf.connect(node, out, zstd, "inputspec.timeseries_oned")

post_labels.append((new_label, zstd, "outputspec.out_file"))

self.set_data(
new_label,
zstd,
"outputspec.out_file",
json_info,
pipe_idx,
"fisher_zscore_standardize",
fork=True,
)

return wf, post_labels

def _get_unlabelled(self, resource: str) -> set[str]:
"""Get unlabelled resources (that need integer suffixes to differentiate)."""
from CPAC.func_preproc.func_motion import motion_estimate_filter
Expand Down

0 comments on commit 4bf5f00

Please sign in to comment.