From c950c2fda2b32634916424ad7bed96f233b125f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 13 Jan 2023 15:54:25 +0100 Subject: [PATCH] fix some errors due to torch optimizations add code for gpu/torch implementation (currently seems to require too much gpu memory) --- FastSurferCNN/segstats.py | 703 ++++++++++++++++++++++---------------- 1 file changed, 416 insertions(+), 287 deletions(-) diff --git a/FastSurferCNN/segstats.py b/FastSurferCNN/segstats.py index 44dc9238..d2ce5652 100644 --- a/FastSurferCNN/segstats.py +++ b/FastSurferCNN/segstats.py @@ -16,8 +16,9 @@ # IMPORTS import argparse from functools import partial, wraps, reduce -from itertools import product, repeat, chain +from itertools import product, repeat from numbers import Number +from packaging import version as _version from typing import Sequence, Tuple, Union, Optional, Dict, overload, cast, TypeVar, List, Iterable import nibabel as nib @@ -28,15 +29,13 @@ try: import torch from torch.nn import functional as _F + _HAS_TORCH = True - from torch import Tensor, BoolTensor, IntTensor, FloatTensor except ImportError: # ensure Tensor, etc. are defined for typing + torch = NotImplemented + _F = NotImplemented _HAS_TORCH = False - Tensor = 'Tensor' - BoolTensor = 'BoolTensor' - IntTensor = 'IntTensor' - FloatTensor = 'FloatTensor' try: import numba from numba.np import numpy_support @@ -44,6 +43,8 @@ _HAS_NUMBA = True except ImportError: _HAS_NUMBA = False + numba = NotImplemented + numpy_support = NotImplemented from FastSurferCNN.utils.parser_defaults import add_arguments from FastSurferCNN.utils.arg_types import (int_gt_zero as patch_size, int_ge_zero as id_type, @@ -75,10 +76,11 @@ _NumberType = TypeVar('_NumberType', bound=Number) _IntType = TypeVar("_IntType", bound=np.integer) _DType = TypeVar('_DType', bound=np.dtype) -_ArrayType = TypeVar("_ArrayType", np.ndarray, Tensor) +_ArrayType = TypeVar("_ArrayType", np.ndarray, 'torch.Tensor') PVStats = Dict[str, Union[int, float]] -UNITS = {"Volume_mm3": "mm^3", "normMean": "MR", "normStdDev": "MR", "normMin": "MR", "normMax": "MR", "normRange": "MR"} +UNITS = {"Volume_mm3": "mm^3", "normMean": "MR", "normStdDev": "MR", "normMin": "MR", "normMax": "MR", + "normRange": "MR"} FIELDS = {"Index": "Index", "SegId": "Segmentation Id", "NVoxels": "Number of Voxels", "Volume_mm3": "Volume", "StructName": "Structure Name", "normMean": "Intensity normMean", "normStdDev": "Intensity normStdDev", "normMin": "Intensity normMin", "normMax": "Intensity normMax", "normRange": "Intensity normRange"} @@ -88,6 +90,7 @@ class HelpFormatter(argparse.HelpFormatter): """Help formatter that keeps line breaks for texts that start with '/keeplinebreaks/'.""" + def _fill_text(self, text, width, indent): klb_str = '/keeplinebreaks/' if text.startswith(klb_str): @@ -98,7 +101,8 @@ def _fill_text(self, text, width, indent): def make_arguments() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(usage=USAGE, epilog=HELPTEXT, description=DESCRIPTION, formatter_class=HelpFormatter) + parser = argparse.ArgumentParser(usage=USAGE, epilog=HELPTEXT, description=DESCRIPTION, + formatter_class=HelpFormatter) parser.add_argument('-norm', '--normfile', type=str, required=True, dest='normfile', help="Biasfield corrected image in the same image space as segmentation (required).") parser.add_argument('-i', '--segfile', type=str, dest='segfile', required=True, @@ -120,23 +124,25 @@ def make_arguments() -> argparse.ArgumentParser: "statistics == `--robust 1.0`).") advanced = parser.add_argument_group(title="Advanced options") advanced.add_argument('--legacy', dest='legacy', action='store_true', - help="Use the lagacy FreeSurfer algorithm.") + help="Use the lagacy FreeSurfer algorithm.") advanced.add_argument('--patch_size', type=patch_size, dest='patch_size', default=32, - help="Patch size to use in calculating the partial volumes (default: 32).") + help="Patch size to use in calculating the partial volumes (default: 32).") parser = add_arguments(parser, ['device', 'lut', 'allow_root']) return parser def loadfile_full(file: str, name: str, device: Union[str, 'torch.device'] = 'cpu') \ - -> Tuple[nib.analyze.SpatialImage, Union[np.ndarray, Tensor]]: + -> Tuple[nib.analyze.SpatialImage, Union[np.ndarray, 'torch.Tensor']]: try: img = nib.load(file) except (IOError, FileNotFoundError) as e: raise IOError(f"Failed loading the {name} '{file}' with error: {e.args[0]}") from e - if device == 'cpu': - return img, np.asarray(img.dataobj) - else: - return img, torch.as_tensor(np.asarray(img.dataobj), device=device) + data = np.asarray(img.dataobj) + if device != 'cpu': + if data.dtype.byteorder != "=": + data = data.astype(data.dtype.newbyteorder("native")) + data = torch.as_tensor(data, device=device) + return img, data def main(args): @@ -153,7 +159,9 @@ def main(args): device = args.device if _HAS_TORCH and hasattr(args, 'device') and args.device is not None else 'cpu' if device == 'auto': - device = 'cuda' if _HAS_TORCH and torch.cuda.is_available() else 'cpu' + device = 'cpu' # 'cuda' if _HAS_TORCH and torch.cuda.is_available() else 'cpu' + if device != 'cpu': + raise NotImplementedError("Only cpu is currently supported.") if _HAS_TORCH and device != "cpu": device = torch.device(device) @@ -171,8 +179,8 @@ def main(args): else: lut = None try: - seg, seg_data = seg_future.result() # type: nib.analyze.SpatialImage, Union[np.ndarray, IntTensor] - norm, norm_data = norm_future.result() # type: nib.analyze.SpatialImage, Union[np.ndarray, Tensor] + seg, seg_data = seg_future.result() # type: nib.analyze.SpatialImage, Union[np.ndarray, torch.IntTensor] + norm, norm_data = norm_future.result() # type: nib.analyze.SpatialImage, Union[np.ndarray, torch.Tensor] except IOError as e: return e.args[0] if hasattr(args, 'ids') and args.ids is not None and len(args.ids) > 0: @@ -192,7 +200,6 @@ def main(args): exclude_id = [] kwargs = { - "patch_size": args.patch_size, "vox_vol": np.prod(seg.header.get_zooms()).item(), "robust_percentage": args.robust if hasattr(args, 'robust') else None } @@ -201,28 +208,39 @@ def main(args): if hasattr(args, 'legacy') and args.legacy: if kwargs.get("robust") is not None: return "robust statistics are not supported in --legacy mode" - if not _HAS_NUMBA: + if _HAS_NUMBA: + def make_pool(): + return tpe + else: print("WARNING: Partial volume calculation in legacy mode without numba is VERY SLOW.") + from contextlib import contextmanager + from concurrent.futures import ProcessPoolExecutor + @contextmanager + def make_pool(): + with ProcessPoolExecutor() as ppe: + yield ppe + return from multiprocessing import cpu_count all_borders = tpe.map(seg_borders, repeat(seg_data), labels, chunksize=np.ceil(len(labels) / cpu_count())) border: npt.NDArray[bool] = np.logical_or.reduce(list(all_borders), axis=0) - table = list(tpe.map(partial(legacy_pv_calc, vox_vol=kwargs["vox_vol"]), - repeat(border), repeat(seg_data.astype(np.int32)), - repeat(norm_data.astype(np.float)), labels, - chunksize=np.ceil(len(labels) / cpu_count()))) + with make_pool() as pool: + table = list(pool.map(partial(legacy_pv_calc, vox_vol=kwargs["vox_vol"]), + repeat(border), repeat(seg_data.astype(np.int32)), + repeat(norm_data.astype(np.float)), labels, + chunksize=np.ceil(len(labels) / cpu_count()))) else: - table = pv_calc(seg_data, norm_data, labels, **kwargs) + table = pv_calc(seg_data, norm_data, labels, patch_size=args.patch_size, **kwargs) else: - table = pv_calc_torch(seg_data, norm_data, labels, **kwargs) + table = pv_calc_torch(seg_data, norm_data.to(torch.float), labels, **kwargs) if lut is not None: for i in range(len(table)): table[i]["StructName"] = lut[lut["ID"] == table[i]["SegId"]]["LabelName"].item() dataframe = pd.DataFrame(table, index=np.arange(len(table))) dataframe = dataframe[dataframe["NVoxels"] != 0].sort_values("SegId") - dataframe.index = np.arange(1, len(dataframe)+1) + dataframe.index = np.arange(1, len(dataframe) + 1) args.strict = True write_statsfile(args.segstatsfile, dataframe, exclude_id, vox_vol=kwargs["vox_vol"], args=args) return 0 @@ -269,6 +287,7 @@ def fmt_field(code: str, data) -> str: if is_f: prec += int(code[-2]) + 1 return filler + str(prec) + code + fmts = ("{:" + fmt_field(FORMATS[k], dataframe[k]) + "}" for k in dataframe.columns) fmt = "{:>" + str(max_index) + "d} " + " ".join(fmts) + "\n" for index, row in dataframe.iterrows(): @@ -307,7 +326,7 @@ def read_classes_from_lut(lut_file): def seg_borders(_array: _ArrayType, label: Union[np.integer, bool], out: Optional[_ArrayType] = None, cmp_dtype: npt.DTypeLike = "int8") -> _ArrayType: """Handle to fast 6-connected border computation.""" - if _HAS_TORCH and isinstance(_array, Tensor): + if _HAS_TORCH and isinstance(_array, torch.Tensor): bin_array = _array if _array.dtype is torch.bool else _array == label def _laplace(data): @@ -333,28 +352,34 @@ def borders(_array: _ArrayType, labels: Union[Iterable[np.integer], bool], max_l """Handle to fast border computation.""" dim = _array.ndim + is_tensor = _HAS_TORCH and isinstance(_array, torch.Tensor) + array_alloc = _array.new_full if is_tensor else partial(np.full, dtype=_array.dtype) + _shape_plus2 = [s + 2 for s in _array.shape] + if labels is True: # already binarized - if _HAS_TORCH and torch.is_tensor(_array): - cmp = torch.logical_xor - else: - cmp = np.logical_xor + if not (is_tensor and _array.dtype is torch.bool) and not (not is_tensor and np.issubdtype(_array, bool)): + raise ValueError("If labels is True, the array should be boolean.") + cmp = torch.logical_xor if is_tensor else np.logical_xor else: + if (is_tensor and _array.dtype is torch.bool) or (not is_tensor and np.issubdtype(_array, bool)): + raise ValueError("If labels is a list/iterable, the array should not be boolean.") + def cmp(a, b): return a == b if max_label is None: max_label = _array.max().item() - is_tensor = _HAS_TORCH and isinstance(_array, Tensor) - lab_dtype = np.dtype(int) if is_tensor else _array.dtype - lookup = np.zeros(max_label + 1, dtype=lab_dtype) - labels = list(labels) + lookup = array_alloc((max_label + 1,), fill_value=0) + # filter labels from labels that are bigger than max_label + labels = list(filter(lambda x: x <= max_label, labels)) if 0 not in labels: labels = [0] + labels - lookup[labels] = np.arange(len(labels), dtype=lab_dtype) - array_alloc = partial(np.full, dtype=_array.dtype) if is_tensor else _array.new_full - logical_or = torch.tensor if is_tensor else np.logical_or - __array = array_alloc([s + 2 for s in _array.shape], fill_value=0) - __array[(slice(1, -1),) * dim] = _array if _array is True else lookup[_array] + arange = partial(torch.arange, device=_array.device) if is_tensor else np.arange + lookup[labels] = arange(len(labels), dtype=lookup.dtype) + _array = lookup[_array.to(torch.long)] + logical_or = torch.logical_or if is_tensor else np.logical_or + __array = array_alloc(_shape_plus2, fill_value=0) + __array[(slice(1, -1),) * dim] = _array mid = (slice(1, -1),) * dim if six_connected: @@ -371,8 +396,8 @@ def ii(off: Iterable[int]) -> Tuple[slice, ...]: nbr_same = [cmp(__array[mid], __array[ii(i - 1)]) for i in np.ndindex((3,) * dim) if np.all(i != 1)] if is_tensor: - for axis in range(dim-1, 2, -1): - nbr_same[axis-1].logical_or_(nbr_same[axis]) + for axis in range(dim - 1, 2, -1): + nbr_same[axis - 1].logical_or_(nbr_same[axis]) return torch.logical_or(nbr_same[0], nbr_same[1], out=out) else: return np.logical_or.reduce(nbr_same, out=out) @@ -394,7 +419,7 @@ def legacy_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[_IntType], if _HAS_NUMBA: def numba_auto_cast_types(func): - # helper function to cast the types of the call + """helper function to cast the types of the call.""" def auto_cast(var): if isinstance(var, np.ndarray): @@ -416,21 +441,30 @@ def wrapper_func(*args, **kwargs): return wrapper_func - from numba import types as nbt - _LD = nbt.LiteralStrKeyDict - _L = nbt.literal - _nb_Float3d = nbt.double[:, :, :] - _nb_Bool3d = nbt.boolean[:, :, :] - _nb_Int3d = nbt.int32[:, :, :] - _nb_ReturnType = nbt.Tuple([nbt.int32, nbt.int_, nbt.double, nbt.double, nbt.double, nbt.double, nbt.double, - _nb_Int3d, _nb_Float3d, _nb_Float3d, _nb_Float3d]) - _pv_calc_signatures = _nb_ReturnType(_nb_Bool3d, _nb_Int3d, _nb_Float3d, nbt.int_, nbt.double, nbt.boolean, nbt.int_) - _nb_ReturnType = nbt.Tuple([nbt.int_[:], nbt.double[:]]) - _nbhd_signatures = _nb_ReturnType(_nb_Int3d, _nb_Float3d, nbt.int_, nbt.int_, nbt.int_, nbt.int_, nbt.int_) - _nb_ReturnType = nbt.int_[:] - _nbhd_nomean_signatures = _nb_ReturnType(_nb_Int3d, nbt.int_, nbt.int_, nbt.int_, nbt.int_, nbt.int_) - - @numba.njit + + def __pv_calc_signatures(): + """Numba Signatures for the numba_pv_calc function.""" + nbt = numba.types + _nb_Bool3d, _nb_Float3d, _nb_Int3d = nbt.boolean[:, :, :], nbt.double[:, :, :], nbt.int32[:, :, :] + _nb_ReturnType = nbt.Tuple([nbt.int32, nbt.int_, nbt.double, nbt.double, nbt.double, nbt.double, nbt.double, + _nb_Int3d, _nb_Float3d, _nb_Float3d, _nb_Float3d]) + return _nb_ReturnType(_nb_Bool3d, _nb_Int3d, _nb_Float3d, nbt.int_, nbt.double, nbt.boolean, nbt.int_) + + + def __nbhd_signatures(): + """Numa Signatures for the mri_compute_label_nbhd function.""" + nbt = numba.types + _nb_ReturnType = nbt.Tuple([nbt.int_[:], nbt.double[:]]) + return _nb_ReturnType(nbt.int32[:, :, :], nbt.double[:, :, :], nbt.int_, nbt.int_, nbt.int_, nbt.int_, nbt.int_) + + + def __nbhd_nomean_signatures(): + """Numa Signatures for the mri_compute_label_nbhd_no_mean function.""" + nbt = numba.types + return nbt.int_[:](nbt.int32[:, :, :], nbt.int_, nbt.int_, nbt.int_, nbt.int_, nbt.int_) + + + @numba.njit(cache=True) def numba_unique_with_counts(array): """Similar to `numpy.unique(array, return_counts=True)`.""" # https://github.com/numba/numba/pull/2959 @@ -446,7 +480,7 @@ def numba_unique_with_counts(array): return np.asarray(unique), np.asarray(counts) - @numba.njit + @numba.njit(cache=True) def numba_repeat_axis0(array, repeats): """Similar to `numpy.repeat(array, repeats, axis=0)`.""" r = np.empty((repeats,) + array.shape, dtype=array._dtype) @@ -455,7 +489,7 @@ def numba_repeat_axis0(array, repeats): return r - @numba.njit(_nbhd_signatures, nogil=True) + @numba.njit(__nbhd_signatures(), nogil=True, cache=True) def mri_compute_label_nbhd(seg_data: npt.NDArray[_IntType], norm_data: Optional[np.ndarray], x: int, y: int, z: int, whalf: int = 1, maxlabels: int = 20_000) \ -> Tuple[npt.NDArray[int], Optional[npt.NDArray[float]]]: @@ -472,31 +506,63 @@ def mri_compute_label_nbhd(seg_data: npt.NDArray[_IntType], norm_data: Optional[ label_means[e] = np.sum(norm_sub_array * (sub_array == e)) / label_counts[e] else: label_means = None - return label_counts, label_means - @numba.njit(_nbhd_nomean_signatures, nogil=True) + @numba.njit(__nbhd_nomean_signatures(), nogil=True, cache=True) def mri_compute_label_nbhd_no_mean(seg_data: npt.NDArray[_IntType], - x: int, y: int, z: int, whalf: int = 1, maxlabels: int = 20_000) \ + x: int, y: int, z: int, whalf: int = 1, maxlabels: int = 20_000) \ -> npt.NDArray[int]: """Numba-compiled version of `mri_compute_label_nbhd(seg_data, norm_data, x, y, z, whalf, maxlabels)`.""" label_counts = np.zeros(maxlabels, dtype='int') sub_array = seg_data[x - whalf:x + whalf + 1, y - whalf:y + whalf + 1, z - whalf:z + whalf + 1] elems, counts = numba_unique_with_counts(sub_array) label_counts[elems] = counts - return label_counts @numba_auto_cast_types - def legacy_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[np.int_], norm_array: npt.NDArray[np.uint8], + def legacy_pv_calc(border: npt.NDArray[bool], seg: npt.NDArray[np.int_], norm: npt.NDArray[np.uint8], label: int, vox_vol: float = 1.0, return_maps: bool = False, maxlabels: int = 20_000) \ -> Union[PVStats, Tuple[PVStats, Dict[str, np.ndarray]]]: - + """Calculate PV effects and volume statistics compatible with FreSurfer's mri_seg_stats for a specific label. + It optionally returns the partial volume maps (see below). + + This function not only produces the same results as FreeSurfer's mri_seg_stats, but also does it using the same + algorithm (i.e. it is just a port of the algorithm). + + Args: + border: The global border of the image (which voxels should be considered for PV effects). + seg: The segmentation map of the primary class per voxel. + norm: The bias-field corrected T1w image (intensity image). + label: The class (from seg) to analyze for PV effects. + vox_vol: The volume per voxel in mm (default: 1.0). + return_maps: Whether partial volume maps should be returned (default: False). + maxlabels: The biggest allowed label index (in seg) (default: 20000) + + Returns: + Segmentation Statistics: A Dictionary of: + SegId: The label/class + NVoxels: The number of voxels for label in seg + Volume_mm3: The volume of label after partial volume correction + StructName: (empty) -- never populated, use the lut to populate, see for example main() + normMean: The average intensity in the bias-field corrected image of voxels labeled as label + normMin: The minimum intensity in the bias-field corrected image of voxels labeled as label + normMax: The maximum intensity in the bias-field corrected image of voxels labeled as label + normStdDev: The standard deviation of the intensities in the bias-field corrected image of voxels labeled as label + normRange: The range of intensities in the bias-field corrected image of voxels labeled as label + If return_maps is True, also returns: Partial Volume Maps: A Dictionary of: + nbr: An image of alternative labels that were considered instead of the voxel's label + nbrmean: The local mean intensity of the label nbr at the specific voxel + segmean: The local mean intensity of the current label at the specific voxel + pv: The partial volume map of the current label + + Note: + This function is just a wrapper to numba-compiled function to harmonize data types and provide a consistent + interface. """ label, _voxel_count, volumes, _mean, _std, _min, _max, full_nbr_label, full_nbr_mean, full_seg_mean, full_pv = \ - numba_pv_calc(border, seg_array, norm_array, label, - vox_vol=vox_vol, return_maps=return_maps, maxlabels=maxlabels) + __legacy_pv_calc_impl(border, seg, norm, label, + vox_vol=vox_vol, return_maps=return_maps, maxlabels=maxlabels) result = {"SegId": int(label), "NVoxels": int(_voxel_count), "Volume_mm3": float(volumes), "StructName": "", "normMean": _mean, "normStdDev": _std, "normMin": _min, "normMax": _max, @@ -507,10 +573,13 @@ def legacy_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[np.int_], n return result - @numba.njit(_pv_calc_signatures, parallel=True, nogil=True) - def numba_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[np.int_], norm_array: npt.NDArray[np.uint8], - label: int, vox_vol: float, return_maps: bool, maxlabels: int) \ - -> Tuple[int, int, float, float, float, float, float, npt.NDArray]: + + @numba.njit(__pv_calc_signatures(), parallel=True, nogil=True, cache=True) + def __legacy_pv_calc_impl(border: npt.NDArray[bool], seg_array: npt.NDArray[np.int_], + norm_array: npt.NDArray[_NumberType], + label: int, vox_vol: float, return_maps: bool, maxlabels: int) \ + -> Tuple[int, int, float, float, float, float, float, npt.NDArray[np.int_], npt.NDArray[float], npt.NDArray[ + float], npt.NDArray[float]]: label, _min, _max = np.int32(label), np.infty, -np.infty _voxel_count, _sum, _sum_2, volumes = 0, 0., 0., 0. @@ -530,7 +599,7 @@ def numba_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[np.int_], no vox_label = seg_array[pos] border_val = border[pos] - ## Addition for other stats: + # Addition for other stats: val = norm_array[pos] if val < _min: _min = val @@ -618,7 +687,7 @@ def mri_compute_label_nbhd(seg_data: npt.NDArray[_IntType], norm_data: Optional[ x: int, y: int, z: int, whalf: int = 1, maxlabels: int = 20_000) \ -> Tuple[npt.NDArray[int], npt.NDArray[float]]: """Port of the c++ function mri_compute_label_nbhd from mri.cpp of FreeSurfer.""" - ## Almost 1-to-1 port of the function + # Almost 1-to-1 port of the function # def mri_compute_label_nbhd(mri, mri_vals, x, y, z, label_counts, label_means, whalf=1, maxlabels=20000): # label_counts = np.zeros_like(label_counts) # label_means = np.zeros_like(label_means) @@ -636,7 +705,7 @@ def mri_compute_label_nbhd(seg_data: npt.NDArray[_IntType], norm_data: Optional[ # label_means[label] += val # label_means = np.nan_to_num(label_means / label_counts) # return label_counts, label_means - ## Optimized vectorized implementation: + # Optimized vectorized implementation: label_counts = np.zeros(maxlabels, dtype=int) label_means = np.zeros(maxlabels, dtype=float) sub_array = seg_data[x - whalf:x + whalf + 1, y - whalf:y + whalf + 1, z - whalf:z + whalf + 1] @@ -654,8 +723,7 @@ def mri_compute_label_nbhd(seg_data: npt.NDArray[_IntType], norm_data: Optional[ def legacy_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[_IntType], norm_array: np.ndarray, - label: _IntType, - vox_vol: float = 1.0, return_maps: bool = False, maxlabels: int = 20_000) \ + label: _IntType, vox_vol: float = 1.0, return_maps: bool = False, maxlabels: int = 20_000) \ -> Union[PVStats, Tuple[PVStats, Dict[str, np.ndarray]]]: """mri_seg_stats from FreeSurfer equivalent.""" @@ -675,7 +743,7 @@ def legacy_pv_calc(border: npt.NDArray[bool], seg_array: npt.NDArray[_IntType], vox_label = seg_array[x, y, z] border_val = border[x, y, z] - ## Addition for other stats: + # Addition for other stats: val = norm_array[x, y, z] if val < _min: _min = val @@ -766,6 +834,17 @@ def unsqueeze(matrix, axis: Union[int, Sequence[int]] = -1): return matrix +if _HAS_TORCH and _version.parse(torch.__version__) >= _version.parse("1.11"): + def _round(_arr: _ArrayType, decimals: int = 0) -> _ArrayType: + return _arr.round(decimals=decimals) +else: + def _round(_arr: _ArrayType, decimals: int = 0) -> _ArrayType: + if _HAS_TORCH and isinstance(_arr, torch.Tensor): + return torch.round(_arr / 10 ** decimals) * 10 ** decimals + else: + return np.round(_arr, decimals=decimals) + + def grow_patch(patch: Sequence[slice], whalf: int, img_size: Union[np.ndarray, Sequence[float]]) -> Tuple[ Tuple[slice, ...], Tuple[slice, ...]]: """Create two slicing tuples for indexing ndarrays/tensors that 'grow' and re-'ungrow' the patch `patch` by `whalf` (also considering the image shape).""" @@ -785,22 +864,26 @@ def grow_patch(patch: Sequence[slice], whalf: int, img_size: Union[np.ndarray, S def uniform_filter(arr: _ArrayType, filter_size: int, patch: Optional[Tuple[slice, ...]] = None, out: Optional[_ArrayType] = None) -> _ArrayType: - """Apply a uniform filter (with kernel size `filter_size`) to `input`. The uniform filter is normalized (weights add to one). - `ungrow_patch` is included for optional optimization purposes.""" + """Apply a uniform filter (with kernel size `filter_size`) to `input`. The uniform filter is normalized + (weights add to one).""" if isinstance(arr, np.ndarray): _patch = (slice(None),) if patch is None else patch + arr = arr.astype(float) from scipy.ndimage import uniform_filter def _uniform_filter(_arr): return uniform_filter(_arr, size=filter_size, mode='constant', cval=0)[_patch] - - arr = arr.astype(float) else: - weight = torch.full((1, 1) + (filter_size,) * arr.ndim, 1 / (filter_size ** arr.ndim), device=arr.device) - _uniform_filter = partial(_Conv[arr.ndim], weight=weight) - arr = torch_pad_crop_for_conv(arr, filter_size, patch) + weight = torch.full((1,1) + (filter_size,) * arr.ndim, 1 / (filter_size ** arr.ndim), device=arr.device) + arr = arr.to(torch.float) + + def _uniform_filter(_arr, out = None): + kw = {} if out is None else {'out': unsqueeze(out, [0, 0])} + return _Conv[arr.ndim](unsqueeze(_arr, [0, 0]), + weight=weight, stride=[1]*_arr.ndim, padding='same', + **kw).squeeze(0).squeeze(0) if out is not None: - out[:] = _uniform_filter(arr) + _uniform_filter(arr, out) return out return _uniform_filter(arr) @@ -852,7 +935,7 @@ def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntT def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntType], - patch_size: int = 32, vox_vol: float = 1.0, eps: float = 1e-6, robust_keep_fraction: Optional[float] = None, + patch_size: int = 32, vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, return_maps: bool = False) \ -> Union[List[PVStats], Tuple[List[PVStats], Dict[str, np.ndarray]]]: """Function to compute volume effects. @@ -864,14 +947,19 @@ def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntT patch_size: Size of patches (default: 32) vox_vol: volume per voxel (default: 1.0) eps: threshold for computation of equality (default: 1e-6) - robust_keep_fraction: fraction for robust calculation of statistics (e.g. 0.95 drops both the 2.5% + robust_percentage: fraction for robust calculation of statistics (e.g. 0.95 drops both the 2.5% lowest and highest values per region) (default: None/1. == off) return_maps: returns a dictionary containing the computed maps. Returns: Table (list of dicts) with keys SegId, NVoxels, Volume_mm3, StructName, normMean, normStdDev, normMin, normMax, and normRange. (Note: StructName is unfilled) - if return_maps: a dictionary with the 5 meta-information pv-maps + if return_maps: a dictionary with the 5 meta-information pv-maps: + nbr: An image of alternative labels that were considered instead of the voxel's label + nbrmean: The local mean intensity of the label nbr at the specific voxel + segmean: The local mean intensity of the primary label at the specific voxel + pv: The partial volume of the primary label at the location + ipv: The partial volume of the alternative (nbr) label at the location """ mins, maxes, voxel_counts, __voxel_counts, sums, sums_2, volumes = [{} for _ in range(7)] @@ -881,16 +969,16 @@ def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntT global_crop: Tuple[slice, ...] = tuple(slice(0, _shape) for _shape in seg.shape) # ignore all regions of the image that are background only if 0 not in labels: - # crop global_crop to the data - any_in_global, global_crop = crop_patch_to_mask(global_crop, seg != 0) + # crop global_crop to the data (plus one extra voxel) + any_in_global, global_crop = crop_patch_to_mask(seg != 0, sub_patch=global_crop) # grow global_crop by one, so all border voxels are included global_crop = grow_patch(global_crop, 1, seg.shape)[0] if not any_in_global: - raise RuntimeError("Segmentation map only consists background") + raise RuntimeError("Segmentation map only consists of background") global_stats_filled = partial(global_stats, norm=norm[global_crop], seg=seg[global_crop], - robust_keep_fraction=robust_keep_fraction) + robust_percentage=robust_percentage) from multiprocessing import cpu_count map_kwargs = {"chunksize": np.ceil(len(labels) / cpu_count())} @@ -924,8 +1012,7 @@ def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntT # iterate through patches of the image patch_iters = [range(slice_.start, slice_.stop, patch_size) for slice_ in global_crop] # for 3D - if "chunksize" in map_kwargs: - map_kwargs["chunksize"] = np.ceil(len(voxel_counts) / cpu_count() / 4) # 4 chunks per core + map_kwargs["chunksize"] = np.ceil(len(voxel_counts) / cpu_count() / 4) # 4 chunks per core _patches = pool.map(partial(patch_filter, mask=border, global_crop=global_crop, patch_size=patch_size), product(*patch_iters), **map_kwargs) patches = (patch for has_pv_vox, patch in _patches if has_pv_vox) @@ -953,7 +1040,7 @@ def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntT def global_stats(lab: _IntType, norm: npt.NDArray[_NumberType], seg: npt.NDArray[_IntType], - out: Optional[npt.NDArray[bool]] = None, robust_keep_fraction: Optional[float] = None) \ + out: Optional[npt.NDArray[bool]] = None, robust_percentage: Optional[float] = None) \ -> Union[Tuple[_IntType, int], Tuple[_IntType, int, int, _NumberType, _NumberType, float, float, float, npt.NDArray[bool]]]: """Computes Label, Number of voxels, 'robust' number of voxels, norm minimum, maximum, sum, sum of squares and @@ -970,9 +1057,9 @@ def global_stats(lab: _IntType, norm: npt.NDArray[_NumberType], seg: npt.NDArray else: out[:] = seg_borders(bin_array, True, cmp_dtype="int").astype(bool) - if robust_keep_fraction is not None: + if robust_percentage is not None: data = np.sort(data) - sym_drop_samples = int((1 - robust_keep_fraction / 2) * nvoxels) + sym_drop_samples = int((1 - robust_percentage / 2) * nvoxels) data = data[sym_drop_samples:-sym_drop_samples] _min: _NumberType = data[0].item() _max: _NumberType = data[-1].item() @@ -988,62 +1075,62 @@ def global_stats(lab: _IntType, norm: npt.NDArray[_NumberType], seg: npt.NDArray return lab, nvoxels, __voxel_count, _min, _max, _sum, sum_2, volume, out -def patch_filter(pos: Tuple[int, int, int], mask: Union[npt.NDArray[bool], BoolTensor], +def patch_filter(pos: Tuple[int, int, int], mask: Union[npt.NDArray[bool], 'torch.BoolTensor'], global_crop: Tuple[slice, ...], patch_size: int = 32) \ -> Tuple[bool, Sequence[slice]]: """Returns, whether there are mask-True voxels in the patch starting at pos with size patch_size and the resulting patch shrunk to mask-True regions.""" # create slices for current patch context (constrained by the global_crop) patch = [slice(p, min(p + patch_size, slice_.stop)) for p, slice_ in zip(pos, global_crop)] - pat_border = mask[tuple(patch)] # crop patch context to the image content - return crop_patch_to_mask(patch, mask=pat_border) + return crop_patch_to_mask(mask, sub_patch=patch) -def crop_patch_to_mask(patch: Sequence[slice], mask: Union[npt.NDArray[_NumberType], BoolTensor], - delayed_optim: bool = False) -> Tuple[bool, Sequence[slice]]: +def crop_patch_to_mask(mask: Union[npt.NDArray[_NumberType], 'torch.BoolTensor'], + sub_patch: Optional[Sequence[slice]] = None) \ + -> Tuple[bool, Sequence[slice]]: """Crop the patch to regions of the patch that are non-zero. Assumes mask is always positive. Returns whether there is any mask>0 in the patch and a patch shrunk to mask>0 regions. - If delayed_optim is True, more work will be done, but the code is less dependent on torch execution and should be - faster for torch. For numpy arrays, delayed_optim will always be False.""" - is_tensor = _HAS_TORCH and torch.is_tensor(mask) - delayed_optim = delayed_optim and is_tensor + Args: + mask: to crop to + sub_patch: subregion of mask to only consider (default: full mask) - def p_tup(__patch: Sequence[slice]) -> Tuple[slice, ...]: - if delayed_optim and is_tensor: - return tuple(slice(i.item(), j.item()) for i, j in __patch) - elif delayed_optim: - return tuple(__patch) - else: - return slice(None), + Note: + This function requires device synchronization.""" + + is_tensor = _HAS_TORCH and torch.is_tensor(mask) def axis_kw(axis): - return {"sim" if is_tensor else "axis": axis} + return {"dim" if is_tensor else "axis": axis} _patch = [] - _mask = mask.sum(**axis_kw(2)) - for i, pat in enumerate(patch): + patch = tuple([slice(0, s) for s in mask.shape] if sub_patch is None else sub_patch) + patch_in_patch_coords = tuple([slice(0, slice_.stop - slice_.start) for slice_ in patch]) + in_mask = True + _mask = mask[patch].sum(**axis_kw(2)) + for i, pat in enumerate(patch_in_patch_coords): p = pat.start - if i == 2: - _mask = mask[p_tup(_patch)].sum(**axis_kw(0)) - # can we shrink the patch context in i-th axis? - pat_has_mask_in_axis = _mask[p_tup(_patch[1:] if i != 2 else [])].sum(**axis_kw(0 if i == 1 else 1)) > 0 - # modify both the _patch_size and the coordinate p to shrink the patch - if is_tensor: - offset = pat_has_mask_in_axis.argwhere()[0] - p = offset.new_tensor(p) + offset - _patch_size = pat_has_mask_in_axis.argwhere()[-1] - offset + 1 - _patch.append((p, p + _patch_size)) + if in_mask: + if i == 2: + _mask = mask[patch][tuple(_patch)].sum(**axis_kw(0)) + # can we shrink the patch context in i-th axis? + pat_has_mask_in_axis = _mask[tuple(_patch[1:] if i != 2 else [])].sum(**axis_kw(int(i == 0))) > 0 + # modify both the _patch_size and the coordinate p to shrink the patch + _pat_mask = pat_has_mask_in_axis.nonzero() if is_tensor else np.argwhere(pat_has_mask_in_axis) + if _pat_mask.shape[0] == 0: + _patch_size = 0 + in_mask = False + else: + offset = _pat_mask[0].item() + p += offset + _patch_size = _pat_mask[-1].item() - offset + 1 else: - offset = pat_has_mask_in_axis.argwhere()[0].item() - p += offset - _patch_size = pat_has_mask_in_axis.argwhere()[-1].item() - offset + 1 - _patch.append(slice(p, p + _patch_size)) - if is_tensor: - # optimization note: This is the point when the torch results are waited for if delayed_optim is True - _patch = [slice(i.item(), j.item()) for i, j in _patch] - return _patch[0].start != _patch[0].stop, _patch + _patch_size = 0 + _patch.append(slice(p, p + _patch_size)) + + out_patch = [slice(_p.start + p.start, p.start + _p.stop) for _p, p in zip(_patch, patch)] + return _patch[0].start != _patch[0].stop, out_patch def pv_calc_patch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...], @@ -1076,7 +1163,7 @@ def pv_calc_patch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...], pat_is_border, pat_is_nbr, pat_label_counts, pat_label_sums \ = patch_neighbors(label_lookup, norm, seg, pat_border, loc_border, patch_grow7, patch_in_gc, patch_shrink6, ungrow1_patch, ungrow7_patch, - ndarray_alloc=np.zeros, eps=eps) + ndarray_alloc=np.full, eps=eps) # both counts and sums are "normalized" by the local neighborhood size (15**3) label_lookup_fwd = np.zeros((maxlabels,), dtype="int") @@ -1108,7 +1195,7 @@ def pv_calc_patch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...], none_valid = ~is_valid.any(axis=0, keepdims=False) # select the label, that is valid or not valid but also exists and is not the current label - max_counts_index = (pat1d_label_counts * is_valid).round(log_eps).argmax(axis=0, keepdims=False) + max_counts_index = _round(pat1d_label_counts * is_valid, log_eps).argmax(axis=0, keepdims=False) nbr_label = label_lookup[max_counts_index] # label with max_counts nbr_label[none_valid] = 0 @@ -1148,110 +1235,138 @@ def pv_calc_patch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...], if _HAS_TORCH: @overload - def pv_calc_torch(seg: IntTensor, norm: Tensor, labels: Sequence[_IntType], patch_size: int = 32, + def pv_calc_torch(seg: 'torch.IntTensor', norm: 'torch.Tensor', labels: Sequence[_IntType], vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, return_maps: False = False) -> List[PVStats]: ... @overload - def pv_calc_torch(seg: IntTensor, norm: Tensor, labels: Sequence[_IntType], patch_size: int = 32, + def pv_calc_torch(seg: 'torch.IntTensor', norm: 'torch.Tensor', labels: Sequence[_IntType], vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, return_maps: True = True) \ -> Tuple[List[PVStats], Dict[str, Dict[int, np.ndarray]]]: ... - def pv_calc_torch(seg: IntTensor, norm: Tensor, labels: Sequence[_IntType], patch_size: int = 32, - vox_vol: float = 1.0, eps: float = 1e-6, robust_keep_fraction: Optional[float] = None, + def pv_calc_torch(seg: 'torch.IntTensor', norm: 'torch.Tensor', labels: Sequence[_IntType], + vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, return_maps: bool = False) \ -> Union[List[PVStats], Tuple[List[PVStats], Dict[str, np.ndarray]]]: """torch-port of pv_calc(). As opposed to pv_calc, this function requires Tensors for seg and norm.""" - if not isinstance(seg, IntTensor) or not isinstance(norm, Tensor): + if not isinstance(seg, torch.Tensor) or not isinstance(norm, torch.Tensor): raise ValueError("Either seg or norm are not IntTensors or Tensors, respectively.") if seg.device != norm.device: raise ValueError("seg and norm are not on the same device.") - log_eps = -int(np.log10(eps)) - teps = seg.new_full(1, fill_value=eps, dtype=torch.float) - mins, maxes, voxel_counts, __voxel_counts, sums, sums_2, volumes = [{} for _ in range(7)] - loc_border = {} - - # initialize global_crop with the full image - global_crop: Tuple[slice, ...] = tuple(slice(0, _shape) for _shape in seg.shape) - # ignore all regions of the image that are background only - if 0 not in labels: - # crop global_crop to the data - any_in_global, global_crop = crop_patch_to_mask(global_crop, seg != 0) - # grow global_crop by one, so all border voxels are included - global_crop = grow_patch(global_crop, 1, seg.shape)[0] - if not any_in_global: - raise RuntimeError("Segmentation map only consists background") - - all_labels = torch.unique(seg) - max_label = all_labels[-1].item() + 1 - border = borders(seg, labels, max_label=max_label) - - global_stats_filled = partial(label_stats_torch, - global_crop=global_crop, norm=norm[global_crop], seg=seg[global_crop], - robust_keep_fraction=robust_keep_fraction, eps=eps) - # Maybe put this in a threading map? - label_stats = map(global_stats_filled, labels) - - means, counts = [torch.sparse_coo_tensor([], [], (max_label,) + seg.shape, - dtype=torch.float, device=seg.device) for _ in range(2)] - crop2label, this_6border = {}, {} - - for lab, *data in label_stats: - if data[0] != 0: - voxel_counts[lab], __voxel_counts[lab] = data[:2] - mins[lab], maxes[lab], sums[lab], sums_2[lab] = data[2:6] - volumes[lab], this_6border[lab] = data[6] * vox_vol, data[7] - crop2label[lab], indices = data[7:-2] - means[lab, indices], counts[lab, indices] = data[-2:] - - full_seg_mean = means[seg] - # 1. considered (mean of) alternative label must be on the other side of norm as the (mean of) the segmentation - # label of the current voxel - is_switched_sign = cast(BoolTensor, means > norm.unsqueeze(0)).logical_xor(cast(BoolTensor, full_seg_mean > norm).unsqueeze(0)) - # 3. (mean of) segmentation label must be different to norm of voxel - is_valid = is_switched_sign.logical_and(unsqueeze(full_seg_mean != norm, 0)) - - none_valid = ~is_valid.any(dim=0, keepdim=False) - # select the label, that is valid or not valid but also exists and is not the current label - full_nbr_label = (counts * is_valid).round(log_eps).argmax(dim=0, keepdim=False) - - full_nbr_label[none_valid] = 0 - full_nbr_label = full_nbr_label.to_dense() - - # get the mean label intensity of the "alternative label" - full_nbr_mean = torch.take_along_dim(means, unsqueeze(full_nbr_label, 0), dim=0)[0] - - # interpolate between the "local" and "alternative label" - mean_to_mean_nbr = full_seg_mean - full_nbr_mean - delta_gt_eps = mean_to_mean_nbr.abs() > teps - full_pv = (norm - full_nbr_mean) / mean_to_mean_nbr.where(delta_gt_eps, teps) # make sure no division by zero - - full_pv[~delta_gt_eps] = 1. # set pv fraction to 1 if division by zero - full_pv[none_valid] = 1. # set pv fraction to 1 for voxels that have no 'valid' nbr - full_pv[full_pv > 1.] = 1. - full_pv[full_pv < 0.] = 0. - - full_pv = full_pv.to_dense() - - # re-create the "supposed" freesurfer inconsistency that does not count vertex neighbors, if the voxel label - # is not of question - mask_by_6border = full_pv.new_full(dtype=torch.bool) - for lab, crop in crop2label.items(): - is_nbr = full_nbr_label[crop] - mask_by_6border[crop][is_nbr] = this_6border[is_nbr] - - full_ipv = (1. - full_pv) * mask_by_6border - - for lab in labels: - volumes[lab] += ((full_pv * (seg == lab)).sum() + (full_ipv * (full_nbr_label == lab)).sum()) * vox_vol + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor() as tpe: + # executed in different thread as it causes host-device synchronization + labels_in_seg = tpe.submit(torch.unique, seg) + + log_eps = -int(np.log10(eps)) + teps = seg.new_full((1,), fill_value=eps, dtype=torch.float) + mins, maxes, voxel_counts, __voxel_counts, sums, sums_2, volumes = [{} for _ in range(7)] + + # initialize global_crop with the full image + global_crop: Tuple[slice, ...] = tuple(slice(0, _shape) for _shape in seg.shape) + # ignore all regions of the image that are background only + if 0 not in labels: + # crop global_crop to the data + any_in_global, global_crop = crop_patch_to_mask(cast(torch.BoolTensor, seg != 0), sub_patch=global_crop) + # grow global_crop by one, so all border voxels are included + global_crop = grow_patch(global_crop, 1, seg.shape)[0] + if not any_in_global: + raise RuntimeError("Segmentation map only consists background") + + labels_in_seg = labels_in_seg.result() + relevant_labels = list(sorted(list(set(labels_in_seg.cpu().tolist()).intersection([0] + list(labels))))) + max_label = labels_in_seg[-1].item() + border = borders(seg, relevant_labels, max_label=max_label) + # executed in different thread as it causes host-device synchronization + norm_of_border = tpe.submit(torch.masked_select, norm, border) + if return_maps: + indices_of_border = tpe.submit(torch.nonzero, border.flatten()) + indices_in_flat = border.to(torch.int, copy=True).flatten().cumsum_(0).sub_(1).reshape(border.shape) + num_border: int = indices_in_flat[tuple(-1 for _ in range(border.ndim))].item() + 1 + norm_of_border = norm_of_border.result() + + # allocate memory for the means and counts (first dim must be `labels_in_seg` for consistent results) + means_of_border = norm.new_full((labels_in_seg.shape[0], num_border), fill_value=0, dtype=torch.float) + counts_of_border = norm.new_full((labels_in_seg.shape[0], num_border), fill_value=0, dtype=torch.float) + mask_by_6border = norm.new_full((labels_in_seg.shape[0], num_border), fill_value=False, dtype=torch.bool_) + + label_stats_filled = partial(label_stats_torch, norm=norm[global_crop], seg=seg[global_crop], + indices=indices_in_flat[global_crop], border=border[global_crop], + robust_keep_fraction=robust_percentage, eps=eps) + # this map must be run on `labels_in_seg` (including not-asked for labels to make results consistent between + # different selections of labels, i.e. parameters `labels`) + # TODO: tpe.map + label_stats = map(label_stats_filled, labels_in_seg, counts_of_border, means_of_border, mask_by_6border) + crop2label, this_6border = {}, {} + + for lab, *data in label_stats: + if data[0] != 0: + voxel_counts[lab], __voxel_counts[lab] = data[:2] + mins[lab], maxes[lab], sums[lab], sums_2[lab] = data[2:6] + volumes[lab], this_6border[lab] = data[6] * vox_vol, data[7] + + seg_of_border = seg[border] + seg_mean_of_border = torch.gather(means_of_border, dim=0, index=labels_in_seg[seg_of_border]) + if return_maps: + keys_of_maps = ["nbr", "segmean", "nbrmean", "pv", "ipv"] + dtypes, defaults = {'nbr': torch.int}, {"pv": 1} + maps = {key: seg.new_full(seg.shape, + fill_value=defaults.get(key, 0), + dtype=dtypes.get(key, torch.float)) for key in keys_of_maps} + indices_of_border = indices_of_border.result() + maps["segmean"][indices_of_border] = seg_mean_of_border + + # 1. considered (mean of) alternative label must be on the other side of norm as the (mean of) the + # segmentation label of the current voxel + is_switched_sign = cast(torch.BoolTensor, means_of_border > norm_of_border.unsqueeze(0)) \ + .logical_xor(cast(torch.BoolTensor(seg_mean_of_border > norm_of_border).unsqueeze(0))) + + # 3. (mean of) segmentation label must be different to norm of voxel + is_valid = is_switched_sign.logical_and(unsqueeze(seg_mean_of_border != norm_of_border, 0)) + none_valid = ~is_valid.any(dim=0, keepdim=False) + + # select the label, that is valid or not valid but also exists and is not the current label + nbr_label_of_border = _round(counts_of_border * is_valid, log_eps).argmax(dim=0, keepdim=False) + nbr_label_of_border[none_valid] = 0 + + # get the mean label intensity of the "alternative label" + nbr_mean_of_border = torch.take_along_dim(means_of_border, unsqueeze(nbr_label_of_border, 0), dim=0)[0] + + if return_maps: + maps["nbr"][indices_of_border] = labels_in_seg[nbr_label_of_border] + maps["nbrmean"][indices_of_border] = nbr_mean_of_border + + # interpolate between the "local" and "alternative label" + mean_to_mean_nbr = seg_mean_of_border - nbr_mean_of_border + delta_gt_eps = mean_to_mean_nbr.abs() > teps + # make sure no division by zero + pv_of_border = (norm_of_border - nbr_mean_of_border) / mean_to_mean_nbr.where(delta_gt_eps, teps) + + pv_of_border[~delta_gt_eps] = 1. # set pv fraction to 1 if division by zero + pv_of_border[none_valid] = 1. # set pv fraction to 1 for voxels that have no 'valid' nbr + pv_of_border[pv_of_border > 1.] = 1. + pv_of_border[pv_of_border < 0.] = 0. + + # re-create the "supposed" freesurfer inconsistency that does not count vertex neighbors, if the voxel label + # is not of question + ipv_of_border = (1. - pv_of_border) * mask_by_6border + + if return_maps: + maps["pv"][indices_of_border] = pv_of_border + maps["ipv"][indices_of_border] = ipv_of_border + + for lab in relevant_labels: + pv_vol = (pv_of_border * (seg_of_border == lab)).sum() + ipv_vol = (ipv_of_border * (nbr_label_of_border == lab)).sum() + volumes[lab] += (pv_vol + ipv_vol) * vox_vol # numpy algo # global_stats_filled = partial(global_stats_torch, @@ -1295,15 +1410,15 @@ def pv_calc_torch(seg: IntTensor, norm: Tensor, labels: Sequence[_IntType], patc # for lab in volumes.keys(): # volumes[lab] += vols.get(lab, 0.) * vox_vol - def item(voxel_counts: Dict[_IntType], lab: _IntType, default: _NumberType = 0) -> _NumberType: - v = voxel_counts.get(lab, 0) - if _HAS_TORCH and isinstance(v, Tensor): + def item(_arr: Dict[_IntType], lab: _IntType, default: _NumberType = 0) -> _NumberType: + v = _arr.get(lab, default) + if _HAS_TORCH and isinstance(v, torch.Tensor): return v.item() return v means = {lab: s / __voxel_counts[lab] for lab, s in sums.items() if item(__voxel_counts, lab) > eps} # *std = sqrt((sum * (*mean) - 2 * (*mean) * sum + sum2) / (nvoxels - 1)); - stds = {lab: np.sqrt((sums_2[lab] - means[lab] * sums[lab]) / (nvox - 1)) for lab, nvox in + stds = {lab: ((sums_2[lab] - means[lab] * sums[lab]) / (nvox - 1)).sqrt() for lab, nvox in __voxel_counts.items() if nvox > 1} # ColHeaders Index SegId NVoxels Volume_mm3 StructName normMean normStdDev normMin normMax normRange table = [{"SegId": lab, "NVoxels": item(voxel_counts, lab), "Volume_mm3": item(volumes, lab, 0.), @@ -1311,18 +1426,17 @@ def item(voxel_counts: Dict[_IntType], lab: _IntType, default: _NumberType = 0) "normMin": item(mins, lab, 0.), "normMax": item(maxes, lab, 0.), "normRange": item(maxes, lab, 0.) - item(mins, lab, 0.)} for lab in labels] if return_maps: - return table, {"nbr": full_nbr_label, "segmean": full_seg_mean, "nbrmean": full_nbr_mean, "pv": full_pv, - "ipv": full_ipv} + return table, maps return table - def global_stats_torch(lab: _IntType, norm: Tensor, seg: IntTensor, - out: Optional[BoolTensor] = None, robust_keep_fraction: Optional[float] = None) \ + def global_stats_torch(lab: _IntType, norm: 'torch.Tensor', seg: 'torch.IntTensor', + out: Optional['torch.BoolTensor'] = None, robust_keep_fraction: Optional[float] = None) \ -> Union[Tuple[_IntType, int], - Tuple[_IntType, int, int, _NumberType, _NumberType, float, float, float, BoolTensor]]: + Tuple[_IntType, int, int, _NumberType, _NumberType, float, float, float, 'torch.BoolTensor']]: """Computes Label, Number of voxels, 'robust' number of voxels, norm minimum, maximum, sum, sum of squares and 6-connected border of label lab.""" - bin_array = cast(BoolTensor, seg == lab) + bin_array = cast(torch.BoolTensor, seg == lab) data = norm[bin_array].to(dtype=torch.int if np.issubdtype(norm.dtype, np.integer) else torch.float) nvoxels: int = data.shape[0] # if lab is not in the image at all @@ -1353,18 +1467,19 @@ def global_stats_torch(lab: _IntType, norm: Tensor, seg: IntTensor, def pv_calc_patch_torch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...], - loc_border: Dict[_IntType, BoolTensor], - seg: IntTensor, norm: np.ndarray, border: BoolTensor, - full_pv: Optional[FloatTensor] = None, full_ipv: Optional[FloatTensor] = None, - full_nbr_label: Optional[IntTensor] = None, - full_seg_mean: Optional[FloatTensor] = None, - full_nbr_mean: Optional[FloatTensor] = None, eps: float = 1e-6) \ + loc_border: Dict[_IntType, 'torch.BoolTensor'], + seg: 'torch.IntTensor', norm: np.ndarray, border: 'torch.BoolTensor', + full_pv: Optional['torch.FloatTensor'] = None, + full_ipv: Optional['torch.FloatTensor'] = None, + full_nbr_label: Optional['torch.IntTensor'] = None, + full_seg_mean: Optional['torch.FloatTensor'] = None, + full_nbr_mean: Optional['torch.FloatTensor'] = None, eps: float = 1e-6) \ -> Dict[_IntType, float]: """Calculates PV for patch. If full* keyword arguments are passed, also fills, per voxel results for the respective voxels in the patch.""" log_eps = -int(np.log10(eps)) - teps = seg.new_full(1, fill_value=eps, dtype=torch.float) + teps = seg.new_full((1,), fill_value=eps, dtype=torch.float) patch = tuple(patch) patch_grow1, ungrow1_patch = grow_patch(patch, 1, seg.shape) @@ -1393,7 +1508,7 @@ def pv_calc_patch_torch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...] pat1d_norm, pat1d_seg = norm[patch][pat_border], seg[patch][pat_border] pat1d_label_counts = pat_label_counts[:, pat_border] # both sums and counts are normalized by n-hood-size**3, so the output is not anymore - pat1d_label_means = (pat_label_sums[:, pat_border] / torch.maximum(pat1d_label_counts, teps)).round(log_eps) # float + pat1d_label_means = _round(pat_label_sums[:, pat_border] / torch.maximum(pat1d_label_counts, teps), log_eps) # float # get the mean label intensity of the "local label" mean_label = torch.take_along_dim(pat1d_label_means, unsqueeze(label_lookup_fwd[pat1d_seg], 0), dim=0)[0] @@ -1401,21 +1516,22 @@ def pv_calc_patch_torch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...] pat1d_is_this_6border = pat_is_border[:, pat_border] # calculate which classes to consider: is_valid = reduce(torch.logical_and, - # 1. considered (mean of) alternative label must be on the other side of norm as the (mean of) the segmentation - # label of the current voxel - [torch.logical_xor(pat1d_label_means > pat1d_norm.unsqueeze(0), (mean_label > pat1d_norm).unsqueeze(0)), - # 2. considered (mean of) alternative label must be different to norm of voxel - pat1d_label_means != pat1d_norm.unsqueeze(0), - # 3. (mean of) segmentation label must be different to norm of voxel - torch.broadcast_to(unsqueeze(mean_label != pat1d_norm, 0), pat1d_label_means.shape), - # 4. label must be a neighbor - pat_is_nbr[:, pat_border], - # 5. label must not be the segmentation - pat1d_seg.unsqueeze(0) != label_lookup.unsqueeze(1)]) + # 1. considered (mean of) alternative label must be on the other side of norm as the + # (mean of) the segmentation label of the current voxel + [torch.logical_xor(cast(torch.BoolTensor, pat1d_label_means > pat1d_norm.unsqueeze(0)), + cast(torch.BoolTensor, mean_label > pat1d_norm).unsqueeze(0)), + # 2. considered (mean of) alternative label must be different to norm of voxel + pat1d_label_means != pat1d_norm.unsqueeze(0), + # 3. (mean of) segmentation label must be different to norm of voxel + torch.broadcast_to(unsqueeze(mean_label != pat1d_norm, 0), pat1d_label_means.shape), + # 4. label must be a neighbor + pat_is_nbr[:, pat_border], + # 5. label must not be the segmentation + pat1d_seg.unsqueeze(0) != label_lookup.unsqueeze(1)]) none_valid = ~is_valid.any(dim=0, keepdim=False) # select the label, that is valid or not valid but also exists and is not the current label - max_counts_index = (pat1d_label_counts * is_valid).round(log_eps).argmax(dim=0, keepdim=False) + max_counts_index = _round(pat1d_label_counts * is_valid, log_eps).argmax(dim=0, keepdim=False) nbr_label = label_lookup[max_counts_index] # label with max_counts nbr_label[none_valid] = 0 @@ -1435,7 +1551,8 @@ def pv_calc_patch_torch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...] # re-create the "supposed" freesurfer inconsistency that does not count vertex neighbors, if the voxel label # is not of question - mask_by_6border = torch.take_along_dim(pat1d_is_this_6border, unsqueeze(label_lookup_fwd[nbr_label], 0), dim=0)[0] + mask_by_6border = torch.take_along_dim(pat1d_is_this_6border, unsqueeze(label_lookup_fwd[nbr_label], 0), dim=0)[ + 0] pat1d_inv_pv = (1. - pat1d_pv) * mask_by_6border if full_pv is not None: @@ -1453,27 +1570,30 @@ def pv_calc_patch_torch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...] label_lookup} - def label_stats_torch(label: _IntType, global_crop: Tuple[slice, ...], - seg: IntTensor, norm: Tensor, robust_keep_fraction: Optional[float] = None, - eps: float = 1e-6): + def label_stats_torch(label: _IntType, out_counts: 'torch.FloatTensor', out_means: 'torch.FloatTensor', + out_6border: Optional['torch.BoolTensor'], seg: 'torch.IntTensor', norm: 'torch.Tensor', + indices: 'torch.LongTensor', border: 'torch.BoolTensor', + robust_keep_fraction: Optional[float] = None, eps: float = 1e-6): """Calculates PV for patch. If full* keyword arguments are passed, also fills, per voxel results for the respective - voxels in the patch.""" + voxels in the patch. + + Note: + This function has internal device synchronization points (e.g. indexing by bin_lab)""" # function constants log_eps = -int(np.log10(eps)) - teps = seg.new_full(1, fill_value=eps, dtype=torch.float) + teps = seg.new_full((1,), fill_value=eps, dtype=torch.float) - bin_lab = cast(BoolTensor, seg == label) - any_label, crop2label = crop_patch_to_mask(global_crop, bin_lab, delayed_optim=True) - # if lab is not in the image at all + bin_lab = cast(torch.BoolTensor, seg == label) + any_label, crop2label = crop_patch_to_mask(bin_lab) + # if lab is not in the image at all, if called from main(), this should never happen if not any_label: return label, 0 crop2label, _ = grow_patch(crop2label, 1, seg.shape) - crop2label_offset = torch.as_tensor([sl.start for sl in crop2label], dtype=torch.int, device=seg.device) - crop2label_from_global_crop = tuple(slice(lc.start - gc.start, lc.stop - gc.start) for lc, gc in zip(crop2label, global_crop)) - bin_cropped2label = bin_lab[crop2label_from_global_crop] - border = borders(bin_cropped2label, True) - intensities_of_label = norm[crop2label_from_global_crop][bin_cropped2label].to(dtype=torch.float if torch.is_floating_point(norm) else torch.int) + border_cropped = border[crop2label] + if out_6border is not None: + _6border = borders(bin_lab[crop2label], True) + intensities_of_label = norm[crop2label][bin_lab[crop2label]] nvoxels: int = intensities_of_label.shape[0] # compute/update the border @@ -1481,36 +1601,43 @@ def label_stats_torch(label: _IntType, global_crop: Tuple[slice, ...], intensities_of_label = torch.sort(intensities_of_label) sym_drop_samples = int((1 - robust_keep_fraction / 2) * nvoxels) intensities_of_label = intensities_of_label[sym_drop_samples:-sym_drop_samples] - _min: _NumberType = intensities_of_label[0].cpu() - _max: _NumberType = intensities_of_label[-1].cpu() + _min: _NumberType = intensities_of_label[0] + _max: _NumberType = intensities_of_label[-1] __voxel_count = nvoxels - 2 * sym_drop_samples else: - _min = intensities_of_label.min().cpu() - _max = intensities_of_label.max().cpu() + _min = intensities_of_label.min() + _max = intensities_of_label.max() __voxel_count = nvoxels - _sum: float = intensities_of_label.sum().cpu() - sum_2: float = (intensities_of_label * intensities_of_label).sum().cpu() + _sum: float = intensities_of_label.sum() + sum_2: float = (intensities_of_label * intensities_of_label).sum() # this is independent of the robustness criterium - volume: float = torch.logical_and(bin_cropped2label, border.logical_not()).sum().to(dtype=torch.float).cpu() + volume: float = torch.logical_and(bin_lab[crop2label], border_cropped.logical_not()).sum().to(dtype=torch.float) # implicitly also a border detection: is lab a neighbor of the "current voxel" # lab is at least once a nbr in the patch (grown by one) - is_nbr = uniform_filter(bin_cropped2label, 3) > teps # as float (*filter_size**3) - label_counts = uniform_filter(bin_cropped2label, 15) # as float (*filter_size**3) - filtered_norm = norm[crop2label_from_global_crop] * bin_cropped2label - label_means = (uniform_filter(filtered_norm, 15) / label_counts).round(log_eps) + is_nbr = uniform_filter(bin_lab[crop2label], 3) > teps # as float (*filter_size**3) + label_counts = uniform_filter(bin_lab[crop2label], 15) # as float (*filter_size**3) + filtered_norm = norm[crop2label] * bin_lab[crop2label] + + label_means = _round(uniform_filter(filtered_norm, 15) / label_counts, log_eps) # conditions: - # 4. label must be a neighbor - # 5. label must not be the segmentation + # 4. label must be a neighbor (implicitly also border) + alt_mask = is_nbr.logical_and(border_cropped) # 2. considered (mean of) alternative label must be different to norm of voxel - mean_not_equal_to_norm = label_means != norm[crop2label_from_global_crop] - alt_mask = is_nbr.logical_and(bin_cropped2label.logical_not()).logical_and(mean_not_equal_to_norm) - indices = alt_mask.argwhere() + crop2label_offset - label_means_mask = label_means[alt_mask] - label_counts_mask = label_counts[alt_mask] + alt_mask = alt_mask.logical_and(label_means != norm[crop2label]) + # 5. label must not be the segmentation (bin_lab is where label is the segmentation) + alt_mask = alt_mask.logical_and(bin_lab[crop2label].logical_not()) + # for counts, we want all voxels that fulfill criteria 2., 4., 5. + out_counts[indices[crop2label][alt_mask]] = label_counts[alt_mask] - return label, nvoxels, __voxel_count, _min, _max, _sum, sum_2, volume, border, crop2label, indices, label_means_mask, label_counts_mask + # for the means, we want all border voxels to be set + indices_border = indices[crop2label][border_cropped] + out_means[indices_border] = label_means[border_cropped] + if out_6border is not None: + out_6border[indices_border] = _6border[border_cropped] + + return label, nvoxels, __voxel_count, _min, _max, _sum, sum_2, volume, out_counts, out_means, out_6border def patch_neighbors(labels, norm, seg, pat_border, loc_border, patch_grow7, patch_in_gc, patch_shrink6, @@ -1521,10 +1648,11 @@ def patch_neighbors(labels, norm, seg, pat_border, loc_border, patch_grow7, patc pat_label_counts, pat_label_sums = ndarray_alloc((2,) + loc_shape, fill_value=0., dtype=float) pat_is_nbr, pat_is_border = ndarray_alloc((2,) + loc_shape, fill_value=False, dtype=bool) for i, lab in enumerate(labels): - pat7_bin_array = cast(BoolTensor, seg[patch_grow7] == lab) + pat7_bin_array = cast(torch.BoolTensor, seg[patch_grow7] == lab) # implicitly also a border detection: is lab a neighbor of the "current voxel" tmp_nbr_label_counts = uniform_filter(pat7_bin_array[patch_shrink6], 3) # as float (*filter_size**3) - if _HAS_TORCH and isinstance(seg, Tensor) or tmp_nbr_label_counts.sum() > eps: # TODO: check if this actually faster for torch than doing the if + if _HAS_TORCH and isinstance(seg, + torch.Tensor) or tmp_nbr_label_counts.sum() > eps: # TODO: check if this actually faster for torch than doing the if # lab is at least once a nbr in the patch (grown by one) if lab in loc_border: pat_is_border[i] = loc_border[lab][patch_in_gc] @@ -1545,5 +1673,6 @@ def patch_neighbors(labels, norm, seg, pat_border, loc_border, patch_grow7, patc # main(make_arguments().parse_args('-norm $TSUB/mri/norm.mgz -i $TSUB/mri/wmparc.DKTatlas.mapped.mgz \ # -o $TSUB/stats/wmparc.DKTatlas.mapped.pyvstats --lut $FREESURFER/WMParcStatsLUT.txt'.split(' ')))" import sys + args = make_arguments() sys.exit(main(args.parse_args()))