Skip to content

Commit

Permalink
👔 Use max_iter parameter for longitudinal template generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Nov 13, 2024
1 parent 9d09358 commit dc4c2df
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 78 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Required positional parameter "wf" in input and output of `ingress_pipeconfig_paths` function, where a node to reorient templates is added to the `wf`.
- Required positional parameter "orientation" to `resolve_resolution`.
- Optional positional argument "cfg" to `create_lesion_preproc`.
- Added `mri_robust_template` for longitudinal template generation.
- `mri_robust_template` for longitudinal template generation.
- `max_iter` parameter for longitudinal template generation.

### Changed

Expand Down
85 changes: 46 additions & 39 deletions CPAC/longitudinal/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from collections import Counter
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing.pool import Pool
import os
from typing import Literal, Optional

import numpy as np
import nibabel as nib
Expand Down Expand Up @@ -131,27 +133,23 @@ def norm_transformation(input_mat):


def template_convergence(
mat_file, mat_type="matrix", convergence_threshold=np.finfo(np.float64).eps
):
mat_file: str,
mat_type: Literal["matrix", "ITK"] = "matrix",
convergence_threshold: float | np.float64 = np.finfo(np.float64).eps,
) -> bool:
"""Check that the deistance between matrices is smaller than the threshold.
Calculate the distance between transformation matrix with a matrix of no transformation.
Parameters
----------
mat_file : str
mat_file
path to an fsl flirt matrix
mat_type : str
'matrix'(default), 'ITK'
mat_type
The type of matrix used to represent the transformations
convergence_threshold : float
(numpy.finfo(np.float64).eps (default)) threshold for the convergence
convergence_threshold
The threshold is how different from no transformation is the
transformation matrix.
Returns
-------
bool
"""
if mat_type == "matrix":
translation, oth_transform = read_mat(mat_file)
Expand Down Expand Up @@ -347,50 +345,51 @@ def flirt_node(in_img, output_img, output_mat):


def template_creation_flirt(
input_brain_list,
input_skull_list,
init_reg=None,
avg_method="median",
dof=12,
interp="trilinear",
cost="corratio",
mat_type="matrix",
convergence_threshold=-1,
thread_pool=2,
unique_id_list=None,
):
input_brain_list: list[str],
input_skull_list: list[str],
init_reg: Optional[list[pe.Node]] = None,
avg_method: Literal["median", "mean", "std"] = "median",
dof: Literal[12, 9, 7, 6] = 12,
interp: Literal["trilinear", "nearestneighbour", "sinc", "spline"] = "trilinear",
cost: Literal[
"corratio", "mutualinfo", "normmi", "normcorr", "leastsq", "labeldiff", "bbr"
] = "corratio",
mat_type: Literal["matrix", "ITK"] = "matrix",
convergence_threshold: float | np.float64 = -1,
max_iter: int = 5,
thread_pool: int | Pool = 2,
unique_id_list: Optional[list[str]] = None,
) -> tuple[str, str, list[str], list[str], list[str]]:
"""Create a temporary template from a list of images.
Parameters
----------
input_brain_list : list of str
input_brain_list
list of brain images paths
input_skull_list : list of str
input_skull_list
list of skull images paths
init_reg : list of Node
init_reg
(default None so no initial registration performed)
the output of the function register_img_list with another reference
Reuter et al. 2012 (NeuroImage) section "Improved template estimation"
doi:10.1016/j.neuroimage.2012.02.084 uses a ramdomly
selected image from the input dataset
avg_method : str
function names from numpy library such as 'median', 'mean', 'std' ...
dof : integer (int of long)
number of transform degrees of freedom (FLIRT) (12 by default)
interp : str
('trilinear' (default) or 'nearestneighbour' or 'sinc' or 'spline')
avg_method
function names from numpy library
dof
number of transform degrees of freedom (FLIRT)
interp
final interpolation method used in reslicing
cost : str
('mutualinfo' or 'corratio' (default) or 'normcorr' or 'normmi' or
'leastsq' or 'labeldiff' or 'bbr')
cost
cost function
mat_type : str
'matrix'(default), 'ITK'
mat_type
The type of matrix used to represent the transformations
convergence_threshold : float
convergence_threshold
(numpy.finfo(np.float64).eps (default)) threshold for the convergence
The threshold is how different from no transformation is the
transformation matrix.
max_iter
Maximum number of iterations if transformation does not converge
thread_pool : int or multiprocessing.dummy.Pool
(default 2) number of threads. You can also provide a Pool so the
node will be added to it to be run.
Expand Down Expand Up @@ -496,7 +495,14 @@ def template_creation_flirt(
and the loop stops when this temporary template is close enough (with a transformation
distance smaller than the threshold) to all the images of the precedent iteration.
"""
while not converged:
iterator = 1
iteration = 0
if max_iter == -1:
# make iteration < max_iter always True
iterator = 0
iteration = -2
while not converged and iteration < max_iter:
iteration += iterator
temporary_brain_template, temporary_skull_template = create_temporary_template(
input_brain_list=output_brain_list,
input_skull_list=output_skull_list,
Expand Down Expand Up @@ -628,6 +634,7 @@ def subject_specific_template(
"cost",
"mat_type",
"convergence_threshold",
"max_iter",
"thread_pool",
"unique_id_list",
],
Expand Down
5 changes: 3 additions & 2 deletions CPAC/longitudinal/wf/anat.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,9 @@ def anat_longitudinal_wf(
interp=config.longitudinal_template_generation["legacy-specific"]["interp"],
cost=config.longitudinal_template_generation["legacy-specific"]["cost"],
convergence_threshold=config.longitudinal_template_generation[
"convergence_threshold"
],
"legacy-specific"
]["convergence_threshold"],
max_iter=config.longitudinal_template_generation["max_iter"],
thread_pool=config.longitudinal_template_generation["legacy-specific"][
"thread_pool"
],
Expand Down
77 changes: 51 additions & 26 deletions CPAC/pipeline/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from itertools import chain, permutations
import re
from subprocess import CalledProcessError
from typing import Any as AnyType

import numpy as np
from pathvalidate import sanitize_filename
Expand Down Expand Up @@ -851,28 +852,34 @@ def sanitize(filename):
"using": In({"mri_robust_template", "C-PAC legacy"}),
"average_method": In({"median", "mean", "std"}),
"dof": In({12, 9, 7, 6}),
"convergence_threshold": Number,
"max_iter": int,
"max_iter": Any(
All(Number, Range(min=0, min_included=False)), In([-1, "default"])
),
"legacy-specific": Maybe(
{
"interp": Maybe(
In({"trilinear", "nearestneighbour", "sinc", "spline"})
),
"cost": Maybe(
In(
{
"corratio",
"mutualinfo",
"normmi",
"normcorr",
"leastsq",
"labeldiff",
"bbr",
}
)
),
"thread_pool": Maybe(int),
}
Schema(
{
"convergence_threshold": Any(
All(Number, Range(min=0, max=1, min_included=False)), -1
),
"interp": Maybe(
In({"trilinear", "nearestneighbour", "sinc", "spline"})
),
"cost": Maybe(
In(
{
"corratio",
"mutualinfo",
"normmi",
"normcorr",
"leastsq",
"labeldiff",
"bbr",
}
)
),
"thread_pool": Maybe(int),
}
)
),
},
"functional_preproc": {
Expand Down Expand Up @@ -1266,6 +1273,20 @@ def sanitize(filename):
)


def check_unimplemented(
to_check: dict[str, AnyType], k_v_pairs: list[tuple[str, AnyType]], category: str
) -> None:
"""Check for unimplemented combinations in subschema.
Raise NotImplementedError if any found.
"""
error_msg = "`{value}` is not implemented for {category} `{key}`."
for key, value in k_v_pairs:
if to_check[key] == value:
msg = error_msg.format(category=category, key=key, value=value)
raise NotImplementedError(msg)


def schema(config_dict):
"""Validate a participant-analysis pipeline configuration.
Expand Down Expand Up @@ -1432,11 +1453,15 @@ def schema(config_dict):
# check for incompatible longitudinal options
lgt = partially_validated["longitudinal_template_generation"]
if lgt["using"] == "mri_robust_template":
error_msg = "{value} is not implemented for longitudinal {key} in `mri_robust_template`."
for key, value in [("average_method", "std"), ("dof", 9), ("max_iter", -1)]:
if lgt[key] == value:
msg = error_msg.format(key=key, value=value)
raise NotImplementedError(msg)
check_unimplemented(
lgt,
[("average_method", "std"), ("dof", 9), ("max_iter", -1)],
"longitudinal `mri_robust_template`",
)
if lgt["using"] == "C-PAC legacy":
check_unimplemented(
lgt, [("max_iter", "default")], "C-PAC legacy longitudinal"
)
except KeyError:
pass
return partially_validated
Expand Down
11 changes: 6 additions & 5 deletions CPAC/resources/configs/pipeline_config_blank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1528,18 +1528,19 @@ longitudinal_template_generation:
# Additional option if using "C-PAC legacy": 9 (traditional)
dof: 12

# Threshold of transformation distance to consider that the loop converged
# (-1 means numpy.finfo(np.float64).eps and is the default)
convergence_threshold: -1

# Maximum iterations
# Stop after this many iterations, even if still above convergence_threshold
# Additional option if using "mri_robust_template": "default" means 5 for 2 sessions, 6 for more than 2 sessions
# Additional option if using "C-PAC legacy": -1 means loop forever until reaching convergence threshold
max_iter: 5
max_iter: default

# Options for C-PAC legacy implementation that are not configurable in mri_robust_template
legacy-specific:

# Threshold of transformation distance to consider that the loop converged
# (-1 means numpy.finfo(np.float64).eps and is the default)
convergence_threshold: -1

# Interpolation parameter for FLIRT in the template creation
# Options: trilinear, nearestneighbour, sinc or spline
interp:
Expand Down
11 changes: 6 additions & 5 deletions CPAC/resources/configs/pipeline_config_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,19 @@ longitudinal_template_generation:
# Additional option if using "C-PAC legacy": 9 (traditional)
dof: 12

# Threshold of transformation distance to consider that the loop converged
# (-1 means numpy.finfo(np.float64).eps and is the default)
convergence_threshold: -1

# Maximum iterations
# Stop after this many iterations, even if still above convergence_threshold
# Additional option if using "mri_robust_template": "default" means 5 for 2 sessions, 6 for more than 2 sessions
# Additional option if using "C-PAC legacy": -1 means loop forever until reaching convergence threshold
max_iter: 5
max_iter: 6

# Options for C-PAC legacy implementation that are not configurable in mri_robust_template
legacy-specific:

# Threshold of transformation distance to consider that the loop converged
# (-1 means numpy.finfo(np.float64).eps and is the default)
convergence_threshold: -1

# Interpolation parameter for FLIRT in the template creation
# Options: trilinear, nearestneighbour, sinc or spline
interp: trilinear
Expand Down

0 comments on commit dc4c2df

Please sign in to comment.