From c39f18c2326a6a75ec4ef13f51733d5a9f5df5f4 Mon Sep 17 00:00:00 2001 From: engrosamaali91 Date: Thu, 1 Feb 2024 10:09:42 +0100 Subject: [PATCH 1/7] modified numpy docstring continuation --- FastSurferCNN/data_loader/augmentation.py | 171 ++++++------ FastSurferCNN/data_loader/conform.py | 148 +++++----- FastSurferCNN/data_loader/data_utils.py | 323 +++++++++++----------- pyproject.toml | 2 +- 4 files changed, 324 insertions(+), 320 deletions(-) diff --git a/FastSurferCNN/data_loader/augmentation.py b/FastSurferCNN/data_loader/augmentation.py index 7499bb3e..a1688765 100644 --- a/FastSurferCNN/data_loader/augmentation.py +++ b/FastSurferCNN/data_loader/augmentation.py @@ -25,27 +25,28 @@ # Transformations for evaluation ## class ToTensorTest(object): - """Convert np.ndarrays in sample to Tensors. + """ + Convert np.ndarrays in sample to Tensors. Methods ------- __call__ - converts image + Converts image. """ def __call__(self, img: npt.NDArray) -> np.ndarray: - """Convert the image to float within range [0, 1] and make it torch compatible. + """ + Convert the image to float within range [0, 1] and make it torch compatible. Parameters ---------- img : npt.NDArray - Image to be converted + Image to be converted. Returns ------- img : np.ndarray - Conformed image - + Conformed image. """ img = img.astype(np.float32) @@ -61,37 +62,37 @@ def __call__(self, img: npt.NDArray) -> np.ndarray: class ZeroPad2DTest(object): - """Pad the input with zeros to get output size. + """ + Pad the input with zeros to get output size. Attributes ---------- output_size : Union[Number, Tuple[Number, Number]] - size of the output image either as Number or tuple of two Number + Size of the output image either as Number or tuple of two Number. pos : str - position to put the input + Position to put the input. Methods ------- pad - pad zeroes of image + Pad zeroes of image. call - call _pad() + Call _pad(). """ - def __init__( self, output_size: Union[Number, Tuple[Number, Number]], pos: str = 'top_left' ): - """Construct object. + """ + Construct object. Parameters ---------- output_size : Union[Number, Tuple[Number, Number]] - size of the output image either as Number or tuple of two Number + Size of the output image either as Number or tuple of two Number. pos : Union[Number, Tuple[Number, Number]] - position to put the input. Defaults to 'top_left' - + Position to put the input. Defaults to 'top_left'. """ if isinstance(output_size, Number): output_size = (int(output_size),) * 2 @@ -99,18 +100,18 @@ def __init__( self.pos = pos def _pad(self, image: npt.NDArray) -> np.ndarray: - """Pad with zeros of the input image. + """ + Pad with zeros of the input image. Parameters ---------- image : npt.NDArray - The image to pad + The image to pad. Returns ------- padded_img : np.ndarray - original image with padded zeros - + Original image with padded zeros. """ if len(image.shape) == 2: h, w = image.shape @@ -125,18 +126,18 @@ def _pad(self, image: npt.NDArray) -> np.ndarray: return padded_img def __call__(self, img: npt.NDArray) -> np.ndarray: - """Call the _pad() function. + """ + Call the _pad() function. Parameters ---------- img : npt.NDArray - the image to pad + The image to pad. Returns ------- img : np.ndarray - original image with padded zeros - + Original image with padded zeros. """ img = self._pad(img) @@ -147,28 +148,28 @@ def __call__(self, img: npt.NDArray) -> np.ndarray: # Transformations for training ## class ToTensor(object): - """Convert ndarrays in sample to Tensors. + """ + Convert ndarrays in sample to Tensors. Methods ------- __call__ - Convert image - + Convert image. """ def __call__(self, sample: npt.NDArray) -> Dict[str, Any]: - """Convert the image to float within range [0, 1] and make it torch compatible. + """ + Convert the image to float within range [0, 1] and make it torch compatible. Parameters ---------- sample : npt.NDArray - sample image + Sample image. Returns ------- Dict[str, Any] - Converted image - + Converted image. """ img, label, weight, sf = ( sample["img"], @@ -196,39 +197,38 @@ def __call__(self, sample: npt.NDArray) -> Dict[str, Any]: class ZeroPad2D(object): - """Pad the input with zeros to get output size. + """ + Pad the input with zeros to get output size. Attributes ---------- output_size : Union[Number, Tuple[Number, Number]] - Size of the output image either as Number or tuple of two Number + Size of the output image either as Number or tuple of two Number. pos : str, Optional - Position to put the input + Position to put the input. Methods ------- _pad - Pads zeroes of image + Pads zeroes of image. __call__ - Cals _pad for sample - + Cals _pad for sample. """ - def __init__( self, output_size: Union[Number, Tuple[Number, Number]], pos: Union[None, str] = 'top_left' ): - """Initialize position and output_size (as Tuple[float]). + """ + Initialize position and output_size (as Tuple[float]). Parameters ---------- output_size : Union[Number, Tuple[Number, Number]] Size of the output image either as Number or - tuple of two Number + tuple of two Number. pos : str, Optional - Position to put the input. Default = 'top_left' - + Position to put the input. Default = 'top_left'. """ if isinstance(output_size, Number): output_size = (int(output_size),) * 2 @@ -236,18 +236,18 @@ def __init__( self.pos = pos def _pad(self, image: npt.NDArray) -> np.ndarray: - """Pad the input image with zeros. + """ + Pad the input image with zeros. Parameters ---------- image : npt.NDArray - The image to pad + The image to pad. Returns ------- padded_img : np.ndarray - Original image with padded zeros - + Original image with padded zeros. """ if len(image.shape) == 2: h, w = image.shape @@ -262,18 +262,18 @@ def _pad(self, image: npt.NDArray) -> np.ndarray: return padded_img def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Pad the image, label and weights. + """ + Pad the image, label and weights. Parameters ---------- - sample :Dict[str, Any] - Sample image + sample : Dict[str, Any] + Sample image. Returns ------- Dict[str, Any] - Dictionary including the padded image, label, weight and scale factor - + Dictionary including the padded image, label, weight and scale factor. """ img, label, weight, sf = ( sample["img"], @@ -290,48 +290,48 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: class AddGaussianNoise(object): - """Add gaussian noise to sample. + """ + Add gaussian noise to sample. Attributes ---------- std - Standard deviation + Standard deviation. mean - Gaussian mean + Gaussian mean. Methods ------- __call__ - Adds noise to scale factor + Adds noise to scale factor. """ - def __init__(self, mean: Real = 0, std: Real = 0.1): - """Construct object. + """ + Construct object. Parameters ---------- mean : Real - Standard deviation. Default = 0 + Standard deviation. Default = 0. std : Real - Gaussian mean. Default = 0.1 - + Gaussian mean. Default = 0.1. """ self.std = std self.mean = mean def __call__(self, sample: Dict[str, Real]) -> Dict[str, Real]: - """Add gaussian noise to scalefactor. + """ + Add gaussian noise to scalefactor. Parameters ---------- - sample :Dict[str, Real] - Sample data to add noise + sample : Dict[str, Real] + Sample data to add noise. Returns ------- Dict[str, Real] - Sample with noise - + Sample with noise. """ img, label, weight, sf = ( sample["img"], @@ -345,37 +345,36 @@ def __call__(self, sample: Dict[str, Real]) -> Dict[str, Real]: class AugmentationPadImage(object): - """Pad Image with either zero padding or reflection padding of img, label and weight. + """ + Pad Image with either zero padding or reflection padding of img, label and weight. Attributes ---------- pad_size_imag - [missing] + [missing]. pad_size_mask - [missing] + [missing]. Methods ------- __call - add zeroes - + Add zeroes. """ - def __init__( self, pad_size: Tuple[Tuple[int, int], Tuple[int, int]] = ((16, 16), (16, 16)), pad_type: str = "edge" ): - """Construct object. + """ + Construct object. Attributes ---------- pad_size - [MISSING] + [MISSING]. pad_type - [MISSING] - + [MISSING]. """ assert isinstance(pad_size, (int, tuple)) @@ -391,13 +390,13 @@ def __init__( self.pad_type = pad_type def __call__(self, sample: Dict[str, Number]): - """Pad zeroes of sample image, label and weight. + """ + Pad zeroes of sample image, label and weight. Attributes ---------- sample : Dict[str, Number] - Sample image and data - + Sample image and data. """ img, label, weight, sf = ( sample["img"], @@ -414,7 +413,9 @@ def __call__(self, sample: Dict[str, Number]): class AugmentationRandomCrop(object): - """Randomly Crop Image to given size.""" + """ + Randomly Crop Image to given size. + """ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): """Construct object. @@ -422,9 +423,9 @@ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): Attributes ---------- output_size - Size of the output image either an integer or a tuple + Size of the output image either an integer or a tuple. crop_type - [MISSING] + [MISSING]. """ assert isinstance(output_size, (int, tuple)) @@ -437,18 +438,18 @@ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): self.crop_type = crop_type def __call__(self, sample: Dict[str, Number]) -> Dict[str, Number]: - """Crops the augmentation. + """ + Crops the augmentation. Attributes ---------- sample : Dict[str, Number] - Sample image with data + Sample image with data. Returns ------- Dict[str, Number] - Cropped sample image - + Cropped sample image. """ img, label, weight, sf = ( sample["img"], diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index f167a6d3..6e283b3a 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -30,7 +30,9 @@ ) HELPTEXT = """ -Script to conform an MRI brain image to UCHAR, RAS orientation, and 1mm or minimal isotropic voxels +Script to conform an MRI brain image to UCHAR, RAS orientation, +and 1mm or minimal isotropic voxels + USAGE: conform.py -i -o OR @@ -51,13 +53,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = argparse.ArgumentParser(usage=HELPTEXT) parser.add_argument( @@ -145,28 +147,28 @@ def map_image( order: int = 1, dtype: Optional[Type] = None ) -> np.ndarray: - """Map image to new voxel space (RAS orientation). + """ + Map image to new voxel space (RAS orientation). Parameters ---------- img : nib.analyze.SpatialImage - the src 3D image with data and affine set + The src 3D image with data and affine set. out_affine : np.ndarray - trg image affine + Trg image affine. out_shape : tuple[int, ...], np.ndarray - the trg shape information + The trg shape information. ras2ras : Optional[np.ndarray] - an additional mapping that should be applied (default=id to just reslice) + An additional mapping that should be applied (default=id to just reslice). order : int - order of interpolation (0=nearest,1=linear(default),2=quadratic,3=cubic) + Order of interpolation (0=nearest,1=linear(default),2=quadratic,3=cubic). dtype : Optional[Type] - target dtype of the resulting image (relevant for reorientation, default=same as img) + Target dtype of the resulting image (relevant for reorientation, default=same as img). Returns ------- np.ndarray - mapped image data array - + Mapped image data array. """ from scipy.ndimage import affine_transform from numpy.linalg import inv @@ -219,30 +221,30 @@ def getscale( f_low: float = 0.0, f_high: float = 0.999 ) -> Tuple[float, float]: - """Get offset and scale of image intensities to robustly rescale to range dst_min..dst_max. + """ + Get offset and scale of image intensities to robustly rescale to range dst_min..dst_max. Equivalent to how mri_convert conforms images. Parameters ---------- data : np.ndarray - image data (intensity values) + Image data (intensity values). dst_min : float - future minimal intensity value + Future minimal intensity value. dst_max : float - future maximal intensity value + Future maximal intensity value. f_low : float - robust cropping at low end (0.0 no cropping, default) + Robust cropping at low end (0.0 no cropping, default). f_high : float - robust cropping at higher end (0.999 crop one thousandth of high intensity voxels, default) + Robust cropping at higher end (0.999 crop one thousandth of high intensity voxels, default). Returns ------- float src_min - (adjusted) offset + (adjusted) offset. float - scale factor - + Scale factor. """ # get min and max from source src_min = np.min(data) @@ -318,26 +320,26 @@ def scalecrop( src_min: float, scale: float ) -> np.ndarray: - """Crop the intensity ranges to specific min and max values. + """ + Crop the intensity ranges to specific min and max values. Parameters ---------- data : np.ndarray - Image data (intensity values) + Image data (intensity values). dst_min : float - future minimal intensity value + Future minimal intensity value. dst_max : float - future maximal intensity value + Future maximal intensity value. src_min : float - minimal value to consider from source (crops below) + Minimal value to consider from source (crops below). scale : float - scale value by which source will be shifted + Scale value by which source will be shifted. Returns ------- np.ndarray - scaled image data - + Scaled image data. """ data_new = dst_min + scale * (data - src_min) @@ -357,26 +359,26 @@ def rescale( f_low: float = 0.0, f_high: float = 0.999 ) -> np.ndarray: - """Rescale image intensity values (0-255). + """ + Rescale image intensity values (0-255). Parameters ---------- data : np.ndarray - image data (intensity values) + Image data (intensity values). dst_min : float - future minimal intensity value + Future minimal intensity value. dst_max : float - future maximal intensity value + Future maximal intensity value. f_low : float - robust cropping at low end (0.0 no cropping, default) + Robust cropping at low end (0.0 no cropping, default). f_high : float - robust cropping at higher end (0.999 crop one thousandth of high intensity voxels, default) + Robust cropping at higher end (0.999 crop one thousandth of high intensity voxels, default). Returns ------- np.ndarray - scaled image data - + Scaled image data. """ src_min, scale = getscale(data, dst_min, dst_max, f_low, f_high) data_new = scalecrop(data, dst_min, dst_max, src_min, scale) @@ -384,24 +386,24 @@ def rescale( def find_min_size(img: nib.analyze.SpatialImage, max_size: float = 1) -> float: - """Find minimal voxel size <= 1mm. + """ + Find minimal voxel size <= 1mm. Parameters ---------- img : nib.analyze.SpatialImage - loaded source image + Loaded source image. max_size : float - maximal voxel size in mm (default: 1.0) + Maximal voxel size in mm (default: 1.0). Returns ------- float - Rounded minimal voxel size + Rounded minimal voxel size. Notes ----- This function only needs the header (not the data). - """ # find minimal voxel side length sizes = np.array(img.header.get_zooms()[:3]) @@ -415,18 +417,19 @@ def find_img_size_by_fov( vox_size: float, min_dim: int = 256 ) -> int: - """Find the cube dimension (>= 256) to cover the field of view of img. + """ + Find the cube dimension (>= 256) to cover the field of view of img. If vox_size is one, the img_size MUST always be min_dim (the FreeSurfer standard). Parameters ---------- img : nib.analyze.SpatialImage - loaded source image + Loaded source image. vox_size : float - the target voxel size in mm + The target voxel size in mm. min_dim : int - minimal image dimension in voxels (default 256) + Minimal image dimension in voxels (default 256). Returns ------- @@ -436,7 +439,6 @@ def find_img_size_by_fov( Notes ----- This function only needs the header (not the data). - """ if vox_size == 1.0: return min_dim @@ -467,31 +469,30 @@ def conform( Parameters ---------- img : nib.analyze.SpatialImage - loaded source image + Loaded source image. order : int - interpolation order (0=nearest,1=linear(default),2=quadratic,3=cubic) + Interpolation order (0=nearest,1=linear(default),2=quadratic,3=cubic). conform_vox_size : VoxSizeOption - conform image the image to voxel size 1. (default), a + Conform image the image to voxel size 1. (default), a specific smaller voxel size (0-1, for high-res), or automatically determine the 'minimum voxel size' from the image (value 'min'). This assumes the smallest of the three voxel sizes. dtype : Optional[Type] - the dtype to enforce in the image (default: UCHAR, as mri_convert -c) + The dtype to enforce in the image (default: UCHAR, as mri_convert -c). conform_to_1mm_threshold : Optional[float] - the threshold above which the image is conformed to 1mm + The threshold above which the image is conformed to 1mm (default: ignore). Returns ------- nib.MGHImage - conformed image + Conformed image. Notes ----- Unlike mri_convert -c, we first interpolate (float image), and then rescale to uchar. mri_convert is doing it the other way around. However, we compute the scale factor from the input to increase similarity. - """ from nibabel.freesurfer.mghformat import MGHHeader @@ -577,33 +578,34 @@ def is_conform( verbose: bool = True, conform_to_1mm_threshold: Optional[float] = None ) -> bool: - """Check if an image is already conformed or not. + """ + Check if an image is already conformed or not. Dimensions: 256x256x256, Voxel size: 1x1x1, LIA orientation, and data type UCHAR. Parameters ---------- img : nib.analyze.SpatialImage - Loaded source image + Loaded source image. conform_vox_size : VoxSizeOption - which voxel size to conform to. Can either be a float between 0.0 and + Which voxel size to conform to. Can either be a float between 0.0 and 1.0 or 'min' check, whether the image is conformed to the minimal voxels size, i.e. conforming to smaller, but isotropic voxel sizes for high-res (default: 1.0). eps : float - allowed deviation from zero for LIA orientation check (default: 1e-06). + Allowed deviation from zero for LIA orientation check (default: 1e-06). Small inaccuracies can occur through the inversion operation. Already conformed images are thus sometimes not correctly recognized. The epsilon accounts for these small shifts. check_dtype : bool - specifies whether the UCHAR dtype condition is checked for; + Specifies whether the UCHAR dtype condition is checked for; this is not done when the input is a segmentation (default: True). dtype : Optional[Type] - specifies the intended target dtype (default: uint8 = UCHAR) + Specifies the intended target dtype (default: uint8 = UCHAR). verbose : bool - if True, details of which conformance conditions are violated (if any) + If True, details of which conformance conditions are violated (if any) are displayed (default: True). conform_to_1mm_threshold : Optional[float] - the threshold above which the image is conformed to 1mm + The threshold above which the image is conformed to 1mm (default: ignore). Returns @@ -614,7 +616,6 @@ def is_conform( Notes ----- This function only needs the header (not the data). - """ conformed_vox_size, conformed_img_size = get_conformed_vox_img_size( img, conform_vox_size, conform_to_1mm_threshold=conform_to_1mm_threshold @@ -683,7 +684,8 @@ def get_conformed_vox_img_size( conform_vox_size: VoxSizeOption, conform_to_1mm_threshold: Optional[float] = None ) -> Tuple[float, int]: - """Extract the voxel size and the image size. + """ + Extract the voxel size and the image size. This function only needs the header (not the data). @@ -702,9 +704,9 @@ def get_conformed_vox_img_size( Returns ------- conformed_vox_size : float - The conformed voxel size of the image. + The determined voxel size to conform the image to. conformed_img_size : int - The conformed image size of the image. + The size of the image adjusted to the conformed voxel size. """ # this is similar to mri_convert --conform_min if isinstance(conform_vox_size, str) and conform_vox_size.lower() in [ @@ -730,7 +732,8 @@ def check_affine_in_nifti( img: Union[nib.Nifti1Image, nib.Nifti2Image], logger: Optional[logging.Logger] = None ) -> bool: - """Check the affine in nifti Image. + """ + Check the affine in nifti Image. Sets affine with qform, if it exists and differs from sform. If qform does not exist, voxel sizes between header information and information @@ -740,16 +743,15 @@ def check_affine_in_nifti( Parameters ---------- img : Union[nib.Nifti1Image, nib.Nifti2Image] - loaded nifti-image + Loaded nifti-image. logger : Optional[logging.Logger] Logger object or None (default) to log or print an info message to - stdout (for None) + stdout (for None). Returns ------- - False, if - voxel sizes in affine and header differ - + bool + False, if voxel sizes in affine and header differ. """ check = True message = "" diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index d1098e10..86ce14ee 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -54,37 +54,37 @@ def load_and_conform_image( logger: logging.Logger = LOGGER, conform_min: bool = False ) -> Tuple[_Header, np.ndarray, np.ndarray]: - """Load MRI image and conform it to UCHAR, RAS orientation and 1mm or minimum isotropic voxels size. + """ + Load MRI image and conform it to UCHAR, RAS orientation and 1mm or minimum isotropic voxels size. Only, if it does not already have this format. Parameters ---------- img_filename : str - path and name of volume to read + Path and name of volume to read. interpol : int - interpolation order for image conformation (0=nearest,1=linear(default),2=quadratic,3=cubic) + Interpolation order for image conformation (0=nearest,1=linear(default),2=quadratic,3=cubic). logger : logging.Logger - Logger to write output to (default = STDOUT) + Logger to write output to (default = STDOUT). conform_min : bool - conform image to minimal voxel size (for high-res) (Default = False) + Conform image to minimal voxel size (for high-res) (Default = False). Returns ------- nibabel.Header header_info - header information of the conformed image + Header information of the conformed image. numpy.ndarray affine_info - affine information of the conformed image + Affine information of the conformed image. numpy.ndarray orig_data - conformed image data + Conformed image data. Raises ------ RuntimeError - Multiple input frames not supported + Multiple input frames not supported. RuntimeError - Inconsistency in nifti-header - + Inconsistency in nifti-header. """ orig = nib.load(img_filename) # is_conform and conform accept numeric values and the string 'min' instead of the bool value @@ -121,15 +121,17 @@ def load_image( name: str = "image", **kwargs ) -> Tuple[nib.analyze.SpatialImage, np.ndarray]: - """Load file 'file' with nibabel, including all data. + """ + Load file 'file' with nibabel, including all data. Parameters ---------- file : str - path to the file to load. + Path to the file to load. name : str - name of the file (optional), only effects error messages. (Default value = "image") + Name of the file (optional), only effects error messages. (Default value = "image"). **kwargs : + Additional keyword arguments. Returns ------- @@ -149,7 +151,6 @@ def load_image( image, data = future1.result() image2, data2 = future2.result() } - """ try: img = nib.load(file, **kwargs) @@ -166,22 +167,22 @@ def load_maybe_conform( alt_file: str, vox_size: VoxSizeOption = "min" ) -> Tuple[str, nib.analyze.SpatialImage, np.ndarray]: - """Load an image by file, check whether it is conformed to vox_size and conform to vox_size if it is not. + """ + Load an image by file, check whether it is conformed to vox_size and conform to vox_size if it is not. Parameters ---------- file : str - path to the file to load. + Path to the file to load. alt_file : str - alternative file to interpolate from + Alternative file to interpolate from. vox_size : VoxSizeOption - Voxel Size (Default value = "min") + Voxel Size (Default value = "min"). Returns ------- Tuple[str, nib.analyze.SpatialImage, np.ndarray] - [MISSING] - + [MISSING]. """ from os.path import isfile @@ -243,24 +244,24 @@ def save_image( save_as: str, dtype: Optional[npt.DTypeLike] = None ) -> None: - """Save an image (nibabel MGHImage), according to the desired output file format. + """ + Save an image (nibabel MGHImage), according to the desired output file format. - Supported formats are defined in supported_output_file_formats. Saves predictions to save_as + Supported formats are defined in supported_output_file_formats. Saves predictions to save_as. Parameters ---------- header_info : _Header - image header information + Image header information. affine_info : npt.NDArray - image affine information + Image affine information. img_array : npt.NDArray - an array containing image data + An array containing image data. save_as : str - name under which to save prediction; this determines output file format + Name under which to save prediction; this determines output file format. dtype : Optional[npt.DTypeLike] - image array type; if provided, the image object is explicitly set to match this type - (Default value = None) - + Image array type; if provided, the image object is explicitly set to match this type + (Default value = None). """ assert any( save_as.endswith(file_ext) for file_ext in SUPPORTED_OUTPUT_FILE_FORMATS @@ -291,20 +292,20 @@ def transform_axial( vol: npt.NDArray, coronal2axial: bool = True ) -> np.ndarray: - """Transform volume into Axial axis and back. + """ + Transform volume into Axial axis and back. Parameters ---------- vol : npt.NDArray - image volume to transform + Image volume to transform. coronal2axial : bool - transform from coronal to axial = True (default), + Transform from coronal to axial = True (default). Returns ------- np.ndarray - Transformed image - + Transformed image. """ if coronal2axial: return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) @@ -316,20 +317,20 @@ def transform_sagittal( vol: npt.NDArray, coronal2sagittal: bool = True ) -> np.ndarray: - """Transform volume into Sagittal axis and back. + """ + Transform volume into Sagittal axis and back. Parameters ---------- vol : npt.NDArray - image volume to transform + Image volume to transform. coronal2sagittal : bool - transform from coronal to sagittal = True (default), + Transform from coronal to sagittal = True (default). Returns ------- np.ndarray: - transformed image - + Transformed image. """ if coronal2sagittal: return np.moveaxis(vol, [0, 1, 2], [2, 1, 0]) @@ -342,23 +343,23 @@ def get_thick_slices( img_data: npt.NDArray, slice_thickness: int = 3 ) -> np.ndarray: - """Extract thick slices from the image. + """ + Extract thick slices from the image. Feed slice_thickness preceding and succeeding slices to network, - label only middle one + label only middle one. Parameters ---------- img_data : npt.NDArray - 3D MRI image read in with nibabel + 3D MRI image read in with nibabel. slice_thickness : int - number of slices to stack on top and below slice of interest (default=3) + Number of slices to stack on top and below slice of interest (default=3). Returns ------- np.ndarray - image data with the thick slices of the n-th axis appended into the n+1-th axis. - + Image data with the thick slices of the n-th axis appended into the n+1-th axis. """ img_data_pad = np.pad( img_data, ((0, 0), (0, 0), (slice_thickness, slice_thickness)), mode="edge" @@ -376,28 +377,28 @@ def filter_blank_slices_thick( weight_vol: npt.NDArray, threshold: int = 50 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Filter blank slices from the volume using the label volume. + """ + Filter blank slices from the volume using the label volume. Parameters ---------- img_vol : npt.NDArray - orig image volume + Orig image volume. label_vol : npt.NDArray - label images (ground truth) + Label images (ground truth). weight_vol : npt.NDArray - weight corresponding to labels + Weight corresponding to labels. threshold : int - threshold for number of pixels needed to keep slice (below = dropped). (Default value = 50) + Threshold for number of pixels needed to keep slice (below = dropped). (Default value = 50). Returns ------- filtered img_vol : np.ndarray - [MISSING] + [MISSING]. label_vol : np.ndarray - [MISSING] + [MISSING]. weight_vol : np.ndarray - [MISSING] - + [MISSING]. """ # Get indices of all slices with more than threshold labels/pixels select_slices = np.sum(label_vol, axis=(0, 1)) > threshold @@ -421,32 +422,32 @@ def create_weight_mask( cortex_mask: bool = True, gradient: bool = True ) -> np.ndarray: - """Create weighted mask - with median frequency balancing and edge-weighting. + """ + Create weighted mask - with median frequency balancing and edge-weighting. Parameters ---------- mapped_aseg : np.ndarray - segmentation to create weight mask from. + Segmentation to create weight mask from. max_weight : int - maximal weight on median weights (cap at this value). (Default value = 5) + Maximal weight on median weights (cap at this value). (Default value = 5). max_edge_weight : int - maximal weight on gradient weight (cap at this value). (Default value = 5) + Maximal weight on gradient weight (cap at this value). (Default value = 5). max_hires_weight : int - maximal weight on hires weight (cap at this value). (Default value = None) + Maximal weight on hires weight (cap at this value). (Default value = None). ctx_thresh : int - label value of cortex (above = cortical parcels). (Default value = 33) + Label value of cortex (above = cortical parcels). (Default value = 33). mean_filter : bool - flag, set to add mean_filter mask (default = False). + Flag, set to add mean_filter mask (default = False). cortex_mask : bool - flag, set to create cortex weight mask (default=True). + Flag, set to create cortex weight mask (default=True). gradient : bool - flag, set to create gradient mask (default = True). + Flag, set to create gradient mask (default = True). Returns ------- np.ndarray - Weights - + Weights. """ unique, counts = np.unique(mapped_aseg, return_counts=True) @@ -495,22 +496,22 @@ def cortex_border_mask( structure: npt.NDArray, ctx_thresh: int = 33 ) -> np.ndarray: - """Erode the cortex of a given mri image to create the inner gray matter mask (outer most cortex voxels). + """ + Erode the cortex of a given mri image to create the inner gray matter mask (outer most cortex voxels). Parameters ---------- label : npt.NDArray - ground truth labels. + Ground truth labels. structure : npt.NDArray - structuring element to erode with + Structuring element to erode with. ctx_thresh : int - label value of cortex (above = cortical parcels). Defaults to 33 + Label value of cortex (above = cortical parcels). Defaults to 33. Returns ------- np.ndarray - inner grey matter layer - + Inner grey matter layer. """ # create aseg brainmask, erode it and subtract from itself bm = np.clip(label, a_max=1, a_min=0) @@ -529,24 +530,24 @@ def deep_sulci_and_wm_strand_mask( iteration: int = 1, ctx_thresh: int = 33 ) -> np.ndarray: - """Get a binary mask of deep sulci and small white matter strands by using binary closing (erosion and dilation). + """ + Get a binary mask of deep sulci and small white matter strands by using binary closing (erosion and dilation). Parameters ---------- volume : npt.NDArray - loaded image (aseg, label space) + Loaded image (aseg, label space). structure : npt.NDArray - structuring element (e.g. np.ones((3, 3, 3))) + Structuring element (e.g. np.ones((3, 3, 3))). iteration : int - number of times mask should be dilated + eroded. Defaults to 1 + Number of times mask should be dilated + eroded. Defaults to 1. ctx_thresh : int - label value of cortex (above = cortical parcels). Defaults to 33 + Label value of cortex (above = cortical parcels). Defaults to 33. Returns ------- np.ndarray - sulcus + wm mask - + Sulcus + wm mask. """ # Binarize label image (cortex = 1, everything else = 0) empty_im = np.zeros(shape=volume.shape) @@ -565,22 +566,22 @@ def deep_sulci_and_wm_strand_mask( # Label mapping functions (to aparc (eval) and to label (train)) def read_classes_from_lut(lut_file: str) -> pd.DataFrame: - """Read in FreeSurfer-like LUT table. + """ + Read in FreeSurfer-like LUT table. Parameters ---------- lut_file : str - path and name of FreeSurfer-style LUT file with classes of interest + Path and name of FreeSurfer-style LUT file with classes of interest Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 - 1 Left-Cerebral-Exterior 70 130 180 0 + 1 Left-Cerebral-Exterior 70 130 180 0. Returns ------- pd.Dataframe - DataFrame with ids present, name of ids, color for plotting - + DataFrame with ids present, name of ids, color for plotting. """ # Read in file separator = {"tsv": "\t", "csv": ",", "txt": " "} @@ -591,20 +592,20 @@ def map_label2aparc_aseg( mapped_aseg: torch.Tensor, labels: Union[torch.Tensor, npt.NDArray] ) -> torch.Tensor: - """Perform look-up table mapping from sequential label space to LUT space. + """ + Perform look-up table mapping from sequential label space to LUT space. Parameters ---------- mapped_aseg : torch.Tensor - label space segmentation (aparc.DKTatlas + aseg) + Label space segmentation (aparc.DKTatlas + aseg). labels : Union[torch.Tensor, npt.NDArray] - list of labels defining LUT space + List of labels defining LUT space. Returns ------- torch.Tensor - labels in LUT space - + Labels in LUT space. """ if isinstance(labels, np.ndarray): labels = torch.from_numpy(labels) @@ -613,24 +614,24 @@ def map_label2aparc_aseg( def clean_cortex_labels(aparc: npt.NDArray) -> np.ndarray: - """Clean up aparc segmentations. + """ + Clean up aparc segmentations. Map undetermined and optic chiasma to BKG Map Hypointensity classes to one Vessel to WM 5th Ventricle to CSF - Remaining cortical labels to BKG + Remaining cortical labels to BKG. Parameters ---------- aparc : npt.NDArray - aparc segmentations + Aparc segmentations. Returns ------- np.ndarray - cleaned aparc - + Cleaned aparc. """ aparc[aparc == 80] = 77 # Hypointensities Class aparc[aparc == 85] = 0 # Optic Chiasma to BKG @@ -650,22 +651,22 @@ def fill_unknown_labels_per_hemi( unknown_label: int, cortex_stop: int ) -> np.ndarray: - """Replace label 1000 (lh unknown) and 2000 (rh unknown) with closest class for each voxel. + """ + Replace label 1000 (lh unknown) and 2000 (rh unknown) with closest class for each voxel. Parameters ---------- gt : npt.NDArray - ground truth segmentation with class unknown + Ground truth segmentation with class unknown. unknown_label : int - class label for unknown (lh: 1000, rh: 2000) + Class label for unknown (lh: 1000, rh: 2000). cortex_stop : int - class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000) + Class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000). Returns ------- np.ndarray - ground truth segmentation with all classes - + Ground truth segmentation with all classes. """ # Define shape of image and dilation element h, w, d = gt.shape @@ -700,18 +701,18 @@ class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000) def fuse_cortex_labels(aparc: npt.NDArray) -> np.ndarray: - """Fuse cortical parcels on left/right hemisphere (reduce aparc classes). + """ + Fuse cortical parcels on left/right hemisphere (reduce aparc classes). Parameters ---------- aparc : npt.NDArray - anatomical segmentation with cortical parcels + Anatomical segmentation with cortical parcels. Returns ------- np.ndarray - anatomical segmentation with reduced number of cortical parcels - + Anatomical segmentation with reduced number of cortical parcels. """ aparc_temp = aparc.copy() @@ -748,18 +749,18 @@ def fuse_cortex_labels(aparc: npt.NDArray) -> np.ndarray: def split_cortex_labels(aparc: npt.NDArray) -> np.ndarray: - """Splot cortex labels to completely de-lateralize structures. + """ + Splot cortex labels to completely de-lateralize structures. Parameters ---------- aparc : npt.NDArray - anatomical segmentation and parcellation from network + Anatomical segmentation and parcellation from network. Returns ------- np.ndarray - re-lateralized aparc - + Re-lateralized aparc. """ # Post processing - Splitting classes # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025 @@ -840,24 +841,24 @@ def unify_lateralized_labels( lut: Union[str, pd.DataFrame], combi: Tuple[str, str] = ("Left-", "Right-") ) -> Mapping: - """Generate lookup dictionary of left-right labels. + """ + Generate lookup dictionary of left-right labels. Parameters ---------- lut : Union[str, pd.DataFrame] - either lut-file string to load or pandas dataframe + Either lut-file string to load or pandas dataframe Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 - 1 Left-Cerebral-Exterior 70 130 180 0 + 1 Left-Cerebral-Exterior 70 130 180 0. combi : Tuple[str, str] - Prefix or labelnames to combine. Default: Left- and Right- + Prefix or labelnames to combine. Default: Left- and Right-. Returns ------- Mapping - dictionary mapping between left and right hemispheres - + Dictionary mapping between left and right hemispheres. """ if isinstance(lut, str): lut = read_classes_from_lut(lut) @@ -873,7 +874,8 @@ def get_labels_from_lut( lut: Union[str, pd.DataFrame], label_extract: Tuple[str, str] = ("Left-", "ctx-rh") ) -> Tuple[np.ndarray, np.ndarray]: - """Extract labels from the lookup tables. + """ + Extract labels from the lookup tables. Parameters ---------- @@ -883,18 +885,17 @@ def get_labels_from_lut( Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 - 1 Left-Cerebral-Exterior 70 130 180 0 + 1 Left-Cerebral-Exterior 70 130 180 0. label_extract : Tuple[str, str] - suffix of label names to mask for sagittal labels - Default: "Left-" and "ctx-rh" + Suffix of label names to mask for sagittal labels + Default: "Left-" and "ctx-rh". Returns ------- np.ndarray - full label list, + Full label list. np.ndarray - sagittal label list - + Sagittal label list. """ if isinstance(lut, str): lut = read_classes_from_lut(lut) @@ -910,32 +911,32 @@ def map_aparc_aseg2label( aseg_nocc: Optional[npt.NDArray] = None, processing: str = "aparc" ) -> Tuple[np.ndarray, np.ndarray]: - """Perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space. + """ + Perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space. Parameters ---------- aseg : npt.NDArray - ground truth aparc+aseg + Ground truth aparc+aseg. labels : npt.NDArray - labels to use (extracted from LUT with get_labels_from_lut) + Labels to use (extracted from LUT with get_labels_from_lut). labels_sag : npt.NDArray - sagittal labels to use (extracted from LUT with - get_labels_from_lut) + Sagittal labels to use (extracted from LUT with + get_labels_from_lut). sagittal_lut_dict : Mapping - left-right label mapping (can be extracted with - unify_lateralized_labels from LUT) + Left-right label mapping (can be extracted with + unify_lateralized_labels from LUT). aseg_nocc : Optional[npt.NDArray] - ground truth aseg without corpus callosum segmentation (Default value = None) + Ground truth aseg without corpus callosum segmentation (Default value = None). processing : str - should be set to "aparc" or "aseg" for additional mappings (hard-coded) (Default value = "aparc") + Should be set to "aparc" or "aseg" for additional mappings (hard-coded) (Default value = "aparc"). Returns ------- np.ndarray - mapped aseg for coronal and axial, + Mapped aseg for coronal and axial. np.ndarray - mapped aseg for sagital - + Mapped aseg for sagital. """ # If corpus callosum is not removed yet, do it now if aseg_nocc is not None: @@ -1001,18 +1002,18 @@ def map_aparc_aseg2label( def sagittal_coronal_remap_lookup(x: int) -> int: - """Convert left labels to corresponding right labels for aseg with dictionary mapping. + """ + Convert left labels to corresponding right labels for aseg with dictionary mapping. Parameters ---------- x : int - label to look up + Label to look up. Returns ------- np.ndarray - mapped label - + Mapped label. """ return { 2: 41, @@ -1037,20 +1038,20 @@ def infer_mapping_from_lut( num_classes_full: int, lut: Union[str, pd.DataFrame] ) -> np.ndarray: - """[MISSING]. + """ + [MISSING]. Parameters ---------- num_classes_full : int - number of classes + Number of classes. lut : Union[str, pd.DataFrame] - look-up table listing class labels + Look-up table listing class labels. Returns ------- np.ndarray - list of indexes for - + List of indexes for. """ labels, labels_sag = unify_lateralized_labels(lut) idx_list = np.ndarray(shape=(num_classes_full,), dtype=np.int16) @@ -1072,24 +1073,24 @@ def map_prediction_sagittal2full( num_classes: int = 51, lut: Optional[str] = None ) -> np.ndarray: - """Remap the prediction on the sagittal network to full label space used by coronal and axial networks. + """ + Remap the prediction on the sagittal network to full label space used by coronal and axial networks. Create full aparc.DKTatlas+aseg.mgz. Parameters ---------- prediction_sag : npt.NDArray - sagittal prediction (labels) + Sagittal prediction (labels). num_classes : int - number of SAGITTAL classes (96 for full classes, 51 for hemi split, 21 for aseg) (Default value = 51) + Number of SAGITTAL classes (96 for full classes, 51 for hemi split, 21 for aseg) (Default value = 51). lut : Optional[str] - look-up table listing class labels (Default value = None) + Look-up table listing class labels (Default value = None). Returns ------- np.ndarray - Remapped prediction - + Remapped prediction. """ if num_classes == 96: idx_list = np.asarray( @@ -1334,28 +1335,28 @@ def map_prediction_sagittal2full( def bbox_3d( img: npt.NDArray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Extract the three-dimensional bounding box coordinates. + """ + Extract the three-dimensional bounding box coordinates. Parameters ---------- img : npt.NDArray - mri image + Mri image. Returns ------- np.ndarray - rmin + Rmin. np.ndarray - rmax + Rmax. np.ndarray - cmin + Cmin. np.ndarray - cmax + Cmax. np.ndarray - zmin + Zmin. np.ndarray - zmax - + Zmax. """ r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) @@ -1369,18 +1370,18 @@ def bbox_3d( def get_largest_cc(segmentation: npt.NDArray) -> np.ndarray: - """Find the largest connected component of segmentation. + """ + Find the largest connected component of segmentation. Parameters ---------- segmentation : npt.NDArray - segmentation + Segmentation. Returns ------- np.ndarray - largest connected component of segmentation (binary mask) - + Largest connected component of segmentation (binary mask). """ labels = label(segmentation, connectivity=3, background=0) diff --git a/pyproject.toml b/pyproject.toml index e4126347..8b890d09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ 'nibabel>=3.2.2', 'numpy>=1.21', 'pandas>=1.4.3', - 'pytorch>=1.12.0', + 'torch>=1.12.0', 'pyyaml>=6.0', 'scipy>=1.8.0', 'yacs>=0.1.8', From 2a56581a94a00c843eed9f7de56afcb4ca60facc Mon Sep 17 00:00:00 2001 From: engrosamaali91 Date: Thu, 1 Feb 2024 12:14:26 +0100 Subject: [PATCH 2/7] Missing docstrings added in new files --- FastSurferCNN/data_loader/dataset.py | 130 +++++++++++++++------------ FastSurferCNN/data_loader/loader.py | 11 ++- 2 files changed, 78 insertions(+), 63 deletions(-) diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index 1434bc7b..ec02aff0 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -31,7 +31,9 @@ # Operator to load imaged for inference class MultiScaleOrigDataThickSlices(Dataset): - """Load MRI-Image and process it to correct format for network inference.""" + """ + Load MRI-Image and process it to correct format for network inference. + """ def __init__( self, @@ -40,19 +42,19 @@ def __init__( cfg: yacs.config.CfgNode, transforms: Optional = None ): - """Construct object. + """ + Construct object. Parameters ---------- orig_data : npt.NDArray - Orignal Data + Orignal Data. orig_zoom : npt.NDArray - Original zoomfactors + Original zoomfactors. cfg : yacs.config.CfgNode - Configuration Node + Configuration Node. transforms : Optional - Transformer for the image. Defaults to None - + Transformer for the image. Defaults to None. """ assert ( orig_data.max() > 0.8 @@ -83,7 +85,8 @@ def __init__( self.transforms = transforms def _get_scale_factor(self) -> npt.NDArray[float]: - """Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. + """ + Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. ToDO: This needs to be updated based on the plane we are looking at in case we @@ -92,26 +95,25 @@ def _get_scale_factor(self) -> npt.NDArray[float]: Returns ------- npt.NDArray[float] - scale factor along x and y dimension - + Scale factor along x and y dimension. """ scale = self.base_res / np.asarray(self.zoom) return scale def __getitem__(self, index: int) -> Dict: - """Return a single image and its scale factor. + """ + Return a single image and its scale factor. Parameters ---------- index : int - Index of image to get + Index of image to get. Returns ------- dict - Dictionary of image and scale factor - + Dictionary of image and scale factor. """ img = self.images[index] @@ -122,19 +124,22 @@ def __getitem__(self, index: int) -> Dict: return {"image": img, "scale_factor": scale_factor} def __len__(self) -> int: - """Return length. + """ + Return length. Returns ------- int - count + Count. """ return self.count # Operator to load hdf5-file for training class MultiScaleDataset(Dataset): - """Class for loading aseg file with augmentations (transforms).""" + """ + Class for loading aseg file with augmentations (transforms). + """ def __init__( self, @@ -143,19 +148,19 @@ def __init__( gn_noise: bool = False, transforms: Optional = None ): - """Construct object. + """ + Construct object. Parameters ---------- dataset_path : str - Path to the dataset + Path to the dataset. cfg : yacs.config.CfgNode - Configuration node + Configuration node. gn_noise : bool - Whether to add gaussian noise (Default value = False) + Whether to add gaussian noise (Default value = False). transforms : Optional - Transformer to apply to the image (Default value = None) - + Transformer to apply to the image (Default value = None). """ self.max_size = cfg.DATA.PADDED_SIZE self.base_res = cfg.MODEL.BASE_RES @@ -223,12 +228,13 @@ def __init__( ) def get_subject_names(self): - """Get the subject name. + """ + Get the subject name. Returns ------- list - list of subject names + List of subject names. """ return self.subjects @@ -237,7 +243,8 @@ def _get_scale_factor( img_zoom: torch.Tensor, scale_aug: torch.Tensor ) -> npt.NDArray[float]: - """Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. + """ + Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. @@ -247,15 +254,14 @@ def _get_scale_factor( Parameters ---------- img_zoom : torch.Tensor - Image zoom factor + Image zoom factor. scale_aug : torch.Tensor - [MISSING] + [MISSING]. Returns ------- npt.NDArray[float] - scale factor along x and y dimension - + Scale factor along x and y dimension. """ if torch.all(scale_aug > 0): img_zoom *= 1 / scale_aug @@ -274,18 +280,18 @@ def _pad( self, image: npt.NDArray ) -> np.ndarray: - """Pad the image with zeros. + """ + Pad the image with zeros. Parameters ---------- image : npt.NDArray - Image to pad + Image to pad. Returns ------- padded_image - Padded image - + Padded image. """ if len(image.shape) == 2: h, w = image.shape @@ -308,26 +314,26 @@ def unify_imgs( label: npt.NDArray, weight: npt.NDArray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Pad img, label and weight. + """ + Pad img, label and weight. Parameters ---------- img : npt.NDArray - image to unify + Image to unify. label : npt.NDArray - labels of the image + Labels of the image. weight : npt.NDArray - weights of the image + Weights of the image. Returns ------- np.ndarray - img + Img. np.ndarray - label + Label. np.ndarray - weight - + Weight. """ img = self._pad(img) label = self._pad(label) @@ -336,17 +342,18 @@ def unify_imgs( return img, label, weight def __getitem__(self, index): - """[MISSING]. + """ + Retrieve processed data at the specified index. Parameters ---------- - index : - [MISSING] + index : int + Index to retrieve data for. Returns ------- - [MISSING] - + dict + Dictionary containing torch tensors for image, label, weight, and scale factor. """ padded_img, padded_label, padded_weight = self.unify_imgs( self.images[index], self.labels[index], self.weights[index] @@ -395,14 +402,17 @@ def __getitem__(self, index): } def __len__(self): - """Return count.""" + """ + Return count. + """ return self.count # Operator to load hdf5-file for validation class MultiScaleDatasetVal(Dataset): - """Class for loading aseg file with augmentations (transforms).""" - + """ + Class for loading aseg file with augmentations (transforms). + """ def __init__(self, dataset_path, cfg, transforms=None): self.max_size = cfg.DATA.PADDED_SIZE @@ -469,11 +479,14 @@ def __init__(self, dataset_path, cfg, transforms=None): ) def get_subject_names(self): - """Get subject names.""" + """ + Get subject names. + """ return self.subjects def _get_scale_factor(self, img_zoom): - """Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. + """ + Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. @@ -483,19 +496,20 @@ def _get_scale_factor(self, img_zoom): Parameters ---------- img_zoom : - zooming factor [MISSING] + Zooming factor [MISSING]. Returns ------- np.ndarray : float32 - scale factor along x and y dimension - + Scale factor along x and y dimension. """ scale = self.base_res / img_zoom return scale def __getitem__(self, index): - """Get item.""" + """ + Get item. + """ img = self.images[index] label = self.labels[index] weight = self.weights[index] @@ -524,5 +538,7 @@ def __getitem__(self, index): } def __len__(self): - """Get count.""" + """ + Get count. + """ return self.count diff --git a/FastSurferCNN/data_loader/loader.py b/FastSurferCNN/data_loader/loader.py index 3a838b82..1ae09d9a 100644 --- a/FastSurferCNN/data_loader/loader.py +++ b/FastSurferCNN/data_loader/loader.py @@ -24,21 +24,20 @@ def get_dataloader(cfg: yacs.config.CfgNode, mode: str): - """Create the dataset and pytorch data loader. + """ + Create the dataset and pytorch data loader. Parameters ---------- cfg : yacs.config.CfgNode - configuration node + Configuration node. mode : str - loading data for train, val and test mode + Loading data for train, val and test mode. Returns ------- torch.utils.data.DataLoader - dataloader with given configs and mode - - + Dataloader with given configs and mode. """ assert mode in ["train", "val"], f"dataloader mode is incorrect {mode}" From 5a6b35032594b74e93e604e3cc421ebd2bdb038d Mon Sep 17 00:00:00 2001 From: engrosamaali91 Date: Thu, 1 Feb 2024 15:48:21 +0100 Subject: [PATCH 3/7] recon_surf directory --- recon_surf/align_points.py | 54 +++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index 0824dc2c..028d9a41 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -28,18 +28,18 @@ def rmat2angles(R: npt.NDArray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Extract rotation angles (alpha,beta,gamma) in FreeSurfer format (mris_register) from a rotation matrix. + """ + Extract rotation angles (alpha,beta,gamma) in FreeSurfer format (mris_register) from a rotation matrix. Parameters ---------- R : npt.NDArray - Rotation matrix + Rotation matrix. Returns ------- alpha, beta, gamma - Rotation degree - + Rotation degree. """ alpha = np.degrees(-np.arctan2(R[1, 0], R[0, 0])) beta = np.degrees(np.arcsin(R[2, 0])) @@ -48,22 +48,22 @@ def rmat2angles(R: npt.NDArray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: def angles2rmat(alpha: float, beta: float, gamma: float) -> np.array: - """Convert FreeSurfer angles (alpha,beta,gamma) in degrees to a rotation matrix. + """ + Convert FreeSurfer angles (alpha,beta,gamma) in degrees to a rotation matrix. Parameters ---------- alpha : float - FreeSurfer angle in degrees + FreeSurfer angle in degrees. beta : float - FreeSurfer angle in degrees + FreeSurfer angle in degrees. gamma : float - FreeSurfer angle in degrees + FreeSurfer angle in degrees. Returns ------- R - rotation angles - + Rotation angles. """ sa = np.sin(np.radians(alpha)) sb = np.sin(np.radians(beta)) @@ -82,25 +82,25 @@ def angles2rmat(alpha: float, beta: float, gamma: float) -> np.array: def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: - """Find the rotation matrix. + """ + Find the rotation matrix. Parameters ---------- p_mov : npt.NDArray - [MISSING] + Source points. p_dst : npt.NDArray - [MISSING] + Destination points. Returns ------- R - Rotation matrix + Rotation matrix. Raises ------ ValueError - Shape of points should be identical - + Shape of points should be identical. """ if p_mov.shape != p_dst.shape: raise ValueError( @@ -131,20 +131,20 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: - """[MISSING]. + """ + [MISSING]. Parameters ---------- p_mov : npt.NDArray - [MISSING] + Source points. p_dst : npt.NDArray - [MISSING] + Destination points. Returns ------- T - Homogeneous transformation matrix - + Homogeneous transformation matrix. """ if p_mov.shape != p_dst.shape: raise ValueError( @@ -175,27 +175,27 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: return T def find_affine(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: - """Find affine by least squares solution of overdetermined system. + """ + Find affine by least squares solution of overdetermined system. Assuming we have more than 4 point pairs Parameters ---------- p_mov : npt.NDArray - [MISSING] + The source points. p_dst : npt.NDArray - [MISSING] + The destination points. Returns ------- T - Affine transformation matrix + Affine transformation matrix. Raises ------ ValueError - Shape of points should be identical - + Shape of points should be identical. """ from scipy.linalg import pinv From c9c03acf21e1456fe74ccd68c7e83c2e6b35ed39 Mon Sep 17 00:00:00 2001 From: engrosamaali91 Date: Wed, 7 Feb 2024 11:43:11 +0100 Subject: [PATCH 4/7] sphinx doc files --- doc/_templates/autosummary/class.rst | 10 ++ doc/_templates/autosummary/function.rst | 8 + doc/_templates/autosummary/module.rst | 6 + doc/api/FastSurferCNN.data_loader.rst | 16 ++ doc/api/FastSurferCNN.models.rst | 14 ++ doc/api/FastSurferCNN.rst | 20 +++ doc/api/FastSurferCNN.utils.rst | 24 +++ doc/api/index.rst | 12 ++ doc/api/recon_surf.rst | 15 ++ doc/conf.py | 224 ++++++++++++++++++++++++ doc/index.rst | 12 ++ recon_surf/align_seg.py | 33 ++-- recon_surf/create_annotation.py | 110 ++++++------ 13 files changed, 433 insertions(+), 71 deletions(-) create mode 100644 doc/_templates/autosummary/class.rst create mode 100644 doc/_templates/autosummary/function.rst create mode 100644 doc/_templates/autosummary/module.rst create mode 100644 doc/api/FastSurferCNN.data_loader.rst create mode 100644 doc/api/FastSurferCNN.models.rst create mode 100644 doc/api/FastSurferCNN.rst create mode 100644 doc/api/FastSurferCNN.utils.rst create mode 100644 doc/api/index.rst create mode 100644 doc/api/recon_surf.rst create mode 100644 doc/conf.py create mode 100644 doc/index.rst diff --git a/doc/_templates/autosummary/class.rst b/doc/_templates/autosummary/class.rst new file mode 100644 index 00000000..3322b321 --- /dev/null +++ b/doc/_templates/autosummary/class.rst @@ -0,0 +1,10 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :inherited-members: + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst new file mode 100644 index 00000000..cdbecc4f --- /dev/null +++ b/doc/_templates/autosummary/function.rst @@ -0,0 +1,8 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autofunction:: {{ objname }} + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/module.rst b/doc/_templates/autosummary/module.rst new file mode 100644 index 00000000..13a2c278 --- /dev/null +++ b/doc/_templates/autosummary/module.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. automodule:: {{ fullname }} + :members: + diff --git a/doc/api/FastSurferCNN.data_loader.rst b/doc/api/FastSurferCNN.data_loader.rst new file mode 100644 index 00000000..c33c3182 --- /dev/null +++ b/doc/api/FastSurferCNN.data_loader.rst @@ -0,0 +1,16 @@ +API data_loader References +========================== + + +.. currentmodule:: FastSurferCNN.data_loader + +.. autosummary:: + :toctree: generated/ + + augmentation + conform + data_utils + dataset + loader + + diff --git a/doc/api/FastSurferCNN.models.rst b/doc/api/FastSurferCNN.models.rst new file mode 100644 index 00000000..595001a8 --- /dev/null +++ b/doc/api/FastSurferCNN.models.rst @@ -0,0 +1,14 @@ +API models References +===================== + + +.. currentmodule:: FastSurferCNN.models + +.. autosummary:: + :toctree: generated/ + + interpolation_layer + losses + networks + sub_module + diff --git a/doc/api/FastSurferCNN.rst b/doc/api/FastSurferCNN.rst new file mode 100644 index 00000000..ed273efb --- /dev/null +++ b/doc/api/FastSurferCNN.rst @@ -0,0 +1,20 @@ +API References +============== + + +.. currentmodule:: FastSurferCNN + + +.. autosummary:: + :toctree: generated/ + + + download_checkpoints + generate_hdf5 + inference + quick_qc + reduce_to_aseg + run_prediction + segstats + version + diff --git a/doc/api/FastSurferCNN.utils.rst b/doc/api/FastSurferCNN.utils.rst new file mode 100644 index 00000000..d239173f --- /dev/null +++ b/doc/api/FastSurferCNN.utils.rst @@ -0,0 +1,24 @@ +API Utils References +==================== + + +.. currentmodule:: FastSurferCNN.utils + +.. autosummary:: + :toctree: generated/ + + arg_types + checkpoint + common + load_config + logging + lr_scheduler + mapper + meters + metrics + misc + parser_defaults + run_tools + threads + + diff --git a/doc/api/index.rst b/doc/api/index.rst new file mode 100644 index 00000000..ebbf707c --- /dev/null +++ b/doc/api/index.rst @@ -0,0 +1,12 @@ +API +=== + +.. toctree:: + :maxdepth: 2 + + FastSurferCNN.rst + FastSurferCNN.models.rst + FastSurferCNN.utils.rst + FastSurferCNN.data_loader.rst + recon_surf.rst + diff --git a/doc/api/recon_surf.rst b/doc/api/recon_surf.rst new file mode 100644 index 00000000..4e683970 --- /dev/null +++ b/doc/api/recon_surf.rst @@ -0,0 +1,15 @@ +API recon_surf References +========================= + + +.. currentmodule:: recon_surf + +.. autosummary:: + :toctree: generated/ + + align_points + align_seg + create_annotation + + + diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 00000000..6a8c744f --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,224 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + + + +import inspect +from datetime import date +from importlib import import_module +from typing import Dict, Optional + + + +# here i added the relative path because sphinx was not able +# to locate FastSurferCNN module directly for autosummary +import sys +import os +sys.path.append(os.path.dirname(__file__) + '/..') +sys.path.append(os.path.dirname(__file__) + '/../recon_surf') +# sys.path.insert(0, '..') #after + +# autodoc_mock_imports = ["torch", "yacs"] + + +project = 'FastSurfer' +copyright = '2023, Martin Reuter' +author = 'Martin Reuter' +copyright = f"{date.today().year}, {author}" +# release = FastSurferCNN.__version__ +# package = fsqc.__name__ +gh_url = "https://github.com/deep-mi/FastSurfer" + + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +# If your documentation needs a minimal Sphinx version, state it here. +needs_sphinx = "5.0" + +# The document name of the “root” document, that is, the document that contains +# the root toctree directive. +root_doc = "index" + + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named "sphinx.ext.*") or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosectionlabel", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "numpydoc", + "sphinxcontrib.bibtex", + "sphinx_copybutton", + "sphinx_design", + "sphinx_issues", + "nbsphinx", + "IPython.sphinxext.ipython_console_highlighting" +] + + +templates_path = ['_templates'] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "**.ipynb_checkpoints", +] + + +# Sphinx will warn about all references where the target cannot be found. +nitpicky = False +nitpick_ignore = [] + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [f"{package}."] + +# The name of a reST role (builtin or Sphinx extension) to use as the default +# role, that is, for text marked up `like this`. This can be set to 'py:obj' to +# make `filter` a cross-reference to the Python function “filter”. +default_role = "py:obj" + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output +html_theme = "furo" +html_static_path = ["_static"] +html_title = project +html_show_sphinx = False + +# Documentation to change footer icons: +# https://pradyunsg.me/furo/customisation/footer/#changing-footer-icons +html_theme_options = { + "footer_icons": [ + { + "name": "GitHub", + "url": gh_url, + "html": """ + + + + """, + "class": "", + }, + ], +} + + + +# -- autosummary ------------------------------------------------------------- +autosummary_generate = True + +# -- autodoc ----------------------------------------------------------------- +autodoc_typehints = "none" +autodoc_member_order = "groupwise" +autodoc_warningiserror = True +autoclass_content = "class" + + +# -- intersphinx ------------------------------------------------------------- +intersphinx_mapping = { + "matplotlib": ("https://matplotlib.org/stable", None), + "mne": ("https://mne.tools/stable/", None), + "numpy": ("https://numpy.org/doc/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "python": ("https://docs.python.org/3", None), + # "scipy": ("https://docs.scipy.org/doc/scipy", None), + "sklearn": ("https://scikit-learn.org/stable/", None), +} +intersphinx_timeout = 5 + + +# -- sphinx-issues ----------------------------------------------------------- +issues_github_path = gh_url.split("https://github.com/")[-1] + +# -- autosectionlabels ------------------------------------------------------- +autosectionlabel_prefix_document = True + +# -- numpydoc ---------------------------------------------------------------- +numpydoc_class_members_toctree = False +numpydoc_attributes_as_param_list = False +# numpydoc_show_class_members = True + + +# x-ref +numpydoc_xref_param_type = True +numpydoc_xref_aliases = { + # Matplotlib + "Axes": "matplotlib.axes.Axes", + "Figure": "matplotlib.figure.Figure", + # Python + "bool": ":class:`python:bool`", + "Path": "pathlib.Path", + "TextIO": "io.TextIOBase", + # Scipy + "csc_matrix": "scipy.sparse.csc_matrix", +} +# numpydoc_xref_ignore = {} + +# validation +# https://numpydoc.readthedocs.io/en/latest/validation.html#validation-checks +error_ignores = { + "GL01", # docstring should start in the line immediately after the quotes + "EX01", # section 'Examples' not found + "ES01", # no extended summary found + "SA01", # section 'See Also' not found + "RT02", # The first line of the Returns section should contain only the type, unless multiple values are being returned # noqa + "PR01", # Parameters {missing_params} not documented + "GL08", # The object does not have a docstring + "SS05", # Summary must start with infinitive verb, not third person + "RT01", # No Returns section found + "SS06", # Summary should fit in a single line + "GL02", # Closing quotes should be placed in the line after the last text + "GL03", # Double line break found; please use only one blank line to + "SS03", # Summary does not end with a period + "YD01", # No Yields section found + "PR02" # Unknown parameters {unknown_params} +} +numpydoc_validate = True +numpydoc_validation_checks = {"all"} | set(error_ignores) +numpydoc_validation_exclude = { # regex to ignore during docstring check + r"\.__getitem__", + r"\.__contains__", + r"\.__hash__", + r"\.__mul__", + r"\.__sub__", + r"\.__add__", + r"\.__iter__", + r"\.__div__", + r"\.__neg__", +} + +# -- sphinxcontrib-bibtex ---------------------------------------------------- +bibtex_bibfiles = ["./references.bib"] + +# -- sphinx.ext.linkcode ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html + +# Alternative method for linking to code by Osama, not sure which one is better +from urllib.parse import quote + + +def linkcode_resolve(domain, info): + if domain != "py": + return None + if not info["module"]: + return None + filename = quote(info["module"].replace(".", "/")) + if not filename.startswith("tests"): + filename = "/" + filename + if "fullname" in info: + anchor = info["fullname"] + anchor = "#:~:text=" + quote(anchor.split(".")[-1]) + else: + anchor = "" + result = f"{gh_url}/blob/stable/{filename}.py{anchor}" + return result + + diff --git a/doc/index.rst b/doc/index.rst new file mode 100644 index 00000000..00d04046 --- /dev/null +++ b/doc/index.rst @@ -0,0 +1,12 @@ +.. FastSurfer documentation master file, created by + sphinx-quickstart on Thu Nov 30 15:48:44 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to FastSurfer's documentation! +====================================== + +.. toctree:: + :hidden: + + api/index diff --git a/recon_surf/align_seg.py b/recon_surf/align_seg.py index 244199b6..f0b60104 100755 --- a/recon_surf/align_seg.py +++ b/recon_surf/align_seg.py @@ -65,13 +65,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = optparse.OptionParser( version="$Id:align_seg.py,v 1.0 2022/08/24 21:22:08 mreuter Exp $", @@ -100,24 +100,24 @@ def options_parse(): def get_seg_centroids(seg_mov: sitk.Image, seg_dst: sitk.Image, label_ids: Optional[npt.NDArray[int]] = []) -> Tuple[npt.NDArray, npt.NDArray]: - """Extract the centroids of the segmentation labels for mov and dst in RAS coords. + """ + Extract the centroids of the segmentation labels for mov and dst in RAS coords. Parameters ---------- seg_mov : sitk.Image - Source segmentation image + Source segmentation image. seg_dst : sitk.Image - Target segmentation image + Target segmentation image. label_ids : Optional[npt.NDArray[int]] - List of label ids to extract (Default value = []) + List of label ids to extract (Default value = []). Returns ------- centroids_mov - List of centroids of source segmentation + List of centroids of source segmentation. centroids_dst - List of centroids of target segmentation - + List of centroids of target segmentation. """ if not label_ids: # use all joint labels except -1 and 0: @@ -159,7 +159,8 @@ def align_seg_centroids( label_ids: Optional[npt.NDArray[int]] = [], affine: bool = False ) -> npt.NDArray: - """Align the segmentations based on label centroids (rigid is default). + """ + Align the segmentations based on label centroids (rigid is default). Parameters ---------- @@ -177,7 +178,6 @@ def align_seg_centroids( ------- T Aligned centroids RAS2RAS transform. - """ # get centroids of each label in image centroids_mov, centroids_dst = get_seg_centroids(seg_mov, seg_dst, label_ids) @@ -191,12 +191,13 @@ def align_seg_centroids( def get_vox2ras(img:sitk.Image) -> npt.NDArray: - """Extract voxel to RAS (affine) from sitk image. + """ + Extract voxel to RAS (affine) from sitk image. Parameters ---------- seg : sitk.Image - sitk Image. + Sitk Image. Returns ------- @@ -219,7 +220,8 @@ def get_vox2ras(img:sitk.Image) -> npt.NDArray: return vox2ras def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDArray: - """Registrate Left - right (make upright). + """ + Registrate Left - right (make upright). Register cortial lables @@ -235,7 +237,6 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA ------- Tsqrt RAS2RAS transformation matrix for registration. - """ lhids = np.array( [ diff --git a/recon_surf/create_annotation.py b/recon_surf/create_annotation.py index 2d43fe84..59eab680 100755 --- a/recon_surf/create_annotation.py +++ b/recon_surf/create_annotation.py @@ -90,13 +90,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = optparse.OptionParser( version="$Id:create_annotation.py,v 1.0 2022/08/24 21:22:08 mreuter Exp $", @@ -156,45 +156,45 @@ def map_multiple_labels( out_dir: Optional[str] = None, stop_missing: bool = True ) -> Tuple[npt.ArrayLike, npt.ArrayLike]: - """Map a list of labels from one surface (e.g. fsavaerage sphere.reg) to another. + """ + Map a list of labels from one surface (e.g. fsavaerage sphere.reg) to another. Labels are just names without hemisphere or path, - which are passed via hemi, src_dir, out_dir) + which are passed via hemi, src_dir, out_dir). Parameters ---------- hemi : str - "lh" or "rh" for reading labels + "lh" or "rh" for reading labels. src_dir : str - director of the source file + Director of the source file. src_labels : npt.ArrayLike - List of labels + List of labels. src_sphere_name : str - filename of source sphere + Filename of source sphere. trg_sphere_name : str - filename of target sphere + Filename of target sphere. trg_white_name : str - filename of target white + Filename of target white. trg_sid : str - target subject id + Target subject id. out_dir : Optional[str] - directory for output, defaults to None + Directory for output, defaults to None. stop_missing : bool - determines whether to stop on a missing src label file, or continue - with a warning. Defaults to True + Determines whether to stop on a missing src label file, or continue + with a warning. Defaults to True. Returns ------- all_labels - mapped labels + Mapped labels. all_values - values of mapped labels + Values of mapped labels. Raises ------ ValueError - Label file missing - + Label file missing. """ # get reverse mapping (trg->src) for sampling rev_mapping, _, _ = getSurfCorrespondence(trg_sphere_name, src_sphere_name) @@ -236,24 +236,24 @@ def read_multiple_labels( input_dir: str, label_names: npt.ArrayLike ) -> Tuple[ List[npt.NDArray], List[npt.NDArray]]: - """Read multiple label files from input_dir. + """ + Read multiple label files from input_dir. Parameters ---------- hemi : str - "lh" or "rh" for reading labels + "lh" or "rh" for reading labels. input_dir : str - director of the source + Director of the source. label_names : npt.ArrayLike - List of labels + List of labels. Returns ------- all_labels - read labels + Read labels. all_values - values of read labels - + Values of read labels. """ all_labels = [] all_values = [] @@ -275,7 +275,8 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, col_ids: npt.ArrayLike, trg_white: Union[str, npt.NDArray], cortex_label_name: Optional[str] = None ) -> Tuple[npt.NDArray, npt.NDArray]: - """Create an annotation from multiple labels. + """ + Create an annotation from multiple labels. Here we also consider the label values and overwrite existing labels if values of current are larger (or equal, so the order of the labels matters). @@ -284,23 +285,22 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, Parameters ---------- all_labels : npt.ArrayLike - List of all Labels + List of all Labels. all_values : npt.ArrayLike - List of all values + List of all values. col_ids : npt.ArrayLike - List of col ids + List of col ids. trg_white : Union[str, npt.NDArray] - target file of white + Target file of white. cortex_label_name : Optional[str] - Path to the cortex label file. Defaults to None + Path to the cortex label file. Defaults to None. Returns ------- annot_ids - Ids of build Annotations + Ids of build Annotations. annot_vals - Values of build Annotations - + Values of build Annotations. """ # create annot from a bunch of labels (and values) if isinstance(trg_white, str): @@ -339,22 +339,22 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, def read_colortable(colortab_name: str) -> Tuple[npt.ArrayLike, List[str], npt.ArrayLike]: - """Read the colortable of given name. + """ + Read the colortable of given name. Parameters ---------- colortab_name : str - Path and Name of the colortable file + Path and Name of the colortable file. Returns ------- ids - List of ids + List of ids. names - List of names + List of names. colors - List of colors corresponding to ids and names - + List of colors corresponding to ids and names. """ colortab = np.genfromtxt(colortab_name, dtype="i8", usecols=(0, 2, 3, 4, 5)) ids = colortab[:, 0] @@ -371,25 +371,25 @@ def write_annot( out_annot: str, append: Union[None, str] = "" ) -> None: - """Combine the colortable with the annotations ids to write an annotation file. + """ + Combine the colortable with the annotations ids to write an annotation file. The annotation file contains colortable information Care needs to be taken that the colortable file has the same number - and order of labels as specified in the label_names list + and order of labels as specified in the label_names list. Parameters ---------- annot_ids : npt.ArrayLike - List of annotation ids + List of annotation ids. label_names : npt.ArrayLike - list of label names + List of label names. colortab_name : str - Path and name of colortable file + Path and name of colortable file. out_annot : str - Path and name of output annotation file + Path and name of output annotation file. append : Union[None, str] - String to append to colour name. Defaults to "" - + String to append to colour name. Defaults to "". """ # colortab_name="colortable_BA.txt" col_ids, col_names, col_colors = read_colortable(colortab_name) @@ -413,14 +413,15 @@ def write_annot( def create_annotation(options, verbose: bool = True) -> None: - """Map (if required), build and write annotation. + """ + Map (if required), build and write annotation. (Main function) Parameters ---------- - options : - object holding options + options : Any + Object holding options hemi: "lh" or "rh" for reading labels colortab: colortab with label ids, names and colors labeldir: dir where to find the label files (when reading) @@ -429,10 +430,9 @@ def create_annotation(options, verbose: bool = True) -> None: cortex: optional path to hemi.cortex for optional masking of annotation to only cortex append: optional, e.g. ".thresh" can be appended to label names (I/O) for exvivo FS labels srcsphere: optional, when mapping: path to src sphere.reg - trgsphere: optional, when mapping: path to trg sphere.reg + trgsphere: optional, when mapping: path to trg sphere.reg. verbose : bool - True if options should be printed. Defaults to True - + True if options should be printed. Defaults to True. """ print() print("Map BA Labels Parameters:") From 300a39bd2bf5296f7e2e362aba7c0ab30ddc1950 Mon Sep 17 00:00:00 2001 From: engrosamaali91 Date: Wed, 7 Feb 2024 14:57:08 +0100 Subject: [PATCH 5/7] DevOp rebased on dev --- recon_surf/create_annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recon_surf/create_annotation.py b/recon_surf/create_annotation.py index 59eab680..a92af7eb 100755 --- a/recon_surf/create_annotation.py +++ b/recon_surf/create_annotation.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# This is test line just to check rebase commit # IMPORTS import optparse From e1f72e45e60f01315b7d8cf9d7022679a989d174 Mon Sep 17 00:00:00 2001 From: engrosamaali91 Date: Thu, 8 Feb 2024 16:10:55 +0100 Subject: [PATCH 6/7] Implemented requested changes --- FastSurferCNN/data_loader/data_utils.py | 4 +-- FastSurferCNN/data_loader/dataset.py | 6 ++--- FastSurferCNN/generate_hdf5.py | 2 +- FastSurferCNN/inference.py | 18 ++++++------- FastSurferCNN/models/interpolation_layer.py | 25 +++++++---------- FastSurferCNN/models/losses.py | 16 +++++------ FastSurferCNN/models/networks.py | 30 +++++++++------------ FastSurferCNN/quick_qc.py | 15 ++++++----- FastSurferCNN/reduce_to_aseg.py | 4 +-- FastSurferCNN/run_prediction.py | 14 +++++----- FastSurferCNN/utils/arg_types.py | 16 ++++++++--- FastSurferCNN/utils/common.py | 21 +++++++++------ doc/api/FastSurferCNN.utils.rst | 2 +- doc/api/index.rst | 4 +-- 14 files changed, 89 insertions(+), 88 deletions(-) diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index 86ce14ee..da232888 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -572,7 +572,7 @@ def read_classes_from_lut(lut_file: str) -> pd.DataFrame: Parameters ---------- lut_file : str - Path and name of FreeSurfer-style LUT file with classes of interest + Path and name of FreeSurfer-style LUT file with classes of interest. Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 @@ -1039,7 +1039,7 @@ def infer_mapping_from_lut( lut: Union[str, pd.DataFrame] ) -> np.ndarray: """ - [MISSING]. + Guess the mapping from a lookup table. Parameters ---------- diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index ec02aff0..fc68f1cb 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -495,12 +495,12 @@ def _get_scale_factor(self, img_zoom): Parameters ---------- - img_zoom : - Zooming factor [MISSING]. + img_zoom : np.ndarray + Voxel sizes of the image. Returns ------- - np.ndarray : float32 + np.ndarray : numpy.typing.NDArray[float] Scale factor along x and y dimension. """ scale = self.base_res / img_zoom diff --git a/FastSurferCNN/generate_hdf5.py b/FastSurferCNN/generate_hdf5.py index cf03c304..e608f8fe 100644 --- a/FastSurferCNN/generate_hdf5.py +++ b/FastSurferCNN/generate_hdf5.py @@ -278,7 +278,7 @@ def create_hdf5_dataset(self, blt: int): Parameters ---------- blt : int - Blank sliec threshold. + Blank slice threshold. """ data_per_size = defaultdict(lambda: defaultdict(list)) start_d = time.time() diff --git a/FastSurferCNN/inference.py b/FastSurferCNN/inference.py index 008ee7e7..6a1b7f83 100644 --- a/FastSurferCNN/inference.py +++ b/FastSurferCNN/inference.py @@ -223,7 +223,7 @@ def load_checkpoint(self, ckpt: Union[str, os.PathLike]): if self.model_parallel: self.model = torch.nn.DataParallel(self.model) - def get_modelname(self): + def get_modelname(self) -> str: """ Return the model name. @@ -234,7 +234,7 @@ def get_modelname(self): """ return self.model_name - def get_cfg(self): + def get_cfg(self) -> yacs.config.CfgNode: """ Return the configurations. @@ -245,7 +245,7 @@ def get_cfg(self): """ return self.cfg - def get_num_classes(self): + def get_num_classes(self) -> int: """ Return the number of classes. @@ -256,7 +256,7 @@ def get_num_classes(self): """ return self.cfg.MODEL.NUM_CLASSES - def get_plane(self): + def get_plane(self) -> str: """ Return the plane. @@ -267,7 +267,7 @@ def get_plane(self): """ return self.cfg.DATA.PLANE - def get_model_height(self): + def get_model_height(self) -> int: """ Return the model height. @@ -278,7 +278,7 @@ def get_model_height(self): """ return self.cfg.MODEL.HEIGHT - def get_model_width(self): + def get_model_width(self) -> int: """ Return the model width. @@ -289,13 +289,13 @@ def get_model_width(self): """ return self.cfg.MODEL.WIDTH - def get_max_size(self): + def get_max_size(self) -> int | tuple[int, int]: """ Return the max size. Returns ------- - Union[int, Tuple[int, int]] + int | tuple[int, int] The maximum size, either a single value or a tuple (width, height). """ if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT: @@ -303,7 +303,7 @@ def get_max_size(self): else: return self.cfg.MODEL.OUT_TENSOR_WIDTH, self.cfg.MODEL.OUT_TENSOR_HEIGHT - def get_device(self): + def get_device(self) -> torch.device: """ Return the device. diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index 4970f7a5..627d7fb5 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -68,7 +68,7 @@ def __init__( target_shape : _T.Optional[_T.Sequence[int]] Target tensor size for after this module, not including batchsize and channels. - interpolation_mode : str + interpolation_mode : str, default="nearest" Interpolation mode as in `torch.nn.interpolate` (default: 'neareast'). """ @@ -132,7 +132,7 @@ def forward( image: The first dimension corresponds to and must be equal to the batch size of the image. The second dimension is optional and may contain different values for the _scale_limits factor per axis. In consequence, this dimension can have 1 or {dim} values. - rescale : bool + rescale : bool, default="False" (Default value = False). Returns @@ -201,7 +201,7 @@ def _fix_scale_factors( Yields ------ - _T.Iterable[_T.Tuple[T_Scale, int]] + tuple[T_Scale, int] The next fixed scale factor. Raises @@ -370,17 +370,10 @@ class Zoom2d(_ZoomNd): """ Perform a crop and interpolation on a Four-dimensional Tensor respecting batch and channel. - Attributes - ---------- - _N - Number of dimensions (Here 2). - _crop_position - Position to crop. - Methods ------- _interpolate - Crops, interpolates and pads the tensor. + (Protected) Crops, interpolates and pads the tensor. """ def __init__( @@ -396,9 +389,9 @@ def __init__( ---------- target_shape : _T.Optional[_T.Sequence[int]] Target tensor size for after this module, not including batchsize and channels. - interpolation_mode : str + interpolation_mode : str, default="nearest" Interpolation mode as in `torch.nn.interpolate` (default: 'nearest') - crop_position : str + crop_position : str, default="top_left" Crop position to use from 'top_left', 'bottom_left', top_right', 'bottom_right', 'center' (default: 'top_left'). """ @@ -416,7 +409,7 @@ def __init__( self._N = 2 super(Zoom2d, self).__init__(target_shape, interpolation_mode) - self._crop_position = crop_position + self.crop_position = crop_position def _interpolate( self, @@ -510,10 +503,10 @@ def __init__( target_shape : _T.Optional[_T.Sequence[int]] Target tensor size for after this module, not including batchsize and channels. - interpolation_mode : str + interpolation_mode : str, default="nearest" Interpolation mode as in `torch.nn.interpolate`, (default: 'neareast'). - crop_position : str + crop_position : str, default="front_top_left" Crop position to use from 'front_top_left', 'back_top_left', 'front_bottom_left', 'back_bottom_left', 'front_top_right', 'back_top_right', 'front_bottom_right', 'back_bottom_right', 'center' (default: 'front_top_left'). diff --git a/FastSurferCNN/models/losses.py b/FastSurferCNN/models/losses.py index 891a94c2..de667f6e 100644 --- a/FastSurferCNN/models/losses.py +++ b/FastSurferCNN/models/losses.py @@ -13,16 +13,16 @@ # limitations under the License. -from numbers import Real -from typing import Optional, Tuple, Union # IMPORTS import torch import yacs.config + from torch import Tensor, nn from torch.nn import functional as F from torch.nn.modules.loss import _Loss - +from numbers import Real +from typing import Optional, Tuple, Union class DiceLoss(_Loss): """ @@ -40,7 +40,7 @@ def forward( target: Tensor, weights: Optional[int] = None, ignore_index: Optional[int] = None, - ) -> float: + ) -> torch.Tensor: """ Calulate the DiceLoss. @@ -50,14 +50,14 @@ def forward( N x C x H x W Variable. target : Tensor N x C x W LongTensor with starting class at 0. - weights : Optional[int] + weights : int, optional C FloatTensor with class wise weights(Default value = None). - ignore_index : Optional[int] + ignore_index : int, optional Ignore label with index x in the loss calculation (Default value = None). Returns ------- - float + torch.Tensor Calculated Diceloss. """ eps = 0.001 @@ -114,7 +114,7 @@ def __init__(self, weight: Optional[Tensor] = None, reduction: str = "none"): Parameters ---------- - weight : Optional[Tensor] + weight : Tensor, optional A manual rescaling weight given to each class. If given, has to be a Tensor of size `C`. Defaults to None. reduction : str Specifies the reduction to apply to the output, as in nn.CrossEntropyLoss. Defaults to 'None'. diff --git a/FastSurferCNN/models/networks.py b/FastSurferCNN/models/networks.py index d3506001..0782bfb3 100644 --- a/FastSurferCNN/models/networks.py +++ b/FastSurferCNN/models/networks.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Dict, Optional, Union - +# IMPORTS import numpy as np import yacs - -# IMPORTS -from torch import Tensor, nn - import FastSurferCNN.models.interpolation_layer as il import FastSurferCNN.models.sub_module as sm +from typing import Dict, Optional, Union +from torch import Tensor, nn class FastSurferCNNBase(nn.Module): """ @@ -61,7 +57,7 @@ def __init__(self, params: Dict, padded_size: int = 256): params : Dict Parameters in dictionary format - padded_size : int + padded_size : int, default = 256 Size of image when padded (Default value = 256). """ super(FastSurferCNNBase, self).__init__() @@ -105,10 +101,10 @@ def forward( ---------- x : Tensor Input image [N, C, H, W] representing the input data. - scale_factor : Optional[Tensor] - [N, 1] Defaults to None (Default value = None). - scale_factor_out : Optional[Tensor] - (Default value = None). + scale_factor : Tensor, optional + [N, 1] Defaults to None. + scale_factor_out : Tensor, optional + [Missing]. Returns ------- @@ -195,10 +191,10 @@ def forward( ---------- x : Tensor Input image [N, C, H, W]. - scale_factor : Optional[Tensor] + scale_factor : Tensor, optional [N, 1] Defaults to None. - scale_factor_out : Optional[Tensor] - Defaults to None. + scale_factor_out : Tensor, optional + [Missing]. Returns ------- @@ -260,7 +256,7 @@ def __init__(self, params: Dict, padded_size: int = 256): ---------- params : Dict Dictionary of configurations. - padded_size : int + padded_size : int, default = 256 Size of image when padded (Default value = 256). """ num_c = params["num_channels"] @@ -342,7 +338,7 @@ def forward( scale_factor : Tensor [MISSING] [N, 1]. scale_factor_out : Tensor, Optional - [MISSING] Defaults to None. + [MISSING]. Returns ------- diff --git a/FastSurferCNN/quick_qc.py b/FastSurferCNN/quick_qc.py index d3ceefd5..dfea521c 100644 --- a/FastSurferCNN/quick_qc.py +++ b/FastSurferCNN/quick_qc.py @@ -68,17 +68,17 @@ def options_parse(): return options -def check_volume(asegdkt_segfile, voxvol, thres=0.70): +def check_volume(asegdkt_segfile:np.ndarray, voxvol: float, thres: float = 0.70): """ Check if total volume is bigger or smaller than threshold. Parameters ---------- - asegdkt_segfile : - + asegdkt_segfile : np.ndarray [MISSING]. - voxvol : - + voxvol : float [MISSING]. - thres : - + thres : float [MISSING]. Returns @@ -113,8 +113,11 @@ def get_region_bg_intersection_mask( seg_array : numpy.ndarray Segmentation array. region_labels : Dict - Dict whose values correspond to the desired region's labels (Default value = VENT_LABELS). - bg_label : int + Dictionary whose values correspond to the desired region's labels (Default value = VENT_LABELS). + VENT_LABELS is a dictionary containing labels for different regions related to the ventricles, + such as "Left-Lateral-Ventricle", "Right-Lateral-Ventricle", etc., along with their + corresponding numeric values. + bg_label : int, default="BG_LABEL" (Default value = BG_LABEL). Returns diff --git a/FastSurferCNN/reduce_to_aseg.py b/FastSurferCNN/reduce_to_aseg.py index 7ab1276c..ec0ad400 100644 --- a/FastSurferCNN/reduce_to_aseg.py +++ b/FastSurferCNN/reduce_to_aseg.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import copy - # IMPORTS +import copy import optparse import sys diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index 0285b6d0..9847955e 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -12,21 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +# IMPORTS import argparse import copy import os - -# IMPORTS import sys -from concurrent.futures import Executor -from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union - import nibabel as nib import numpy as np import torch import yacs.config - import FastSurferCNN.reduce_to_aseg as rta + +from concurrent.futures import Executor +from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union from FastSurferCNN.data_loader import conform as conf from FastSurferCNN.data_loader import data_utils as du from FastSurferCNN.inference import Inference @@ -154,11 +152,11 @@ class RunModelOnData: pred_name : str conf_name : str orig_name : str - vox_size : Union[float, Literal["min"]] + vox_size : float, 'min' current_plane : str models : Dict[str, Inference] view_ops : Dict[str, Dict[str, Any]] - conform_to_1mm_threshold : Optional[float] + conform_to_1mm_threshold : float, optional threshold until which the image will be conformed to 1mm res Methods diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index 6cb63d7b..6891f1a6 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -36,11 +36,10 @@ def vox_size(a: str) -> VoxSizeOption: If 'auto' or 'min' is provided, it returns a string('auto' or 'min'). If a valid voxel size (between 0 and 1) is provided, it returns a float. - Raises ------ argparse.ArgumentTypeError - An error from creating or using an argument. Additionally, vox_sizes may be 'min'. + If the arguemnt is not "min", "auto" or convertible to a float between 0 and 1. """ if a.lower() in ["auto", "min"]: return "min" @@ -66,7 +65,11 @@ def float_gt_zero_and_le_one(a: str) -> Optional[float]: float or None If `a` is a valid float between 0 and 1, return the float value. If `a` is 'none' or 'infinity', return None. - Otherwise, raise an argparse.ArgumentTypeError. + + Raises + ------ + argparse.ArgumentTypeError + If `a` is neither a float between 0 and 1. """ if a is None or a.lower() in ["none", "infinity"]: return None @@ -84,7 +87,7 @@ def target_dtype(a: str) -> str: Parameters ---------- a : str - Datatype. + Datatype descriptor. Returns ------- @@ -95,6 +98,11 @@ def target_dtype(a: str) -> str: ------ argparse.ArgumentTypeError Invalid dtype. + + See Also + -------- + numpy.dtype + For more information on numpy data types and their properties. """ dtypes = nib.freesurfer.mghformat.data_type_codes.value_set("label") dtypes.add("any") diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index 0dc9a9d4..1d38b723 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -200,10 +200,10 @@ def pipeline( pipeline_size : int Size of the processing pipeline (default is 1). - Returns - ------- + Yields + ------ Iterator[Tuple[_Ti, _T]] - Iterator yielding input elements and corresponding results. + Yielding elements from iterable and corresponding results from func(element). """ # do pipeline loading the next element from collections import deque @@ -231,8 +231,8 @@ def iterate( Parameters ---------- pool : Executor - [MISSING]. - func : Callable[[_Ti], _T] + Callable. + func : Iterable Function to use. iterable : Iterable[_Ti] Iterable. @@ -242,7 +242,7 @@ def iterate( element : _Ti Elements _T - [MISSING]. + Func(element). """ for element in iterable: yield element, func(element) @@ -318,7 +318,7 @@ def filename_in_subject_folder(self, filepath: str) -> str: Parameters ---------- filepath : str - Abs path to the file or name of the file. + Absolute path. Returns ------- @@ -679,8 +679,13 @@ def get_attribute(self, attr_name: str): Returns ------- - AttributeError + object The value of the requested attribute. + + Raises + ------ + AttributeError + If the subject has no attribute with the given name. """ if not self.has_attribute(attr_name): raise AttributeError(f"The subject has no attribute named {attr_name}.") diff --git a/doc/api/FastSurferCNN.utils.rst b/doc/api/FastSurferCNN.utils.rst index d239173f..f505d0e4 100644 --- a/doc/api/FastSurferCNN.utils.rst +++ b/doc/api/FastSurferCNN.utils.rst @@ -1,4 +1,4 @@ -API Utils References +API utils References ==================== diff --git a/doc/api/index.rst b/doc/api/index.rst index ebbf707c..40ac8b38 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -1,5 +1,5 @@ -API -=== +FastSurfer API +============== .. toctree:: :maxdepth: 2 From bb0ea35f9017a1ae891f5eaf37028b6299912fe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Thu, 8 Feb 2024 16:24:00 +0100 Subject: [PATCH 7/7] Changes arg_types.py - update unquote_str docstring common.py - remove the obsolete removesuffix function run_prediction.py - remove the obsolete removesuffix function quick_qc.py - Fix the docstring of get_region_by_intersection_mask generate_hdf5.py - replace seg-mask with sag_mask interpolation_layer.py - Update of the crop_position property common.py - Update the docstrings of pipeline and iterate --- FastSurferCNN/generate_hdf5.py | 4 +- FastSurferCNN/models/interpolation_layer.py | 73 ++++++++++++++++----- FastSurferCNN/quick_qc.py | 20 +++--- FastSurferCNN/run_prediction.py | 31 --------- FastSurferCNN/utils/arg_types.py | 4 +- FastSurferCNN/utils/common.py | 63 +++++------------- 6 files changed, 88 insertions(+), 107 deletions(-) diff --git a/FastSurferCNN/generate_hdf5.py b/FastSurferCNN/generate_hdf5.py index e608f8fe..84e8cb82 100644 --- a/FastSurferCNN/generate_hdf5.py +++ b/FastSurferCNN/generate_hdf5.py @@ -155,7 +155,7 @@ def __init__(self, params: Dict, processing: str = "aparc"): self.gm_mask = params["gm_mask"] self.lut = read_classes_from_lut(params["lut"]) - self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag-mask"]) + self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag_mask"]) self.lateralization = unify_lateralized_labels(self.lut, params["combi"]) if params["csv_file"] is not None: @@ -540,7 +540,7 @@ def create_hdf5_dataset(self, blt: int): "plane": args.plane, "lut": args.lut, "combi": args.combi, - "sag-mask": args.sag_mask, + "sag_mask": args.sag_mask, "hires_weight": args.hires_w, "gm_mask": args.gm, "gradient": not args.no_grad, diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index 627d7fb5..b656b803 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -36,23 +36,23 @@ class _ZoomNd(nn.Module): Attributes ---------- _mode - Interpolation mode as in `torch.nn.interpolate` (default: 'neareast'). + (Protected) Interpolation mode as in `torch.nn.interpolate` (default: 'neareast'). _target_shape - Target tensor size for after this module, + (Protected) Target tensor size for after this module, not including batchsize and channels. _N - Number of dimensions. + (Protected) Internal number of dimensions. Methods ------- forward Forward propagation. _fix_scale_factors - Checking and fixing the conformity of scale_factors. + (Protected) Checking and fixing the conformity of scale_factors. _interpolate - Abstract method. - -calculate_crop_pad - Return start- and end- coordinate. + (Protected) Abstract method. + _calculate_crop_pad + (Protected) Return start- and end- coordinate. """ def __init__( @@ -376,6 +376,8 @@ class Zoom2d(_ZoomNd): (Protected) Crops, interpolates and pads the tensor. """ + _crop_position: str + def __init__( self, target_shape: _T.Optional[_T.Sequence[int]], @@ -398,6 +400,28 @@ def __init__( if interpolation_mode not in ["nearest", "bilinear", "bicubic", "area"]: raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}") + self._N = 2 + super(Zoom2d, self).__init__(target_shape, interpolation_mode) + self.crop_position = crop_position + + @property + def crop_position(self) -> str: + """ + Property associated with the position of the image in the data. + """ + return self._crop_position + + @crop_position.setter + def crop_position(self, crop_position: str) -> None: + """ + Set the crop position. + + Parameters + ---------- + crop_position : str + The crop position key from 'top_left', 'bottom_left', top_right', + 'bottom_right', 'center'. + """ if crop_position not in [ "top_left", "bottom_left", @@ -406,11 +430,8 @@ def __init__( "center", ]: raise ValueError(f"invalid crop_position, got {crop_position}") - - self._N = 2 - super(Zoom2d, self).__init__(target_shape, interpolation_mode) - self.crop_position = crop_position - + self._crop_position = crop_position + def _interpolate( self, data: Tensor, @@ -514,6 +535,29 @@ def __init__( if interpolation_mode not in ["nearest", "trilinear", "area"]: raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}") + self._N = 3 + super(Zoom3d, self).__init__(target_shape, interpolation_mode) + self.crop_position = crop_position + + @property + def crop_position(self) -> str: + """ + Property associated with the position of the image in the data. + """ + return self._crop_position + + @crop_position.setter + def crop_position(self, crop_position: str) -> None: + """ + Set the crop position. + + Parameters + ---------- + crop_position : str + Crop position to use from 'front_top_left', 'back_top_left', + 'front_bottom_left', 'back_bottom_left', 'front_top_right', 'back_top_right', + 'front_bottom_right', 'back_bottom_right', 'center' (default: 'front_top_left'). + """ if crop_position not in [ "front_top_left", "back_top_left", @@ -526,11 +570,8 @@ def __init__( "center", ]: raise ValueError(f"invalid crop_position, got {crop_position}") - - self._N = 3 - super(Zoom3d, self).__init__(target_shape, interpolation_mode) self._crop_position = crop_position - + def _interpolate( self, data: Tensor, scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[int]] ): diff --git a/FastSurferCNN/quick_qc.py b/FastSurferCNN/quick_qc.py index dfea521c..bf531d9d 100644 --- a/FastSurferCNN/quick_qc.py +++ b/FastSurferCNN/quick_qc.py @@ -100,7 +100,7 @@ def check_volume(asegdkt_segfile:np.ndarray, voxvol: float, thres: float = 0.70) def get_region_bg_intersection_mask( seg_array, region_labels=VENT_LABELS, bg_label=BG_LABEL ): - """ + f""" Return a mask of the intersection between the voxels of a given region and background voxels. This is obtained by dilating the region by 1 voxel and computing the intersection with the @@ -112,18 +112,22 @@ def get_region_bg_intersection_mask( ---------- seg_array : numpy.ndarray Segmentation array. - region_labels : Dict - Dictionary whose values correspond to the desired region's labels (Default value = VENT_LABELS). - VENT_LABELS is a dictionary containing labels for different regions related to the ventricles, - such as "Left-Lateral-Ventricle", "Right-Lateral-Ventricle", etc., along with their - corresponding numeric values. - bg_label : int, default="BG_LABEL" - (Default value = BG_LABEL). + region_labels : dict, default= + Dictionary whose values correspond to the desired region's labels (see Note). + bg_label : int, default={BG_LABEL} + (Default value = {BG_LABEL}). Returns ------- bg_intersect : numpy.ndarray Region and background intersection mask array. + + Notes + ----- + VENT_LABELS is a dictionary containing labels for four regions related to the ventricles: + "Left-Lateral-Ventricle", "Right-Lateral-Ventricle", "Left-choroid-plexus", + "Right-choroid-plexus" along with their corresponding integer label values + (see also FreeSurferColorLUT.txt). """ region_array = seg_array.copy() conditions = np.all( diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index 9847955e..7e4fbf8f 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -107,37 +107,6 @@ def args2cfg( return cfg_fin, cfg_cor, cfg_sag, cfg_ax -def removesuffix(string: str, suffix: str) -> str: - """ - Remove a suffix from a string. - - Similar to string.removesuffix in PY3.9+, - - Parameters - ---------- - string : str - String to be cut. - suffix : str - Suffix to be removed. - - Returns - ------- - Str - Suffix removed string. - """ - import sys - - if sys.version_info.minor >= 9: - # removesuffix is a Python3.9 feature - return string.removesuffix(suffix) - else: - return ( - string[: -len(suffix)] - if len(suffix) > 0 and string.endswith(suffix) - else string - ) - - ## # Input array preparation ## diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index 6891f1a6..7f4dc404 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -173,7 +173,7 @@ def int_ge_zero(value) -> int: def unquote_str(value) -> str: """ - Unquote a (single quoted) string. + Unquote a (single quoted) string, i.e. remove one level of single-quotes. Parameters ---------- @@ -183,7 +183,7 @@ def unquote_str(value) -> str: Returns ------- val : str - A string of the value without quoting with '''. + A string of the value without leading and trailing single-quotes. """ val = str(value) if val.startswith("'") and val.endswith("'"): diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index 1d38b723..6adf3aad 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -41,7 +41,6 @@ "iterate", "NoParallelExecutor", "pipeline", - "removesuffix", "SubjectList", "SubjectDirectory", ] @@ -190,20 +189,19 @@ def pipeline( ---------- pool : Executor Thread pool executor for parallel execution. - - func : Callable[[_Ti], _T] : + func : callable Function to use. - - iterable : Iterable[_Ti] + iterable : Iterable Iterable containing input elements. - - pipeline_size : int - Size of the processing pipeline (default is 1). + pipeline_size : int, default=1 + Size of the processing pipeline. Yields ------ - Iterator[Tuple[_Ti, _T]] - Yielding elements from iterable and corresponding results from func(element). + element : _Ti + Elements + _T + Results of func corresponding to element: func(element). """ # do pipeline loading the next element from collections import deque @@ -231,54 +229,23 @@ def iterate( Parameters ---------- pool : Executor - Callable. - func : Iterable + The Executor object (dummy object to have a common API with pipeline). + func : callable Function to use. - iterable : Iterable[_Ti] - Iterable. + iterable : Iterable + Iterable to draw objects to process with func from. Yields ------ - element : _Ti + element : _Ti Elements _T - Func(element). + Results of func corresponding to element: func(element). """ for element in iterable: yield element, func(element) -def removesuffix(string: str, suffix: str) -> str: - """ - Remove a suffix from a string. - - Similar to string.removesuffix in PY3.9+. - - Parameters - ---------- - string : str - String that should be edited. - suffix : str - Suffix to remove. - - Returns - ------- - str - Input string with removed suffix. - """ - import sys - - if sys.version_info.minor >= 9: - # removesuffix is a Python3.9 feature - return string.removesuffix(suffix) - else: - return ( - string[: -len(suffix)] - if len(suffix) > 0 and string.endswith(suffix) - else string - ) - - class SubjectDirectory: """ Represent a subject directory. @@ -1012,7 +979,7 @@ def __getitem__(self, item: Union[int, str]) -> SubjectDirectory: # subject is always an absolute path (or relative to the working directory) ... of the input file subject = self._subjects[item] sid = ( - os.path.basename(removesuffix(subject, self._remove_suffix)) + os.path.basename(subject.removesuffix(self._remove_suffix)) if self._sid is None else self._sid )