Skip to content

Commit

Permalink
✅ Add tests for fetch refactor
Browse files Browse the repository at this point in the history
[rebuild base-lite]
[rebuild base-standard]
[run reg-suite]
  • Loading branch information
shnizzedy committed Jul 8, 2024
1 parent 6a5b723 commit 3ebb9f4
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 13 deletions.
1 change: 1 addition & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ external = ["T20"] # Don't autoremove 'noqa` comments for these rules
[lint.per-file-ignores]
"CPAC/func_preproc/func_preproc.py" = ["E402"]
"CPAC/utils/sklearn.py" = ["RUF003"]
"CPAC/utils/tests/old_functions.py" = ["C", "D", "E", "EM", "PLW", "RET"]
"CPAC/utils/utils.py" = ["T201"] # until `repickle` is removed
"setup.py" = ["D1"]

Expand Down
67 changes: 67 additions & 0 deletions CPAC/utils/tests/old_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (C) 2012-2024 C-PAC Developers

# This file is part of C-PAC.

# C-PAC is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.

# C-PAC is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.

# You should have received a copy of the GNU Lesser General Public
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
"""Functions from before refactoring."""


def check(params_dct, subject_id, scan_id, val_to_check, throw_exception):
"""https://github.com/FCP-INDI/C-PAC/blob/96db8b0b65ab1d5f55fb3b895855af34d72c17e4/CPAC/utils/utils.py#L630-L653"""
if val_to_check not in params_dct:
if throw_exception:
raise Exception(
f"Missing Value for {val_to_check} for participant " f"{subject_id}"
)
return None
if isinstance(params_dct[val_to_check], dict):
ret_val = params_dct[val_to_check][scan_id]
else:
ret_val = params_dct[val_to_check]
if ret_val == "None":
if throw_exception:
raise Exception(
f"'None' Parameter Value for {val_to_check} for participant "
f"{subject_id}"
)
else:
ret_val = None
if ret_val == "" and throw_exception:
raise Exception(
f"Missing Value for {val_to_check} for participant " f"{subject_id}"
)
return ret_val


def check2(val):
"""https://github.com/FCP-INDI/C-PAC/blob/96db8b0b65ab1d5f55fb3b895855af34d72c17e4/CPAC/utils/utils.py#L745-L746"""
return val if val == None or val == "" or isinstance(val, str) else int(val)


def try_fetch_parameter(scan_parameters, subject, scan, keys):
"""https://github.com/FCP-INDI/C-PAC/blob/96db8b0b65ab1d5f55fb3b895855af34d72c17e4/CPAC/utils/utils.py#L679-L703"""
scan_parameters = dict((k.lower(), v) for k, v in scan_parameters.items())
for key in keys:
key = key.lower()
if key not in scan_parameters:
continue
if isinstance(scan_parameters[key], dict):
value = scan_parameters[key][scan]
else:
value = scan_parameters[key]
if value == "None":
return None
if value is not None:
return value
return None
52 changes: 47 additions & 5 deletions CPAC/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from CPAC.pipeline.nodeblock import NodeBlockFunction
from CPAC.utils.configuration import Configuration
from CPAC.utils.monitoring.custom_logging import log_subprocess
from CPAC.utils.tests import old_functions
from CPAC.utils.utils import (
check_config_resources,
check_system_deps,
Expand All @@ -30,11 +31,19 @@
"tr": 2.5,
"acquisition": "seq+z",
"reference": "24",
"first_tr": "",
"last_tr": "",
"first_TR": 1,
"last_TR": "",
},
"expected_TR": 2.5,
},
"nested": {
"params": {
"TR": {"scan": 3},
"first_TR": {"scan": 0},
"last_TR": {"scan": 450},
},
"expected_TR": 3,
},
}


Expand Down Expand Up @@ -78,7 +87,7 @@ def test_check_config_resources():
assert "threads available (2)" in error_string


@pytest.mark.parametrize("scan_params", ["BIDS", "C-PAC"])
@pytest.mark.parametrize("scan_params", ["BIDS", "C-PAC", "nested"])
@pytest.mark.parametrize("convert_to", [int, float, str])
def test_fetch_and_convert(
caplog: LogCaptureFixture, scan_params: str, convert_to: type
Expand All @@ -89,8 +98,25 @@ def test_fetch_and_convert(
keys=["TR", "RepetitionTime"],
convert_to=convert_to,
)
assert (TR == convert_to(SCAN_PARAMS[scan_params]["expected_TR"])) and isinstance(
TR, convert_to
if TR and "RepetitionTime" in params.params:
old_TR = convert_to(
old_functions.check(
params.params, params.subject, params.scan, "RepetitionTime", False
)
)
assert TR == old_TR
try:
old_TR = convert_to(
old_functions.try_fetch_parameter(
params.params, params.subject, params.scan, ["TR", "RepetitionTime"]
)
)
except TypeError:
old_TR = None
assert (
(TR == convert_to(SCAN_PARAMS[scan_params]["expected_TR"]))
and isinstance(TR, convert_to)
and TR == old_TR
)
if scan_params == "C-PAC":
assert "Using case-insenitive match: 'TR' ≅ 'tr'." in caplog.text
Expand All @@ -101,6 +127,22 @@ def test_fetch_and_convert(
convert_to=convert_to,
)
assert not_TR is None
if "first_TR" in params.params:
first_tr = params.fetch_and_convert(["first_TR"], int, 1, False)
old_first_tr = old_functions.check(
params.params, params.subject, params.scan, "first_TR", False
)
if old_first_tr:
old_first_tr = old_functions.check2(old_first_tr)
assert first_tr == old_first_tr
if "last_TR" in params.params:
last_tr = params.fetch_and_convert(["last_TR"], int, "", False)
old_last_tr = old_functions.check(
params.params, params.subject, params.scan, "last_TR", False
)
if old_last_tr:
old_last_tr = old_functions.check2(old_last_tr)
assert last_tr == old_last_tr


@pytest.mark.parametrize("executable", ["Xvfb"])
Expand Down
27 changes: 19 additions & 8 deletions CPAC/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,17 @@ def fetch(
keys: Optional[list[str]] = None,
*,
match_case: Literal[False],
throw_exception: bool,
) -> Any: ...
@overload
def fetch(
self,
keys: Optional[list[str]] = None,
*,
match_case: Literal[True],
throw_exception: bool,
) -> tuple[Any, tuple[str, str]]: ...
def fetch(self, keys, *, match_case=False):
def fetch(self, keys, *, match_case=False, throw_exception=True):
"""Fetch the first found parameter from a scan params dictionary.
Returns
Expand All @@ -551,6 +553,9 @@ def fetch(self, keys, *, match_case=False):
keys, optional
The matched keys (only if ``match_case is True``)
throw_exception
Raise an exception if value is ``""`` or ``None``?
"""
if match_case:
keys = {key.lower(): key for key in keys}
Expand All @@ -561,11 +566,11 @@ def fetch(self, keys, *, match_case=False):
for key in keys:
if key in scan_parameters:
if match_case:
return self.check(key, True), (
return self.check(key, throw_exception), (
keys[key],
scan_param_keys[key],
)
return self.check(key, True)
return self.check(key, throw_exception)
msg = f"None of {keys} found in {list(scan_parameters.keys())}."
raise KeyError(msg)

Expand All @@ -575,6 +580,7 @@ def fetch_and_convert(
convert_to: Optional[type] = None,
fallback: Optional[Any] = None,
warn_typeerror: bool = True,
throw_exception: bool = False,
) -> Any:
"""Fetch a parameter from a scan params dictionary and convert it to a given type.
Expand All @@ -595,6 +601,9 @@ def fetch_and_convert(
warn_typeerror
log a warning if value cannot be converted to ``convert_to`` type?
throw_exception
raise an error for empty string or NoneTypes?
Returns
-------
value
Expand All @@ -605,10 +614,12 @@ def fetch_and_convert(
fallback_message = f"Falling back to {fallback} ({type(fallback)})."

try:
raw_value = self.fetch(keys)
raw_value = self.fetch(keys, throw_exception=throw_exception)
except KeyError:
try:
raw_value, matched_keys = self.fetch(keys, match_case=True)
raw_value, matched_keys = self.fetch(
keys, match_case=True, throw_exception=throw_exception
)
except KeyError:
WFLOGGER.warning(
f"None of {keys} found in {list(self.params.keys())}. "
Expand All @@ -622,7 +633,7 @@ def fetch_and_convert(
if convert_to:
try:
value = convert_to(raw_value)
except TypeError:
except (TypeError, ValueError):
if warn_typeerror:
WFLOGGER.warning(
f"Could not convert {value} to {convert_to}. {fallback_message}"
Expand Down Expand Up @@ -820,10 +831,10 @@ def get_scan_params(
)
ref_slice: Optional[int | str] = params.fetch_and_convert(["reference"], int, None)
first_tr: Optional[int | str] = params.fetch_and_convert(
["first_TR"], int, pipeconfig_start_indx
["first_TR"], int, pipeconfig_start_indx, False
)
last_tr: Optional[int | str] = params.fetch_and_convert(
["last_TR"], int, pipeconfig_stop_indx
["last_TR"], int, pipeconfig_stop_indx, False
)
pe_direction: PE_DIRECTION = params.fetch_and_convert(
["PhaseEncodingDirection"], str, ""
Expand Down

0 comments on commit 3ebb9f4

Please sign in to comment.