Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
arg_types.py
- update unquote_str docstring

common.py
- remove the obsolete removesuffix function

run_prediction.py
- remove the obsolete removesuffix function

quick_qc.py
- Fix the docstring of get_region_by_intersection_mask

generate_hdf5.py
- replace seg-mask with sag_mask

interpolation_layer.py
- Update of the crop_position property

common.py
- Update the docstrings of pipeline and iterate
  • Loading branch information
dkuegler committed Feb 8, 2024
1 parent e1f72e4 commit bb0ea35
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 107 deletions.
4 changes: 2 additions & 2 deletions FastSurferCNN/generate_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, params: Dict, processing: str = "aparc"):
self.gm_mask = params["gm_mask"]

self.lut = read_classes_from_lut(params["lut"])
self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag-mask"])
self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag_mask"])
self.lateralization = unify_lateralized_labels(self.lut, params["combi"])

if params["csv_file"] is not None:
Expand Down Expand Up @@ -540,7 +540,7 @@ def create_hdf5_dataset(self, blt: int):
"plane": args.plane,
"lut": args.lut,
"combi": args.combi,
"sag-mask": args.sag_mask,
"sag_mask": args.sag_mask,
"hires_weight": args.hires_w,
"gm_mask": args.gm,
"gradient": not args.no_grad,
Expand Down
73 changes: 57 additions & 16 deletions FastSurferCNN/models/interpolation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ class _ZoomNd(nn.Module):
Attributes
----------
_mode
Interpolation mode as in `torch.nn.interpolate` (default: 'neareast').
(Protected) Interpolation mode as in `torch.nn.interpolate` (default: 'neareast').
_target_shape
Target tensor size for after this module,
(Protected) Target tensor size for after this module,
not including batchsize and channels.
_N
Number of dimensions.
(Protected) Internal number of dimensions.
Methods
-------
forward
Forward propagation.
_fix_scale_factors
Checking and fixing the conformity of scale_factors.
(Protected) Checking and fixing the conformity of scale_factors.
_interpolate
Abstract method.
-calculate_crop_pad
Return start- and end- coordinate.
(Protected) Abstract method.
_calculate_crop_pad
(Protected) Return start- and end- coordinate.
"""

def __init__(
Expand Down Expand Up @@ -376,6 +376,8 @@ class Zoom2d(_ZoomNd):
(Protected) Crops, interpolates and pads the tensor.
"""

_crop_position: str

def __init__(
self,
target_shape: _T.Optional[_T.Sequence[int]],
Expand All @@ -398,6 +400,28 @@ def __init__(
if interpolation_mode not in ["nearest", "bilinear", "bicubic", "area"]:
raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}")

self._N = 2
super(Zoom2d, self).__init__(target_shape, interpolation_mode)
self.crop_position = crop_position

@property
def crop_position(self) -> str:
"""
Property associated with the position of the image in the data.
"""
return self._crop_position

@crop_position.setter
def crop_position(self, crop_position: str) -> None:
"""
Set the crop position.
Parameters
----------
crop_position : str
The crop position key from 'top_left', 'bottom_left', top_right',
'bottom_right', 'center'.
"""
if crop_position not in [
"top_left",
"bottom_left",
Expand All @@ -406,11 +430,8 @@ def __init__(
"center",
]:
raise ValueError(f"invalid crop_position, got {crop_position}")

self._N = 2
super(Zoom2d, self).__init__(target_shape, interpolation_mode)
self.crop_position = crop_position

self._crop_position = crop_position

def _interpolate(
self,
data: Tensor,
Expand Down Expand Up @@ -514,6 +535,29 @@ def __init__(
if interpolation_mode not in ["nearest", "trilinear", "area"]:
raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}")

self._N = 3
super(Zoom3d, self).__init__(target_shape, interpolation_mode)
self.crop_position = crop_position

@property
def crop_position(self) -> str:
"""
Property associated with the position of the image in the data.
"""
return self._crop_position

@crop_position.setter
def crop_position(self, crop_position: str) -> None:
"""
Set the crop position.
Parameters
----------
crop_position : str
Crop position to use from 'front_top_left', 'back_top_left',
'front_bottom_left', 'back_bottom_left', 'front_top_right', 'back_top_right',
'front_bottom_right', 'back_bottom_right', 'center' (default: 'front_top_left').
"""
if crop_position not in [
"front_top_left",
"back_top_left",
Expand All @@ -526,11 +570,8 @@ def __init__(
"center",
]:
raise ValueError(f"invalid crop_position, got {crop_position}")

self._N = 3
super(Zoom3d, self).__init__(target_shape, interpolation_mode)
self._crop_position = crop_position

def _interpolate(
self, data: Tensor, scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[int]]
):
Expand Down
20 changes: 12 additions & 8 deletions FastSurferCNN/quick_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def check_volume(asegdkt_segfile:np.ndarray, voxvol: float, thres: float = 0.70)
def get_region_bg_intersection_mask(
seg_array, region_labels=VENT_LABELS, bg_label=BG_LABEL
):
"""
f"""
Return a mask of the intersection between the voxels of a given region and background voxels.
This is obtained by dilating the region by 1 voxel and computing the intersection with the
Expand All @@ -112,18 +112,22 @@ def get_region_bg_intersection_mask(
----------
seg_array : numpy.ndarray
Segmentation array.
region_labels : Dict
Dictionary whose values correspond to the desired region's labels (Default value = VENT_LABELS).
VENT_LABELS is a dictionary containing labels for different regions related to the ventricles,
such as "Left-Lateral-Ventricle", "Right-Lateral-Ventricle", etc., along with their
corresponding numeric values.
bg_label : int, default="BG_LABEL"
(Default value = BG_LABEL).
region_labels : dict, default=<dict VENT_LABELS>
Dictionary whose values correspond to the desired region's labels (see Note).
bg_label : int, default={BG_LABEL}
(Default value = {BG_LABEL}).
Returns
-------
bg_intersect : numpy.ndarray
Region and background intersection mask array.
Notes
-----
VENT_LABELS is a dictionary containing labels for four regions related to the ventricles:
"Left-Lateral-Ventricle", "Right-Lateral-Ventricle", "Left-choroid-plexus",
"Right-choroid-plexus" along with their corresponding integer label values
(see also FreeSurferColorLUT.txt).
"""
region_array = seg_array.copy()
conditions = np.all(
Expand Down
31 changes: 0 additions & 31 deletions FastSurferCNN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,37 +107,6 @@ def args2cfg(
return cfg_fin, cfg_cor, cfg_sag, cfg_ax


def removesuffix(string: str, suffix: str) -> str:
"""
Remove a suffix from a string.
Similar to string.removesuffix in PY3.9+,
Parameters
----------
string : str
String to be cut.
suffix : str
Suffix to be removed.
Returns
-------
Str
Suffix removed string.
"""
import sys

if sys.version_info.minor >= 9:
# removesuffix is a Python3.9 feature
return string.removesuffix(suffix)
else:
return (
string[: -len(suffix)]
if len(suffix) > 0 and string.endswith(suffix)
else string
)


##
# Input array preparation
##
Expand Down
4 changes: 2 additions & 2 deletions FastSurferCNN/utils/arg_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def int_ge_zero(value) -> int:

def unquote_str(value) -> str:
"""
Unquote a (single quoted) string.
Unquote a (single quoted) string, i.e. remove one level of single-quotes.
Parameters
----------
Expand All @@ -183,7 +183,7 @@ def unquote_str(value) -> str:
Returns
-------
val : str
A string of the value without quoting with '''.
A string of the value without leading and trailing single-quotes.
"""
val = str(value)
if val.startswith("'") and val.endswith("'"):
Expand Down
63 changes: 15 additions & 48 deletions FastSurferCNN/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"iterate",
"NoParallelExecutor",
"pipeline",
"removesuffix",
"SubjectList",
"SubjectDirectory",
]
Expand Down Expand Up @@ -190,20 +189,19 @@ def pipeline(
----------
pool : Executor
Thread pool executor for parallel execution.
func : Callable[[_Ti], _T] :
func : callable
Function to use.
iterable : Iterable[_Ti]
iterable : Iterable
Iterable containing input elements.
pipeline_size : int
Size of the processing pipeline (default is 1).
pipeline_size : int, default=1
Size of the processing pipeline.
Yields
------
Iterator[Tuple[_Ti, _T]]
Yielding elements from iterable and corresponding results from func(element).
element : _Ti
Elements
_T
Results of func corresponding to element: func(element).
"""
# do pipeline loading the next element
from collections import deque
Expand Down Expand Up @@ -231,54 +229,23 @@ def iterate(
Parameters
----------
pool : Executor
Callable.
func : Iterable
The Executor object (dummy object to have a common API with pipeline).
func : callable
Function to use.
iterable : Iterable[_Ti]
Iterable.
iterable : Iterable
Iterable to draw objects to process with func from.
Yields
------
element : _Ti
element : _Ti
Elements
_T
Func(element).
Results of func corresponding to element: func(element).
"""
for element in iterable:
yield element, func(element)


def removesuffix(string: str, suffix: str) -> str:
"""
Remove a suffix from a string.
Similar to string.removesuffix in PY3.9+.
Parameters
----------
string : str
String that should be edited.
suffix : str
Suffix to remove.
Returns
-------
str
Input string with removed suffix.
"""
import sys

if sys.version_info.minor >= 9:
# removesuffix is a Python3.9 feature
return string.removesuffix(suffix)
else:
return (
string[: -len(suffix)]
if len(suffix) > 0 and string.endswith(suffix)
else string
)


class SubjectDirectory:
"""
Represent a subject directory.
Expand Down Expand Up @@ -1012,7 +979,7 @@ def __getitem__(self, item: Union[int, str]) -> SubjectDirectory:
# subject is always an absolute path (or relative to the working directory) ... of the input file
subject = self._subjects[item]
sid = (
os.path.basename(removesuffix(subject, self._remove_suffix))
os.path.basename(subject.removesuffix(self._remove_suffix))
if self._sid is None
else self._sid
)
Expand Down

0 comments on commit bb0ea35

Please sign in to comment.