Skip to content

Commit

Permalink
♻️ Move _check_null from method to private function
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Jul 17, 2024
1 parent 0d848b9 commit 7571455
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
16 changes: 7 additions & 9 deletions CPAC/pipeline/engine/nodeblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from CPAC.pipeline.engine.resource import ResourceData, StratPool

NODEBLOCK_INPUTS = list[str | list | tuple]
NODEBLOCK_OUTPUTS = list[str] | dict[str, Any]
PIPELINE_BLOCKS = list["NodeBlockFunction | PIPELINE_BLOCKS"]


Expand All @@ -46,7 +47,7 @@ def __init__(
option_key: Optional[str | list[str]] = None,
option_val: Optional[str | list[str]] = None,
inputs: Optional[NODEBLOCK_INPUTS] = None,
outputs: Optional[list[str] | dict[str, Any]] = None,
outputs: Optional[NODEBLOCK_OUTPUTS] = None,
) -> None:
self.func = func
"""Nodeblock function reference."""
Expand All @@ -70,9 +71,7 @@ def __init__(
"""
self.option_val: Optional[str | list[str]] = option_val
"""Indicates values for which this NodeBlock should be active."""
if inputs is None:
inputs = []
self.inputs: list[str | list | tuple] = inputs
self.inputs: list[str | list | tuple] = inputs if inputs else []
"""ResourcePool keys indicating resources needed for the NodeBlock's functionality."""
self.outputs: list[str] | dict[str, Any] = outputs if outputs else []
"""
Expand Down Expand Up @@ -218,12 +217,11 @@ def __init__(
config.update_config({"logging": {"workflow_level": "INFO"}})
logging.update_logging(config)

def check_null(self, val):
if isinstance(val, str):
val = None if val.lower() == "none" else val
return val
def check_output(self, outputs: NODEBLOCK_OUTPUTS, label: str, name: str) -> None:
"""Check if a label is listed in a NodeBlock's ``outputs``.
def check_output(self, outputs, label, name):
Raises ``NameError`` if a mismatch is found.
"""
if label not in outputs:
msg = (
f'\n[!] Output name "{label}" in the block '
Expand Down
28 changes: 19 additions & 9 deletions CPAC/pipeline/engine/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from CPAC.pipeline import nipype_pipeline_engine as pe
from CPAC.pipeline.check_outputs import ExpectedOutputs
from CPAC.pipeline.engine.nodeblock import NodeBlock, NodeBlockFunction
from CPAC.pipeline.engine.nodeblock import (
NodeBlock,
NODEBLOCK_INPUTS,
NODEBLOCK_OUTPUTS,
NodeBlockFunction,
)
from CPAC.pipeline.utils import name_fork, source_set
from CPAC.registration.registration import transform_derivative
from CPAC.resources.templates.lookup_table import lookup_identifier
Expand Down Expand Up @@ -2341,8 +2346,6 @@ def ingress_raw_anat_data(self) -> None:

def connect_block(self, wf: pe.Workflow, block: NodeBlock) -> pe.Workflow: # noqa: PLR0912,PLR0915
"""Connect a NodeBlock via the ResourcePool."""
from CPAC.pipeline.engine.nodeblock import NODEBLOCK_INPUTS

debug = bool(self.cfg.pipeline_setup["Debugging"]["verbose"]) # type: ignore [attr-defined]
all_opts: list[str] = []

Expand All @@ -2360,12 +2363,12 @@ def connect_block(self, wf: pe.Workflow, block: NodeBlock) -> pe.Workflow: # no

for name, block_dct in block.node_blocks.items():
# iterates over either the single node block in the sequence, or a list of node blocks within the list of node blocks, i.e. for option forking.
switch = block.check_null(block_dct["switch"])
config = block.check_null(block_dct["config"])
option_key = block.check_null(block_dct["option_key"])
option_val = block.check_null(block_dct["option_val"])
inputs: NODEBLOCK_INPUTS = block.check_null(block_dct["inputs"])
outputs = block.check_null(block_dct["outputs"])
switch = _check_null(block_dct["switch"])
config = _check_null(block_dct["config"])
option_key = _check_null(block_dct["option_key"])
option_val = _check_null(block_dct["option_val"])
inputs: NODEBLOCK_INPUTS = _check_null(block_dct["inputs"])
outputs: NODEBLOCK_OUTPUTS = _check_null(block_dct["outputs"])

block_function: NodeBlockFunction = block_dct["block_function"]

Expand Down Expand Up @@ -3248,3 +3251,10 @@ def filtered_movement(self) -> bool:
except KeyError:
# not a strat_pool or no movement parameters in strat_pool
return False


def _check_null(val: Any) -> Any:
"""Return ``None`` if ``val`` == "none" (case-insensitive)."""
if isinstance(val, str):
val = None if val.lower() == "none" else val
return val

0 comments on commit 7571455

Please sign in to comment.