From 6d1735236477965b97ce36114a7252ef4c0a14c1 Mon Sep 17 00:00:00 2001 From: Taha Abdullah Date: Fri, 31 May 2024 16:12:52 +0200 Subject: [PATCH 1/8] Documentation for HypVINN: - Added docstring for functions in HypVINN --- HypVINN/config/hypvinn.py | 9 +- HypVINN/data_loader/data_utils.py | 147 ++++++++++++++---- HypVINN/data_loader/dataset.py | 104 +++++++++++-- HypVINN/inference.py | 210 +++++++++++++++++++++++++- HypVINN/models/networks.py | 118 ++++++++++++--- HypVINN/run_prediction.py | 145 +++++++++++++++++- HypVINN/utils/img_processing_utils.py | 113 +++++++++++++- HypVINN/utils/load_config.py | 24 ++- HypVINN/utils/mode_config.py | 21 +++ HypVINN/utils/preproc.py | 57 ++++++- HypVINN/utils/stats_utils.py | 25 +++ HypVINN/utils/visualization_utils.py | 104 ++++++++++++- doc/api/HypVINN_dataloader.rst | 3 + doc/api/HypVINN_models.rst | 2 + 14 files changed, 998 insertions(+), 84 deletions(-) diff --git a/HypVINN/config/hypvinn.py b/HypVINN/config/hypvinn.py index 48c3e1516..8df5807ee 100644 --- a/HypVINN/config/hypvinn.py +++ b/HypVINN/config/hypvinn.py @@ -258,7 +258,14 @@ def get_cfg_hypvinn(): - """Get a yacs CfgNode object with default values for my_project.""" + """ + Get a yacs CfgNode object with default values for HypVINN project. + + Returns + ------- + _C : yacs.config.CfgNode + A clone of the default configuration node for the HypVINN project. + """ # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern return _C.clone() \ No newline at end of file diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index b07a9af02..1ae0aca1f 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -26,7 +26,7 @@ ## -def calculate_flip_orientation(iornt, base_ornt): +def calculate_flip_orientation(iornt: np.ndarray, base_ornt: np.ndarray) -> np.ndarray: """ Compute the flip orientation transform. @@ -34,11 +34,15 @@ def calculate_flip_orientation(iornt, base_ornt): Parameters ---------- - iornt - base_ornt + iornt: np.ndarray + Initial orientation. + base_ornt: np.ndarray + Base orientation. Returns ------- + new_iornt: np.ndarray + New orientation. """ new_iornt=iornt.copy() @@ -56,21 +60,26 @@ def calculate_flip_orientation(iornt, base_ornt): def reorient_img(img, ref_img): """ + Reorient a Nibabel image based on the orientation of a reference nibabel image. + Function to reorient a Nibabel image based on the orientation of a reference nibabel image - The orientation transform. ornt[N,1]` is flip of axis N of the array implied by `shape`, where 1 means no flip and -1 means flip. - For example, if ``N==0 and ornt[0,1] == -1, and there’s an array arr of shape shape, the flip would correspond to the effect of - np.flipud(arr). ornt[:,0] is the transpose that needs to be done to the implied array, as in arr.transpose(ornt[:,0]) + The orientation transform. ornt[N,1]` is flip of axis N of the array implied by `shape`, where 1 means no flip and + -1 means flip. For example, if ``N==0 and ornt[0,1] == -1, and there’s an array arr of shape shape, the flip would + correspond to the effect of np.flipud(arr). ornt[:,0] is the transpose that needs to be done to the implied array, + as in arr.transpose(ornt[:,0]). Parameters ---------- - img: nibabel Image to reorient - base_img: referece orientation nibabel image + img : nibabel.Nifti1Image + Nibabel Image to reorient. + ref_img : nibabel.Nifti1Image + Reference orientation nibabel image. Returns ------- - + img : nibabel.Nifti1Image + Reoriented image. """ - ref_ornt =nib.io_orientation(ref_img.affine) iornt=nib.io_orientation(img.affine) @@ -86,13 +95,26 @@ def reorient_img(img, ref_img): return img -def transform_axial2coronal(vol, axial2coronal=True): +def transform_axial2coronal(vol: np.ndarray, axial2coronal: bool = True) -> np.ndarray: """ - Function to transform volume into coronal axis and back - :param np.ndarray vol: image volume to transform - :param bool axial2coronal: transform from axial to coronal = True (default), - transform from coronal to axial = False - :return: + Transforms a volume into the coronal axis and back. + + This function is used to transform a volume into the coronal axis and back. The transformation is done by moving + the axes of the volume. If the `axial2coronal` parameter is set to True, the function will transform from axial to + coronal. If it is set to False, the function will transform from coronal to axial. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + axial2coronal : bool, optional + A flag to determine the direction of the transformation. If True, transform from axial to coronal. If False, + transform from coronal to axial. (Default: True). + + Returns + ------- + np.ndarray + The transformed volume. """ # TODO check compatibility with axis transform from CerebNet if axial2coronal: @@ -101,13 +123,26 @@ def transform_axial2coronal(vol, axial2coronal=True): return np.moveaxis(vol, [0, 1, 2], [0, 2, 1]) -def transform_axial2sagittal(vol, axial2sagittal=True): +def transform_axial2sagittal(vol: np.ndarray, axial2sagittal: bool = True) -> np.ndarray: """ - Function to transform volume into Sagittal axis and back - :param np.ndarray vol: image volume to transform - :param bool coronal2sagittal: transform from coronal to sagittal = True (default), - transform from sagittal to coronal = False - :return: + Transforms a volume into the sagittal axis and back. + + This function is used to transform a volume into the sagittal axis and back. The transformation is done by moving + the axes of the volume. If the `axial2sagittal` parameter is set to True, the function will transform from axial to + sagittal. If it is set to False, the function will transform from sagittal to axial. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + axial2sagittal : bool, optional + A flag to determine the direction of the transformation. If True, transform from axial to sagittal. If False, + transform from sagittal to axial. (Default: True). + + Returns + ------- + np.ndarray + The transformed volume. """ # TODO check compatibility with axis transform from CerebNet if axial2sagittal: @@ -116,7 +151,22 @@ def transform_axial2sagittal(vol, axial2sagittal=True): return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) -def rescale_image(img_data): +def rescale_image(img_data: np.ndarray) -> np.ndarray: + """ + Rescale the image data to the range [0, 255]. + + This function rescales the input image data to the range [0, 255]. + + Parameters + ---------- + img_data : np.ndarray + The image data to rescale. + + Returns + ------- + np.ndarray + The rescaled image data. + """ # Conform intensities # TODO move function into FastSurferCNN, same: CerebNet.datasets.utils.rescale_image src_min, scale = getscale(img_data, 0, 255) @@ -131,7 +181,19 @@ def rescale_image(img_data): def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]: """ - Function to perform look-up table mapping from label space to subseg space + Perform look-up table mapping from label space to subseg space. + + This function is used to perform a look-up table mapping from label space to subseg space. + + Parameters + ---------- + mapped_subseg : npt.NDArray[int] + The input array in label space to be mapped to subseg space. + + Returns + ------- + npt.NDArray[int] + The mapped array in subseg space. """ # TODO can this function be replaced by a Mapper and a mapping file? labels, _ = hyposubseg_labels @@ -143,13 +205,23 @@ def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]: def hypo_map_prediction_sagittal2full( - prediction_sag: npt.NDArray[int], +prediction_sag: npt.NDArray[int], ) -> npt.NDArray[int]: """ - Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks - :param prediction_sag: sagittal prediction (labels) - :param lbl_type: type of label - :return: Remapped prediction + Remap the prediction on the sagittal network to full label space. + + This function is used to remap the prediction on the sagittal network to the full label space used by the coronal + and axial networks. + + Parameters + ---------- + prediction_sag : npt.NDArray[int] + The sagittal prediction in label space to be remapped to full label space. + + Returns + ------- + npt.NDArray[int] + The remapped prediction in full label space. """ # TODO can this function be replaced by a Mapper and a mapping file? @@ -163,15 +235,24 @@ def hypo_map_subseg_2_fsseg( reverse: bool = False, ) -> npt.NDArray[int]: """ - Function to remap HypVINN internal labels to FastSurfer Labels and viceversa + Remap HypVINN internal labels to FastSurfer Labels and vice versa. + + This function is used to remap HypVINN internal labels to FastSurfer Labels and vice versa. If the `reverse` + parameter is set to False, the function will map HypVINN labels to FastSurfer labels. If it is set to True, + the function will map FastSurfer labels to HypVINN labels. + Parameters ---------- - subseg - reverse + subseg : npt.NDArray[int] + The input array with HypVINN or FastSurfer labels to be remapped. + reverse : bool, optional + A flag to determine the direction of the remapping. If False, remap HypVINN labels to FastSurfer labels. + If True, remap FastSurfer labels to HypVINN labels. (Default: False). Returns ------- - + npt.NDArray[int] + The remapped array with FastSurfer or HypVINN labels. """ # TODO can this function be replaced by a Mapper and a mapping file? diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py index c22c4bc2a..7018cf31f 100644 --- a/HypVINN/data_loader/dataset.py +++ b/HypVINN/data_loader/dataset.py @@ -30,12 +30,27 @@ # Operator to load imaged for inference class HypoVINN_dataset(Dataset): """ - Class to load MRI-Image and process it to correct format for HypVINN network inference - The HypVINN Dataset passed during Inference the input images,the scale factor for the VINN layer and a weight factor (wT1,wT2) - The Weight factor determines the running mode of the HypVINN model - if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1) - if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2) - if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based on the learned modality weights (mode = t1t2) + Class to load MRI-Image and process it to correct format for HypVINN network inference. + + The HypVINN Dataset passed during Inference the input images,the scale factor for the VINN layer and a weight factor + (wT1,wT2). + The Weight factor determines the running mode of the HypVINN model. + if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1). + if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2). + if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based + on the learned modality weights (mode = t1t2). + + Methods + ------- + _standarized_img(orig_data: np.ndarray, orig_zoom: npt.NDArray[float], modality: np.ndarray) -> np.ndarray + Standardize the image based on the original data, original zoom, and modality. + _get_scale_factor() -> npt.NDArray[float] + Get the scaling factor to match the original resolution of the input image to the final resolution of the + FastSurfer base network. + __getitem__(index: int) -> dict[str, torch.Tensor | np.ndarray] + Retrieve the image, scale factor, and weight factor for a given index. + __len__() + Return the number of images in the dataset. """ def __init__( self, @@ -46,6 +61,25 @@ def __init__( mode: ModalityMode = "t1t2", transforms=None, ): + """ + Initialize the HypoVINN Dataset. + + Parameters + ---------- + subject_name : str + The name of the subject. + modalities : ModalityDict + The modalities of the subject. + orig_zoom : npt.NDArray[float] + The original zoom of the subject. + cfg : CfgNode + The configuration object. + mode : ModalityMode, optional + The running mode of the HypVINN model. (Default: "t1t2"). + transforms : Callable, optional + The transformations to apply to the images. (Default: None). + + """ self.subject_name = subject_name self.plane = cfg.DATA.PLANE #Inference Mode @@ -84,7 +118,8 @@ def __init__( f"model" ) - if (cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT) : + if ((cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and + (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT)) : logger.info( f"For inference T1 block weight and the T2 block are set to " f"the weights learn during training" @@ -95,7 +130,25 @@ def __init__( f"{self.weight_factor.numpy()[0]} and the T2 block was set to: " f"{self.weight_factor.numpy()[1]}") - def _standarized_img(self, orig_data, orig_zoom, modality): + def _standarized_img(self, orig_data: np.ndarray, orig_zoom: npt.NDArray[float], + modality: np.ndarray) -> np.ndarray: + """ + Standardize the image based on the original data, original zoom, and modality. + + Parameters + ---------- + orig_data : np.ndarray + The original data of the image. + orig_zoom : npt.NDArray[float] + The original zoom of the image. + modality : np.ndarray + The modality of the image. + + Returns + ------- + orig_thick : np.ndarray + The standardized image. + """ if self.plane == "sagittal": orig_data = transform_axial2sagittal(orig_data) self.zoom = orig_zoom[::-1][:2] @@ -123,19 +176,40 @@ def _standarized_img(self, orig_data, orig_zoom, modality): def _get_scale_factor(self) -> npt.NDArray[float]: """ - Get scaling factor to match original resolution of input image to - final resolution of FastSurfer base network. Input resolution is - taken from voxel size in image header. - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. - :param img_zoom: - :return np.ndarray(float32): scale factor along x and y dimension + Get the scaling factor to match the original resolution of the input image to + the final resolution of the FastSurfer base network. The input resolution is + taken from the voxel size in the image header. + + Returns + ------- + scale : npt.NDArray[float] + The scaling factor along the x and y dimensions. This is a numpy array of float values. """ + # TODO: This needs to be updated based on the plane we are looking at in case we + # are dealing with non-isotropic images as inputs. + scale = self.base_res / np.asarray(self.zoom) return scale def __getitem__(self, index: int) -> dict[str, torch.Tensor | np.ndarray]: + """ + Retrieve the image, scale factor, and weight factor for a given index. + + This method retrieves the image at the given index from the images attribute, calculates the scale factor, + applies any transformations to the image if they are defined, and returns a dictionary containing the image, + scale factor, and weight factor. + + Parameters + ---------- + index : int + The index of the image to retrieve. + + Returns + ------- + dict[str, torch.Tensor | np.ndarray] + A dictionary containing the image, scale factor, and weight factor. + """ img = self.images[index] scale_factor = self._get_scale_factor() diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 85cac5f33..8630f3f92 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -17,6 +17,7 @@ import torch import numpy as np +import yacs.config from tqdm import tqdm from torch.utils.data import DataLoader from torchvision import transforms @@ -33,6 +34,21 @@ class Inference: + """ + Class for running inference on a single subject. + + Attributes + ---------- + model : torch.nn.Module + The model to use for inference. + model_name : str + The name of the model. + + Methods + ------- + setup_model(cfg) + Set up the model. + """ def __init__( self, cfg, @@ -41,7 +57,26 @@ def __init__( device: str = "auto", viewagg_device: str = "auto", ): - + """ + Initialize the Inference class. + + This method initializes the Inference class with the provided configuration, number of threads, async IO flag, + device, and view aggregation device. It sets the random seed, switches on denormal flushing, defines the device, + and sets up the initial model. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node containing the parameters for the model. + threads : int, optional + The number of threads to use. Default is -1, which uses all available threads. + async_io : bool, optional + Whether to use asynchronous IO. Default is False. + device : str, optional + The device to use for computations. Can be 'auto', 'cpu', or 'cuda'. Default is 'auto'. + viewagg_device : str, optional + The device to use for view aggregation. Can be 'auto', 'cpu', or 'cuda'. Default is 'auto'. + """ self._threads = threads torch.set_num_threads(self._threads) self._async_io = async_io @@ -81,6 +116,21 @@ def setup_model( self, cfg: Optional["yacs.config.CfgNode"] = None, ) -> torch.nn.Module: + """ + Set up the model. + + This method sets up the model for inference. + + Parameters + ---------- + cfg : yacs.config.CfgNode, optional + The configuration node containing the parameters for the model. + + Returns + ------- + model : torch.nn.Module + The model set up for inference. + """ if cfg is not None: self.cfg = cfg @@ -91,9 +141,25 @@ def setup_model( return model def set_cfg(self, cfg): + """ + Set the configuration node. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node containing the parameters for the model. + """ self.cfg = cfg - def set_model(self, cfg=None): + def set_model(self, cfg: yacs.config.CfgNode = None): + """ + Set the model for the Inference instance. + + Parameters + ---------- + cfg : yacs.config.CfgNode, optional + The configuration node containing the parameters for the model. (Default = None). + """ if cfg is not None: self.cfg = cfg @@ -102,41 +168,151 @@ def set_model(self, cfg=None): model.to(self.device) self.model = model - def load_checkpoint(self, ckpt): + def load_checkpoint(self, ckpt: str): + """ + Load a model checkpoint. + + This method loads a model checkpoint from a .pth file containing a state dictionary of a model. + + Parameters + ---------- + ckpt : str + The path to the checkpoint file. The checkpoint file should be a .pth file containing a state dictionary + of a model. + """ logger.info("Loading checkpoint {}".format(ckpt)) model_state = torch.load(ckpt, map_location=self.device) self.model.load_state_dict(model_state["model_state"]) def get_modelname(self): + """ + Get the name of the model. + + This method returns the name of the model used in the Inference instance. + + Returns + ------- + str + The name of the model. + """ return self.model_name def get_cfg(self): + """ + Get the configuration node. + + This method returns the configuration node used in the Inference instance. + + Returns + ------- + yacs.config.CfgNode + The configuration node containing the parameters for the model. + """ return self.cfg def get_num_classes(self): + """ + Get the number of classes. + + This method returns the number of classes defined in the model configuration. + + Returns + ------- + int + The number of classes. + """ return self.cfg.MODEL.NUM_CLASSES def get_plane(self): + """ + Get the plane. + + This method returns the plane defined in the data configuration. + + Returns + ------- + str + The plane. + """ return self.cfg.DATA.PLANE def get_model_height(self): + """ + Get the model height. + + This method returns the height of the model defined in the model configuration. + + Returns + ------- + int + The height of the model. + """ return self.cfg.MODEL.HEIGHT def get_model_width(self): + """ + Get the model width. + + This method returns the width of the model defined in the model configuration. + + Returns + ------- + int + The width of the model. + """ return self.cfg.MODEL.WIDTH def get_max_size(self): + """ + Get the maximum size of the output tensor. + + Returns + ------- + int or tuple + The maximum size. If the width and height of the output tensor are equal, it returns the width. Otherwise, it + returns both the width and height. + """ if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT: return self.cfg.MODEL.OUT_TENSOR_WIDTH else: return self.cfg.MODEL.OUT_TENSOR_WIDTH, self.cfg.MODEL.OUT_TENSOR_HEIGHT def get_device(self): + """ + Get the device. + + This method returns the device and view aggregation device used in the Inference instance. + + Returns + ------- + tuple + The device and view aggregation device. + """ return self.device,self.viewagg_device #TODO check is possible to modify to CerebNet inference mode from RAS directly to LIA (CerebNet.Inference._predict_single_subject) @torch.no_grad() - def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale=None): + def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float = None) -> torch.Tensor: + """ + Evaluate the model on a validation set. + + This method runs the model in evaluation mode on a validation set. It iterates over the validation set, + computes the model's predictions, and updates the prediction probabilities based on the plane of the data. + + Parameters + ---------- + val_loader : DataLoader + The DataLoader for the validation set. + pred_prob : torch.Tensor + The tensor to update with the prediction probabilities. + out_scale : float, optional + The scale factor for the output. Default is None. + + Returns + ------- + pred_prob: torch.Tensor + The updated prediction probabilities. + """ self.model.eval() start_index = 0 @@ -176,6 +352,32 @@ def run( out_res=None, mode: ModalityMode = "t1t2", ): + """ + Run the inference process on a single subject. + + This method sets up a DataLoader for the subject, runs the model in evaluation mode on the subject's data, + and returns the updated prediction probabilities. + + Parameters + ---------- + subject_name : str + The name of the subject. + modalities : ModalityDict + The modalities of the subject. + orig_zoom : npt.NDArray[float] + The original zoom of the subject. + pred_prob : torch.Tensor + The tensor to update with the prediction probabilities. + out_res : float, optional + The resolution of the output. Default is None. + mode : ModalityMode, optional + The mode of the modalities. Default is 't1t2'. + + Returns + ------- + pred_prob: torch.Tensor + The updated prediction probabilities. + """ # Set up DataLoader test_dataset = HypoVINN_dataset( subject_name, diff --git a/HypVINN/models/networks.py b/HypVINN/models/networks.py index a26b80245..ca2cfbfd2 100644 --- a/HypVINN/models/networks.py +++ b/HypVINN/models/networks.py @@ -16,6 +16,8 @@ # IMPORTS from typing import Dict + +import yacs.config from torch import Tensor, nn import torch import FastSurferCNN.models.sub_module as sm @@ -26,16 +28,63 @@ class HypVINN(FastSurferCNNBase): """ - Construct HypVINN object. - - Parameters - ---------- - params : Dict - Dictionary of configurations. - padded_size : int - Size of image when padded (Default value = 256). - """ + HypVINN class that extends the FastSurferCNNBase class. + + This class represents a HypVINN model. It includes methods for initializing the model, setting up the layers, + and performing forward propagation. + + Attributes + ---------- + height : int + The height of the output tensor. + width : int + The width of the output tensor. + out_tensor_shape : tuple + The shape of the output tensor. + interpolation_mode : str + The interpolation mode to use when resizing the images. This can be 'nearest', 'bilinear', 'bicubic', or 'area'. + crop_position : str + The position to crop the images from. This can be 'center', 'top_left', 'top_right', 'bottom_left', or 'bottom_right'. + m1_inp_block : InputDenseBlock + The input block for the first modality. + m2_inp_block : InputDenseBlock + The input block for the second modality. + mod_weights : nn.Parameter + The weights for the two modalities. + normalize_weights : nn.Softmax + A softmax function to normalize the modality weights. + outp_block : OutputDenseBlock + The output block of the model. + interpol1 : Zoom2d + The first interpolation layer. + interpol2 : Zoom2d + The second interpolation layer. + classifier : ClassifierBlock + The final classifier block of the model. + + Methods + ------- + forward(x, scale_factor, weight_factor, scale_factor_out=None) + Perform forward propagation through the model. + """ def __init__(self, params, padded_size=256): + """ + Initialize the HypVINN model. + + This method initializes the HypVINN model by calling the super class constructor and setting up the layers. + + Parameters + ---------- + params : Dict + A dictionary containing the configuration parameters for the model. + padded_size : int, optional + The size of the image when padded. (Default = 256). + + Raises + ------ + ValueError + If the interpolation mode or crop position is invalid. + """ num_c = params["num_channels"] params["num_channels"] = params["num_filters_interpol"] @@ -104,10 +153,40 @@ def __init__(self, params, padded_size=256): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) - def forward(self, x, scale_factor, weight_factor, scale_factor_out=None): + def forward(self, x: torch.Tensor, scale_factor: torch.Tensor, weight_factor: torch.Tensor, + scale_factor_out: torch.Tensor = None) -> torch.Tensor: + """ + Forward propagation method for the HypVINN model. + + This method takes an input tensor, a scale factor, a weight factor, and an optional output scale factor. + It performs forward propagation through the model, applying the input blocks, interpolation layers, output + block, and classifier block. It also handles the weighting of the two modalities and the rescaling of the + output. + + Parameters + ---------- + x : torch.Tensor + The input tensor. It should have a shape of (batch_size, num_channels, height, width). + scale_factor : torch.Tensor + The scale factor for the input images. It should have a shape of (batch_size, 2). + weight_factor : torch.Tensor + The weight factor for the two modalities. It should have a shape of (batch_size, 2). + scale_factor_out : torch.Tensor, optional + The scale factor for the output images. If not provided, it defaults to the scale factor of the input images. + + Returns + ------- + logits : torch.Tensor + The output logits from the classifier block. It has a shape of (batch_size, num_classes, height, width). + + Raises + ------ + ValueError + If the interpolation mode or crop position is invalid. + """ # Weight factor [wT1,wT2] has 3 stages [1,0],[0.5,0.5],[0,1], - #if the weight factor is [0.5,0.5] the automatically weights (s_weights) are passed - #If there is a 1 in the comparison the automatically weights will be replace by the first weight_factors pass + # if the weight factor is [0.5,0.5] the automatically weights (s_weights) are passed + # If there is a 1 in the comparison the automatically weights will be replace by the first weight_factors pass comparison = weight_factor[0] x = torch.tensor_split(x, 2, dim=1) @@ -159,19 +238,24 @@ def forward(self, x, scale_factor, weight_factor, scale_factor_out=None): } -def build_model(cfg) -> HypVINN: +def build_model(cfg: yacs.config.CfgNode) -> HypVINN: """ - Build requested model. + Build and return the requested model. Parameters ---------- cfg : yacs.config.CfgNode - Node of configs to be used. + The configuration node containing the parameters for the model. Returns ------- - model - Object of the initialized model. + HypVINN + An instance of the requested model. + + Raises + ------ + AssertionError + If the model specified in the configuration is not supported. """ if cfg.MODEL.MODEL_NAME not in _MODELS: raise AssertionError(f"Model {cfg.MODEL.MODEL_NAME} not supported") diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index df7efdb9d..14b1317de 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -59,13 +59,13 @@ def optional_path(a: Path | str) -> Optional[Path]: Parameters ---------- - a : str - The string to convert. + a : Path | str + The input to convert. Returns ------- Optional[Path] - The Path object or None. + The converted Path object. """ if isinstance(a, Path): return a @@ -80,7 +80,7 @@ def option_parse() -> argparse.ArgumentParser: Returns ------- - argparse.Ar + argparse.ArgumentParser The parser object to parse arguments from the command line. """ parser = argparse.ArgumentParser( @@ -188,6 +188,44 @@ def main( Parameters ---------- + out_dir : Path + The output directory where the results will be stored. + t2 : Optional[Path] + The path to the T2 image to process. + orig_name : Optional[Path] + The original name of the input image. + sid : str + The subject ID. + ckpt_ax : Path + The path to the axial checkpoint file. + ckpt_cor : Path + The path to the coronal checkpoint file. + ckpt_sag : Path + The path to the sagittal checkpoint file. + cfg_ax : Path + The path to the axial configuration file. + cfg_cor : Path + The path to the coronal configuration file. + cfg_sag : Path + The path to the sagittal configuration file. + hypo_segfile : str, optional + The name of the hypothalamus segmentation file. Default is HYPVINN_SEG_NAME. + allow_root : bool, optional + Whether to allow running as root user. Default is False. + qc_snapshots : bool, optional + Whether to create QC snapshots. Default is False. + reg_mode : Literal["coreg", "robust", "none"], optional + The registration mode to use. Default is "coreg". + threads : int, optional + The number of threads to use. Default is -1, which uses all available threads. + batch_size : int, optional + The batch size to use. Default is 1. + async_io : bool, optional + Whether to use asynchronous I/O. Default is False. + device : str, optional + The device to use. Default is "auto", which automatically selects the device. + viewagg_device : str, optional + The view aggregation device to use. Default is "auto", which automatically selects the device. Returns ------- @@ -394,6 +432,21 @@ def main( def prepare_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag): + """ + Prepare the checkpoints for the Hypothalamus Segmentation model. + + This function checks if the checkpoint files for the axial, coronal, and sagittal planes exist. + If they do not exist, it downloads them from the default URLs specified in the configuration file. + + Parameters + ---------- + ckpt_ax : str + The path to the axial checkpoint file. + ckpt_cor : str + The path to the coronal checkpoint file. + ckpt_sag : str + The path to the sagittal checkpoint file. + """ logger.info("Checking or downloading default checkpoints ...") urls = load_checkpoint_config_defaults( "url", @@ -413,6 +466,40 @@ def load_volumes( tuple[float, float, float], tuple[int, int, int], ]: + """ + Load the volumes of T1 and T2 images. + + This function loads the T1 and T2 images, checks their compatibility based on the mode, and returns the loaded + volumes along with their affine transformations, headers, zoom levels, and sizes. + + Parameters + ---------- + mode : ModalityMode + The mode of operation. Can be 't1', 't2', or 't1t2'. + t1_path : Optional[Path], optional + The path to the T1 image. Default is None. + t2_path : Optional[Path], optional + The path to the T2 image. Default is None. + + Returns + ------- + tuple + A tuple containing the following elements: + - modalities: A dictionary with keys 't1' and/or 't2' and values being the corresponding loaded and rescaled images. + - affine: The affine transformation of the loaded image(s). + - header: The header of the loaded image(s). + - zoom: The zoom level of the loaded image(s). + - size: The size of the loaded image(s). + + Raises + ------ + RuntimeError + If the mode is inconsistent with the provided image paths, or if the number of dimensions of the data is invalid. + ValueError + If the mode is invalid, or if a header is missing. + AssertionError + If the mode is 't1t2' but the T1 and T2 images have different resolutions or sizes. + """ import nibabel as nib modalities: ModalityDict = {} @@ -489,7 +576,36 @@ def get_prediction( out_scale=None, mode: ModalityMode = "t1t2", ) -> npt.NDArray[int]: + """ + Run the prediction for the Hypothalamus Segmentation model. + + This function sets up the prediction process for the Hypothalamus Segmentation model. It runs the model for each + plane (axial, coronal, sagittal), accumulates the prediction probabilities, and then generates the final prediction. + Parameters + ---------- + subject_name : str + The name of the subject. + modalities : ModalityDict + A dictionary containing the modalities (T1 and/or T2) and their corresponding images. + orig_zoom : npt.NDArray[float] + The original zoom of the subject. + model : Inference + The Inference object of the model. + target_shape : tuple[int, int, int] + The target shape of the output prediction. + view_opts : ViewOperations + A dictionary containing the configurations for each plane. + out_scale : optional + The output scale. Default is None. + mode : ModalityMode, optional + The mode of operation. Can be 't1', 't2', or 't1t2'. Default is 't1t2'. + + Returns + ------- + pred_classes: npt.NDArray[int] + The final prediction of the model. + """ # TODO There are probably several possibilities to accelerate this script. # FastSurferVINN takes 7-8s vs. HypVINN 10+s per slicing direction. # Solution: make this script/function more similar to the optimized FastSurferVINN @@ -535,6 +651,27 @@ def set_up_cfgs( out_dir: Path, batch_size: int = 1, ) -> "yacs.config.CfgNode": + """ + Set up the configuration for the Hypothalamus Segmentation model. + + This function loads the configuration, sets the output directory and batch size, and adjusts the output tensor + dimensions based on the padded size specified in the configuration. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node to load. + out_dir : Path + The output directory where the results will be stored. + batch_size : int, optional + The batch size to use. Default is 1. + + Returns + ------- + yacs.config.CfgNode + The loaded and adjusted configuration node. + + """ cfg = load_config(cfg) cfg.OUT_LOG_DIR = str(out_dir or cfg.LOG_DIR) cfg.TEST.BATCH_SIZE = batch_size diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index efcf97c12..2ea77262f 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -26,7 +26,22 @@ LOGGER = logging.get_logger(__name__) -def img2axcodes(img): +def img2axcodes(img: nib.Nifti1Image) -> tuple: + """ + Convert the affine matrix of an image to axis codes. + + This function takes an image as input and returns the axis codes corresponding to the affine matrix of the image. + + Parameters + ---------- + img : nibabel image object + The input image. + + Returns + ------- + tuple + The axis codes corresponding to the affine matrix of the image. + """ return nib.aff2axcodes(img.affine) @@ -34,11 +49,40 @@ def save_segmentation( prediction: np.ndarray, orig_path: Path, ras_affine: npt.NDArray[float], - ras_header, + ras_header: nib.nifti1.Nifti1Header, subject_dir: Path, seg_file: Path, save_mask: bool = False, ) -> float: + """ + Save the segmentation results. + + This function takes the prediction results, cleans the labels, maps them to FreeSurfer Hypvinn Labels, and saves + the results. It also reorients the mask and prediction images to match the original image's orientation. + + Parameters + ---------- + prediction : np.ndarray + The prediction results. + orig_path : Path + The path to the original image. + ras_affine : npt.NDArray[float] + The affine transformation of the RAS orientation. + ras_header : nibabel header object + The header of the RAS orientation. + subject_dir : Path + The directory where the subject's data is stored. + seg_file : Path + The file where the segmentation results will be saved. + save_mask : bool, optional + Whether to save the mask or not. Default is False. + + Returns + ------- + float + The time taken to save the segmentation. + + """ from time import time starttime = time() from HypVINN.data_loader.data_utils import reorient_img @@ -74,10 +118,37 @@ def save_logits( logits: npt.NDArray[float], orig_path: Path, ras_affine: npt.NDArray[float], - ras_header, + ras_header: nib.nifti1.Nifti1Header, save_dir: Path, mode: str, ) -> Path: + """ + Save the logits (raw model outputs) as a NIfTI image. + + This function takes the logits, reorients the image to match the original image's orientation, and saves the + results. + + Parameters + ---------- + logits : npt.NDArray[float] + The raw model outputs. + orig_path : Path + The path to the original image. + ras_affine : npt.NDArray[float] + The affine transformation of the RAS orientation. + ras_header : nib.nifti1.Nifti1Header + The header of the RAS orientation. + save_dir : Path + The directory where the logits will be saved. + mode : str + The mode of operation. + + Returns + ------- + save_as: Path + The path where the logits were saved. + + """ from HypVINN.data_loader.data_utils import reorient_img orig_img = nib.load(orig_path) LOGGER.info(f"Orig data orientation: {img2axcodes(orig_img)}") @@ -97,7 +168,32 @@ def save_logits( return save_as -def get_clean_mask(segmentation, optic=False): +def get_clean_mask(segmentation: np.ndarray, optic=False) \ + -> tuple[np.ndarray, np.ndarray, bool]: + """ + Get a clean mask by removing not connected components. + + This function takes a segmentation mask and an optional boolean flag indicating whether to consider optic labels. + It removes not connected components from the segmentation mask and returns the cleaned segmentation mask, the + labels of the connected components, and a flag indicating whether to save the mask. + + Parameters + ---------- + segmentation : np.ndarray + The input segmentation mask. + optic : bool, optional + A flag indicating whether to consider optic labels. Default is False. + + Returns + ------- + clean_seg : np.ndarray + The cleaned segmentation mask. + labels_cc : np.ndarray + The labels of the connected components in the segmentation mask. + savemask : bool + A flag indicating whether to save the mask. + + """ savemask = False # Remove not connected components @@ -146,7 +242,7 @@ def get_clean_mask(segmentation, optic=False): return clean_seg, labels_cc, savemask -def get_clean_labels(segmentation): +def get_clean_labels(segmentation: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Function to find the largest connected component of the segmentation. @@ -154,6 +250,13 @@ def get_clean_labels(segmentation): ---------- segmentation: np.ndarray The segmentation mask. + + Returns + ------- + clean_seg: np.ndarray + The cleaned segmentation mask. + labels_cc: np.ndarray + The labels of the connected components in the segmentation mask. """ # Mask largest CC without optic labels diff --git a/HypVINN/utils/load_config.py b/HypVINN/utils/load_config.py index 9da2635cc..ad9fd3484 100644 --- a/HypVINN/utils/load_config.py +++ b/HypVINN/utils/load_config.py @@ -19,7 +19,16 @@ def get_config(args): """ - Given the arguemnts, load and initialize the configs. + Given the arguments, load and initialize the configs. + + Parameters + ---------- + args : object + The arguments object. + Returns + ------- + cfg : yacs.config.CfgNode + The configuration node. """ # Setup cfg. cfg = get_cfg_hypvinn() @@ -41,6 +50,19 @@ def get_config(args): return cfg def load_config(cfg_file): + """ + Load and initialize the configuration from a given file. + + Parameters + ---------- + cfg_file : str + The path to the configuration file. The function will load configurations from this file. + + Returns + ------- + cfg : yacs.config.CfgNode + The configuration node, loaded and initialized with the given file. + """ # setup base cfg = get_cfg_hypvinn() cfg.EXPR_NUM = None diff --git a/HypVINN/utils/mode_config.py b/HypVINN/utils/mode_config.py index 15ba6d74c..5ac271895 100644 --- a/HypVINN/utils/mode_config.py +++ b/HypVINN/utils/mode_config.py @@ -26,7 +26,28 @@ def get_hypinn_mode( t1_path: Optional[Path] = None, t2_path: Optional[Path] = None, ) -> ModalityMode: + """ + Determine the input mode for HypVINN based on the existence of T1 and T2 files. + This function checks the existence of T1 and T2 files based on the provided paths. + + Parameters + ---------- + t1_path : Optional[Path], default=None + The path to the T1 file. + t2_path : Optional[Path], default=None + The path to the T2 file. + + Returns + ------- + ModalityMode + The input mode for HypVINN, which can be "t1t2", "t1", or "t2". + + Raises + ------ + RuntimeError + If neither T1 nor T2 files exist, or if the corresponding flags were passed but the files do not exist. + """ LOGGER.info("Setting up input mode...") if t1_path is not None and t2_path is not None: if t1_path.is_file() and t2_path.is_file(): diff --git a/HypVINN/utils/preproc.py b/HypVINN/utils/preproc.py index 99d5a4cc5..38bddbdbd 100644 --- a/HypVINN/utils/preproc.py +++ b/HypVINN/utils/preproc.py @@ -35,6 +35,33 @@ def t1_to_t2_registration( registration_type: RegistrationMode = "coreg", threads: int = -1, ) -> Path: + """ + Register T1 to T2 images using either mri_coreg or mri_robust_register. + + Parameters + ---------- + t1_path : Path + The path to the T1 image. + t2_path : Path + The path to the T2 image. + subject_dir : Path + The directory of the subject. + registration_type : RegistrationMode, default="coreg" + The type of registration to be used. It can be either "coreg" or "robust". + threads : int, default=-1 + The number of threads to be used. If it is less than or equal to 0, the number of threads will be automatically + determined. + + Returns + ------- + Path + The path to the registered T2 image. + + Raises + ------ + RuntimeError + If mri_coreg, mri_vol2vol, or mri_robust_register fails to run or if they cannot be found. + """ from FastSurferCNN.utils.run_tools import Popen from FastSurferCNN.utils.threads import get_num_threads import shutil @@ -131,7 +158,35 @@ def hyvinn_preproc( subject_dir: Path, threads: int = -1, ) -> Path: - + """ + Preprocess the input images for HypVINN. + + Parameters + ---------- + mode : ModalityMode + The mode for HypVINN. It should be "t1t2". + reg_mode : RegistrationMode + The registration mode. If it is not "none", the function will register T1 to T2 images. + t1_path : Path + The path to the T1 image. + t2_path : Path + The path to the T2 image. + subject_dir : Path + The directory of the subject. + threads : int, default=-1 + The number of threads to be used. If it is less than or equal to 0, the number of threads will be automatically + determined. + + Returns + ------- + Path + The path to the preprocessed T2 image. + + Raises + ------ + RuntimeError + If the mode is not "t1t2", or if the registration mode is not "none" and the registration fails. + """ if mode != "t1t2": raise RuntimeError( "hypvinn_preproc should only be called for t1t2 mode." diff --git a/HypVINN/utils/stats_utils.py b/HypVINN/utils/stats_utils.py index 0d0886173..837a7f53b 100644 --- a/HypVINN/utils/stats_utils.py +++ b/HypVINN/utils/stats_utils.py @@ -21,6 +21,31 @@ def compute_stats( stats_dir: Path, threads: int, ) -> int | str: + """ + Compute statistics for the segmentation results. + + Parameters + ---------- + orig_path : Path + The path to the original image. + prediction_path : Path + The path to the predicted segmentation. + stats_dir : Path + The directory for storing the statistics. + threads : int + The number of threads to be used. + + Returns + ------- + int | str + The return value of the main function from FastSurferCNN.segstats. + Exit code. Returns 0 upon successful execution. + + Raises + ------ + RuntimeError + If the main function from FastSurferCNN.segstats fails to run. + """ from collections import namedtuple from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 4bde6be6e..96050a6c3 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -22,10 +22,40 @@ def remove_values_from_list(the_list, val): + """ + Removes values from a list. + + Parameters + ---------- + the_list : list + The original list from which values will be removed. + val : any + The value to be removed from the list. + + Returns + ------- + list + A new list with the specified value removed. + """ return [value for value in the_list if value != val] def get_lut(lookup_table_path: Path = HYPVINN_LUT): + """ + Retrieve a lookup table (LUT) from a file. + + This function reads a file and constructs a lookup table (LUT) from it. + + Parameters + ---------- + lookup_table_path: Path, default=HYPVINN_LUT + The path to the file from which the LUT will be constructed. + + Returns + ------- + lut: OrderedDict + The constructed LUT as an ordered dictionary. + """ from collections import OrderedDict lut = OrderedDict() with open(lookup_table_path, "r") as f: @@ -39,11 +69,26 @@ def get_lut(lookup_table_path: Path = HYPVINN_LUT): return lut -def map_hyposeg2label(hyposeg, lut_file: Path = HYPVINN_LUT): - import matplotlib.colors +def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT): """ - Function to perform look-up table mapping of aseg.mgz data to label space (continue labels) + Map a HypVINN segmentation to a continuous label space using a lookup table. + + Parameters + ---------- + hyposeg : np.ndarray + The original segmentation map. + lut_file : Path, default=HYPVINN_LUT + The path to the lookup table file. + + Returns + ------- + mapped_hyposeg : ndarray + The mapped segmentation. + cmap : ListedColormap + The colormap for the mapped segmentation. """ + import matplotlib.colors + labels = np.unique(hyposeg) labels = np.int16(labels) @@ -65,6 +110,26 @@ def map_hyposeg2label(hyposeg, lut_file: Path = HYPVINN_LUT): def plot_coronal_predictions(cmap, images_batch=None, pred_batch=None, img_per_row=8): + """ + Plot the predicted segmentations on a grid layout. + + Parameters + ---------- + cmap : matplotlib.colors.Colormap + The colormap to be used for the predicted segmentations. + images_batch : np.ndarray, optional + The batch of input images. If not provided, the function will not plot anything. + pred_batch : np.ndarray, optional + The batch of predicted segmentations. If not provided, the function will not plot anything. + img_per_row : int, default=8 + The number of images to be plotted per row in the grid layout. + + Returns + ------- + fig: matplotlib.figure.Figure + The figure containing the plotted images and predictions. + + """ import matplotlib.pyplot as plt import torch from torchvision import utils @@ -121,6 +186,21 @@ def plot_coronal_predictions(cmap, images_batch=None, pred_batch=None, img_per_r def select_index_to_plot(hyposeg, slice_step=2): + """ + Select indices to plot based on the given segmentation map. + + Parameters + ---------- + hyposeg : np.ndarray + The segmentation map from which indices will be selected. + slice_step : int, default=2 + The step size for selecting indices from the remaining indices after removing certain indices. + + Returns + ------- + list + The selected indices, sorted in ascending order. + """ # slices with labels idx = np.where(hyposeg > 0) idx = np.unique(idx[0]) @@ -162,6 +242,24 @@ def plot_qc_images( padd: int = 45, lut_file: Path = HYPVINN_LUT, slice_step: int = 2): + """ + Plot the quality control images for the subject. + + Parameters + ---------- + subject_qc_dir : Path + The directory for the subject. + orig_path : Path + The path to the original image. + prediction_path : Path + The path to the predicted image. + padd : int, default=45 + The padding value for cropping the images and segmentations. + lut_file : Path, default=HYPVINN_LUT + The path to the lookup table file. + slice_step : int, default=2 + The step size for selecting indices from the predicted segmentation. + """ from scipy import ndimage from HypVINN.data_loader.data_utils import transform_axial2coronal, hypo_map_subseg_2_fsseg diff --git a/doc/api/HypVINN_dataloader.rst b/doc/api/HypVINN_dataloader.rst index 302519213..4797d4209 100644 --- a/doc/api/HypVINN_dataloader.rst +++ b/doc/api/HypVINN_dataloader.rst @@ -6,3 +6,6 @@ HypVINN.data_loader .. autosummary:: :toctree: generated/ + + data_utils + dataset \ No newline at end of file diff --git a/doc/api/HypVINN_models.rst b/doc/api/HypVINN_models.rst index f632d9bff..76bfdbd63 100644 --- a/doc/api/HypVINN_models.rst +++ b/doc/api/HypVINN_models.rst @@ -6,3 +6,5 @@ HypVINN.models .. autosummary:: :toctree: generated/ + + networks From 87db776089ba018e6b4a49c5b473ae83f29b2936 Mon Sep 17 00:00:00 2001 From: Taha Abdullah Date: Fri, 31 May 2024 16:29:25 +0200 Subject: [PATCH 2/8] fixing issue with matplotlib 3.9.0: importing cmap from pyplot instead of cm --- FastSurferCNN/utils/mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FastSurferCNN/utils/mapper.py b/FastSurferCNN/utils/mapper.py index 02198bc7d..3c664ef25 100644 --- a/FastSurferCNN/utils/mapper.py +++ b/FastSurferCNN/utils/mapper.py @@ -49,7 +49,7 @@ import numpy as np import pandas import torch -from matplotlib.cm import get_cmap +from matplotlib.pyplot import get_cmap from matplotlib.colors import Colormap from numpy import typing as npt From bd6d263d29132a18d2162055fc44e7b6bc7b5511 Mon Sep 17 00:00:00 2001 From: Taha Abdullah Date: Fri, 31 May 2024 17:28:07 +0200 Subject: [PATCH 3/8] sphinx syntax issues: - adding space before colon separating parameter name and type --- HypVINN/data_loader/data_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index 1ae0aca1f..0e49432b5 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -34,14 +34,14 @@ def calculate_flip_orientation(iornt: np.ndarray, base_ornt: np.ndarray) -> np.n Parameters ---------- - iornt: np.ndarray + iornt : np.ndarray Initial orientation. - base_ornt: np.ndarray + base_ornt : np.ndarray Base orientation. Returns ------- - new_iornt: np.ndarray + new_iornt : np.ndarray New orientation. """ new_iornt=iornt.copy() @@ -62,12 +62,6 @@ def reorient_img(img, ref_img): """ Reorient a Nibabel image based on the orientation of a reference nibabel image. - Function to reorient a Nibabel image based on the orientation of a reference nibabel image - The orientation transform. ornt[N,1]` is flip of axis N of the array implied by `shape`, where 1 means no flip and - -1 means flip. For example, if ``N==0 and ornt[0,1] == -1, and there’s an array arr of shape shape, the flip would - correspond to the effect of np.flipud(arr). ornt[:,0] is the transpose that needs to be done to the implied array, - as in arr.transpose(ornt[:,0]). - Parameters ---------- img : nibabel.Nifti1Image From bdc2663e66595add2116e3adfb65e0d41c4f3c6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 31 May 2024 18:24:27 +0200 Subject: [PATCH 4/8] Apply suggestions from code review --- HypVINN/data_loader/data_utils.py | 2 +- HypVINN/data_loader/dataset.py | 2 +- HypVINN/inference.py | 2 +- HypVINN/run_prediction.py | 6 +++--- HypVINN/utils/img_processing_utils.py | 4 ++-- HypVINN/utils/mode_config.py | 4 ++-- HypVINN/utils/stats_utils.py | 2 +- HypVINN/utils/visualization_utils.py | 12 ++++++------ 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index 0e49432b5..60f0ca0d8 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -129,7 +129,7 @@ def transform_axial2sagittal(vol: np.ndarray, axial2sagittal: bool = True) -> np ---------- vol : np.ndarray The image volume to transform. - axial2sagittal : bool, optional + axial2sagittal : bool, default=True A flag to determine the direction of the transformation. If True, transform from axial to sagittal. If False, transform from sagittal to axial. (Default: True). diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py index 7018cf31f..697fd53fe 100644 --- a/HypVINN/data_loader/dataset.py +++ b/HypVINN/data_loader/dataset.py @@ -74,7 +74,7 @@ def __init__( The original zoom of the subject. cfg : CfgNode The configuration object. - mode : ModalityMode, optional + mode : ModalityMode, default="t1t2" The running mode of the HypVINN model. (Default: "t1t2"). transforms : Callable, optional The transformations to apply to the images. (Default: None). diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 8630f3f92..d9f0a4740 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -370,7 +370,7 @@ def run( The tensor to update with the prediction probabilities. out_res : float, optional The resolution of the output. Default is None. - mode : ModalityMode, optional + mode : ModalityMode, default="t1t2" The mode of the modalities. Default is 't1t2'. Returns diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index 14b1317de..f885ec7bc 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -59,14 +59,14 @@ def optional_path(a: Path | str) -> Optional[Path]: Parameters ---------- - a : Path | str + a : Path, str The input to convert. Returns ------- - Optional[Path] + Path, optional The converted Path object. - """ + f""" if isinstance(a, Path): return a if a.lower() in ("none", ""): diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index 2ea77262f..27a5ea69d 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -74,7 +74,7 @@ def save_segmentation( The directory where the subject's data is stored. seg_file : Path The file where the segmentation results will be saved. - save_mask : bool, optional + save_mask : bool, default=False Whether to save the mask or not. Default is False. Returns @@ -181,7 +181,7 @@ def get_clean_mask(segmentation: np.ndarray, optic=False) \ ---------- segmentation : np.ndarray The input segmentation mask. - optic : bool, optional + optic : bool, default=False A flag indicating whether to consider optic labels. Default is False. Returns diff --git a/HypVINN/utils/mode_config.py b/HypVINN/utils/mode_config.py index 5ac271895..6f140265f 100644 --- a/HypVINN/utils/mode_config.py +++ b/HypVINN/utils/mode_config.py @@ -33,9 +33,9 @@ def get_hypinn_mode( Parameters ---------- - t1_path : Optional[Path], default=None + t1_path : Path, optional The path to the T1 file. - t2_path : Optional[Path], default=None + t2_path : Path, optional The path to the T2 file. Returns diff --git a/HypVINN/utils/stats_utils.py b/HypVINN/utils/stats_utils.py index 837a7f53b..b28da5faf 100644 --- a/HypVINN/utils/stats_utils.py +++ b/HypVINN/utils/stats_utils.py @@ -37,7 +37,7 @@ def compute_stats( Returns ------- - int | str + int, str The return value of the main function from FastSurferCNN.segstats. Exit code. Returns 0 upon successful execution. diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 96050a6c3..5e43ae962 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -41,14 +41,14 @@ def remove_values_from_list(the_list, val): def get_lut(lookup_table_path: Path = HYPVINN_LUT): - """ + f""" Retrieve a lookup table (LUT) from a file. This function reads a file and constructs a lookup table (LUT) from it. Parameters ---------- - lookup_table_path: Path, default=HYPVINN_LUT + lookup_table_path: Path, default="{HYPVINN_LUT}" The path to the file from which the LUT will be constructed. Returns @@ -70,14 +70,14 @@ def get_lut(lookup_table_path: Path = HYPVINN_LUT): def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT): - """ + f""" Map a HypVINN segmentation to a continuous label space using a lookup table. Parameters ---------- hyposeg : np.ndarray The original segmentation map. - lut_file : Path, default=HYPVINN_LUT + lut_file : Path, default="{HYPVINN_LUT}" The path to the lookup table file. Returns @@ -242,7 +242,7 @@ def plot_qc_images( padd: int = 45, lut_file: Path = HYPVINN_LUT, slice_step: int = 2): - """ + f""" Plot the quality control images for the subject. Parameters @@ -255,7 +255,7 @@ def plot_qc_images( The path to the predicted image. padd : int, default=45 The padding value for cropping the images and segmentations. - lut_file : Path, default=HYPVINN_LUT + lut_file : Path, default="{HYPVINN_LUT}" The path to the lookup table file. slice_step : int, default=2 The step size for selecting indices from the predicted segmentation. From c01a3fe8103923f4e495b46cd64fc1c27dcb66e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 31 May 2024 18:25:45 +0200 Subject: [PATCH 5/8] Apply suggestions from code review --- HypVINN/run_prediction.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index f885ec7bc..2f2e55137 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -190,9 +190,9 @@ def main( ---------- out_dir : Path The output directory where the results will be stored. - t2 : Optional[Path] + t2 : Path, optional The path to the T2 image to process. - orig_name : Optional[Path] + orig_name : Path, optional The original name of the input image. sid : str The subject ID. @@ -208,23 +208,23 @@ def main( The path to the coronal configuration file. cfg_sag : Path The path to the sagittal configuration file. - hypo_segfile : str, optional + hypo_segfile : str, default="{HYPVINN_SEG_NAME}" The name of the hypothalamus segmentation file. Default is HYPVINN_SEG_NAME. - allow_root : bool, optional + allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional Whether to create QC snapshots. Default is False. - reg_mode : Literal["coreg", "robust", "none"], optional + reg_mode : "coreg", "robust", "none", default="coreg" The registration mode to use. Default is "coreg". - threads : int, optional + threads : int, default=-1 The number of threads to use. Default is -1, which uses all available threads. - batch_size : int, optional + batch_size : int, default=1 The batch size to use. Default is 1. - async_io : bool, optional + async_io : bool, default=False Whether to use asynchronous I/O. Default is False. - device : str, optional + device : str, default="auto" The device to use. Default is "auto", which automatically selects the device. - viewagg_device : str, optional + viewagg_device : str, default="auto" The view aggregation device to use. Default is "auto", which automatically selects the device. Returns @@ -476,9 +476,9 @@ def load_volumes( ---------- mode : ModalityMode The mode of operation. Can be 't1', 't2', or 't1t2'. - t1_path : Optional[Path], optional + t1_path : Path, optional The path to the T1 image. Default is None. - t2_path : Optional[Path], optional + t2_path : Path, optional The path to the T2 image. Default is None. Returns @@ -598,7 +598,7 @@ def get_prediction( A dictionary containing the configurations for each plane. out_scale : optional The output scale. Default is None. - mode : ModalityMode, optional + mode : ModalityMode, default="t1t2" The mode of operation. Can be 't1', 't2', or 't1t2'. Default is 't1t2'. Returns @@ -663,7 +663,7 @@ def set_up_cfgs( The configuration node to load. out_dir : Path The output directory where the results will be stored. - batch_size : int, optional + batch_size : int, default=1 The batch size to use. Default is 1. Returns From 6fc25ecb0d00e3ac50627e16458c85f1a3343576 Mon Sep 17 00:00:00 2001 From: Taha Abdullah Date: Tue, 4 Jun 2024 10:54:19 +0200 Subject: [PATCH 6/8] Applying changes requested in code review --- HypVINN/inference.py | 9 +++++---- HypVINN/run_prediction.py | 2 +- HypVINN/utils/img_processing_utils.py | 9 +++++---- HypVINN/utils/visualization_utils.py | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/HypVINN/inference.py b/HypVINN/inference.py index d9f0a4740..b01c618f2 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -294,10 +294,10 @@ def get_device(self): @torch.no_grad() def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float = None) -> torch.Tensor: """ - Evaluate the model on a validation set. + Evaluate the model on a HypVINN dataset. - This method runs the model in evaluation mode on a validation set. It iterates over the validation set, - computes the model's predictions, and updates the prediction probabilities based on the plane of the data. + This method runs the model in evaluation mode on a HypVINN Dataset. It iterates over the given dataset and + computes the model's predictions. Parameters ---------- @@ -355,7 +355,8 @@ def run( """ Run the inference process on a single subject. - This method sets up a DataLoader for the subject, runs the model in evaluation mode on the subject's data, + This method sets up the HypVINN DataLoader for the subject, runs the model in evaluation mode on the subject's + data, and returns the updated prediction probabilities. Parameters diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index 2f2e55137..0a92a0229 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -193,7 +193,7 @@ def main( t2 : Path, optional The path to the T2 image to process. orig_name : Path, optional - The original name of the input image. + The path to the T1 image to process or FastSurfer orig image. sid : str The subject ID. ckpt_ax : Path diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index 27a5ea69d..ea0d902c0 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -49,7 +49,7 @@ def save_segmentation( prediction: np.ndarray, orig_path: Path, ras_affine: npt.NDArray[float], - ras_header: nib.nifti1.Nifti1Header, + ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, subject_dir: Path, seg_file: Path, save_mask: bool = False, @@ -118,7 +118,7 @@ def save_logits( logits: npt.NDArray[float], orig_path: Path, ras_affine: npt.NDArray[float], - ras_header: nib.nifti1.Nifti1Header, + ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, save_dir: Path, mode: str, ) -> Path: @@ -171,7 +171,7 @@ def save_logits( def get_clean_mask(segmentation: np.ndarray, optic=False) \ -> tuple[np.ndarray, np.ndarray, bool]: """ - Get a clean mask by removing not connected components. + Get a clean mask by removing non-connected components from a dilated mask. This function takes a segmentation mask and an optional boolean flag indicating whether to consider optic labels. It removes not connected components from the segmentation mask and returns the cleaned segmentation mask, the @@ -244,7 +244,8 @@ def get_clean_mask(segmentation: np.ndarray, optic=False) \ def get_clean_labels(segmentation: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ - Function to find the largest connected component of the segmentation. + Get clean labels by removing non-connected components from a dilated mask and any connected component with size + less than 3. Parameters ---------- diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 5e43ae962..7cd151059 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -42,7 +42,7 @@ def remove_values_from_list(the_list, val): def get_lut(lookup_table_path: Path = HYPVINN_LUT): f""" - Retrieve a lookup table (LUT) from a file. + Retrieve a color lookup table (LUT) from a file. This function reads a file and constructs a lookup table (LUT) from it. From 79dc3850926e70690a8c7be8a900ada3f8f64028 Mon Sep 17 00:00:00 2001 From: Taha Abdullah Date: Tue, 4 Jun 2024 12:44:26 +0200 Subject: [PATCH 7/8] Applying relative path changes in docstrings --- HypVINN/run_prediction.py | 11 +++++++---- HypVINN/utils/visualization_utils.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index 0a92a0229..26903e14a 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # IMPORTS +import os.path from typing import TYPE_CHECKING, Optional, cast, Literal import argparse from pathlib import Path @@ -32,6 +32,7 @@ load_checkpoint_config_defaults, ) from FastSurferCNN.utils.common import assert_no_root, SerialExecutor +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image @@ -46,6 +47,8 @@ from HypVINN.utils.stats_utils import compute_stats from HypVINN.utils.visualization_utils import plot_qc_images +_doc_HYPVINN_SEG_NAME = os.path.relpath(HYPVINN_SEG_NAME, FASTSURFER_ROOT) + logger = logging.get_logger(__name__) ## @@ -183,7 +186,7 @@ def main( device: str = "auto", viewagg_device: str = "auto", ) -> int | str: - """ + f""" Main function of the hypothalamus segmentation module. Parameters @@ -208,8 +211,8 @@ def main( The path to the coronal configuration file. cfg_sag : Path The path to the sagittal configuration file. - hypo_segfile : str, default="{HYPVINN_SEG_NAME}" - The name of the hypothalamus segmentation file. Default is HYPVINN_SEG_NAME. + hypo_segfile : str, default="{_doc_HYPVINN_SEG_NAME}" + The name of the hypothalamus segmentation file. Default is {_doc_HYPVINN_SEG_NAME}. allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 7cd151059..a43aaf9b2 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os.path from pathlib import Path import numpy as np @@ -19,6 +19,9 @@ import matplotlib.pyplot as plt from HypVINN.config.hypvinn_files import HYPVINN_LUT +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT) def remove_values_from_list(the_list, val): @@ -48,7 +51,7 @@ def get_lut(lookup_table_path: Path = HYPVINN_LUT): Parameters ---------- - lookup_table_path: Path, default="{HYPVINN_LUT}" + lookup_table_path: Path, default="{_doc_HYPVINN_LUT}" The path to the file from which the LUT will be constructed. Returns @@ -77,7 +80,7 @@ def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT): ---------- hyposeg : np.ndarray The original segmentation map. - lut_file : Path, default="{HYPVINN_LUT}" + lut_file : Path, default="{_doc_HYPVINN_LUT}" The path to the lookup table file. Returns @@ -255,7 +258,7 @@ def plot_qc_images( The path to the predicted image. padd : int, default=45 The padding value for cropping the images and segmentations. - lut_file : Path, default="{HYPVINN_LUT}" + lut_file : Path, default="{_doc_HYPVINN_LUT}" The path to the lookup table file. slice_step : int, default=2 The step size for selecting indices from the predicted segmentation. From 9a6822bc3d8b96f155532911763c50d51f81338b Mon Sep 17 00:00:00 2001 From: Taha Abdullah Date: Tue, 4 Jun 2024 16:43:13 +0200 Subject: [PATCH 8/8] changed back to HYPVINN_SEG_NAME --- HypVINN/run_prediction.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index 26903e14a..ea185040d 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # IMPORTS -import os.path from typing import TYPE_CHECKING, Optional, cast, Literal import argparse from pathlib import Path @@ -32,7 +31,6 @@ load_checkpoint_config_defaults, ) from FastSurferCNN.utils.common import assert_no_root, SerialExecutor -from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image @@ -47,8 +45,6 @@ from HypVINN.utils.stats_utils import compute_stats from HypVINN.utils.visualization_utils import plot_qc_images -_doc_HYPVINN_SEG_NAME = os.path.relpath(HYPVINN_SEG_NAME, FASTSURFER_ROOT) - logger = logging.get_logger(__name__) ## @@ -211,8 +207,8 @@ def main( The path to the coronal configuration file. cfg_sag : Path The path to the sagittal configuration file. - hypo_segfile : str, default="{_doc_HYPVINN_SEG_NAME}" - The name of the hypothalamus segmentation file. Default is {_doc_HYPVINN_SEG_NAME}. + hypo_segfile : str, default="{HYPVINN_SEG_NAME}" + The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional