diff --git a/HypVINN/inference.py b/HypVINN/inference.py index d9f0a474..b01c618f 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -294,10 +294,10 @@ def get_device(self): @torch.no_grad() def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float = None) -> torch.Tensor: """ - Evaluate the model on a validation set. + Evaluate the model on a HypVINN dataset. - This method runs the model in evaluation mode on a validation set. It iterates over the validation set, - computes the model's predictions, and updates the prediction probabilities based on the plane of the data. + This method runs the model in evaluation mode on a HypVINN Dataset. It iterates over the given dataset and + computes the model's predictions. Parameters ---------- @@ -355,7 +355,8 @@ def run( """ Run the inference process on a single subject. - This method sets up a DataLoader for the subject, runs the model in evaluation mode on the subject's data, + This method sets up the HypVINN DataLoader for the subject, runs the model in evaluation mode on the subject's + data, and returns the updated prediction probabilities. Parameters diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index 3bc0f5ad..f64b117b 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # IMPORTS from typing import TYPE_CHECKING, Optional, cast, Literal import argparse @@ -183,7 +182,7 @@ def main( device: str = "auto", viewagg_device: str = "auto", ) -> int | str: - """ + f""" Main function of the hypothalamus segmentation module. Parameters @@ -193,7 +192,7 @@ def main( t2 : Path, optional The path to the T2 image to process. orig_name : Path, optional - The original name of the input image. + The path to the T1 image to process or FastSurfer orig image. sid : str The subject ID. ckpt_ax : Path @@ -209,7 +208,7 @@ def main( cfg_sag : Path The path to the sagittal configuration file. hypo_segfile : str, default="{HYPVINN_SEG_NAME}" - The name of the hypothalamus segmentation file. Default is HYPVINN_SEG_NAME. + The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index 27a5ea69..ea0d902c 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -49,7 +49,7 @@ def save_segmentation( prediction: np.ndarray, orig_path: Path, ras_affine: npt.NDArray[float], - ras_header: nib.nifti1.Nifti1Header, + ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, subject_dir: Path, seg_file: Path, save_mask: bool = False, @@ -118,7 +118,7 @@ def save_logits( logits: npt.NDArray[float], orig_path: Path, ras_affine: npt.NDArray[float], - ras_header: nib.nifti1.Nifti1Header, + ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, save_dir: Path, mode: str, ) -> Path: @@ -171,7 +171,7 @@ def save_logits( def get_clean_mask(segmentation: np.ndarray, optic=False) \ -> tuple[np.ndarray, np.ndarray, bool]: """ - Get a clean mask by removing not connected components. + Get a clean mask by removing non-connected components from a dilated mask. This function takes a segmentation mask and an optional boolean flag indicating whether to consider optic labels. It removes not connected components from the segmentation mask and returns the cleaned segmentation mask, the @@ -244,7 +244,8 @@ def get_clean_mask(segmentation: np.ndarray, optic=False) \ def get_clean_labels(segmentation: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ - Function to find the largest connected component of the segmentation. + Get clean labels by removing non-connected components from a dilated mask and any connected component with size + less than 3. Parameters ---------- diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 5e43ae96..a43aaf9b 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os.path from pathlib import Path import numpy as np @@ -19,6 +19,9 @@ import matplotlib.pyplot as plt from HypVINN.config.hypvinn_files import HYPVINN_LUT +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT) def remove_values_from_list(the_list, val): @@ -42,13 +45,13 @@ def remove_values_from_list(the_list, val): def get_lut(lookup_table_path: Path = HYPVINN_LUT): f""" - Retrieve a lookup table (LUT) from a file. + Retrieve a color lookup table (LUT) from a file. This function reads a file and constructs a lookup table (LUT) from it. Parameters ---------- - lookup_table_path: Path, default="{HYPVINN_LUT}" + lookup_table_path: Path, default="{_doc_HYPVINN_LUT}" The path to the file from which the LUT will be constructed. Returns @@ -77,7 +80,7 @@ def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT): ---------- hyposeg : np.ndarray The original segmentation map. - lut_file : Path, default="{HYPVINN_LUT}" + lut_file : Path, default="{_doc_HYPVINN_LUT}" The path to the lookup table file. Returns @@ -255,7 +258,7 @@ def plot_qc_images( The path to the predicted image. padd : int, default=45 The padding value for cropping the images and segmentations. - lut_file : Path, default="{HYPVINN_LUT}" + lut_file : Path, default="{_doc_HYPVINN_LUT}" The path to the lookup table file. slice_step : int, default=2 The step size for selecting indices from the predicted segmentation.