diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index f64b117b..5a275cfe 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -32,7 +32,7 @@ ) from FastSurferCNN.utils.common import assert_no_root, SerialExecutor -from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME +from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME, HYPVINN_MASK_NAME from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image from HypVINN.inference import Inference from HypVINN.utils import ModalityDict, ModalityMode, ViewOperations @@ -173,6 +173,7 @@ def main( cfg_cor: Path, cfg_sag: Path, hypo_segfile: str = HYPVINN_SEG_NAME, + hypo_maskfile: str = HYPVINN_MASK_NAME, allow_root: bool = False, qc_snapshots: bool = False, reg_mode: Literal["coreg", "robust", "none"] = "coreg", @@ -209,6 +210,8 @@ def main( The path to the sagittal configuration file. hypo_segfile : str, default="{HYPVINN_SEG_NAME}" The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. + hypo_maskfile : str, default="{HYPVINN_MASK_NAME}" + The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}. allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional @@ -224,7 +227,8 @@ def main( device : str, default="auto" The device to use. Default is "auto", which automatically selects the device. viewagg_device : str, default="auto" - The view aggregation device to use. Default is "auto", which automatically selects the device. + The view aggregation device to use. Default is "auto", which automatically + selects the device. Returns ------- @@ -360,21 +364,17 @@ def main( else: orig_path = t2_path - save_future: Future = pool.submit( - save_segmentation, + time_needed = save_segmentation( pred_classes, orig_path=orig_path, ras_affine=affine, ras_header=header, subject_dir=subject_dir, seg_file=hypo_segfile, + mask_file=hypo_maskfile, save_mask=True, ) - save_future.add_done_callback( - lambda x: logger.info( - f"Prediction successfully saved in {x.result()} seconds." - ), - ) + logger.info(f"Prediction successfully saved in {time_needed} seconds.") if qc_snapshots: qc_future: Optional[Future] = pool.submit( plot_qc_images, @@ -408,7 +408,6 @@ def main( if qc_future: # finish qc qc_future.result() - save_future.result() logger.info( f"Processing whole pipeline finished in {time() - start:.4f} seconds." diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index ea0d902c..765f8baf 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -51,7 +51,8 @@ def save_segmentation( ras_affine: npt.NDArray[float], ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, subject_dir: Path, - seg_file: Path, + seg_file: str, + mask_file: str, save_mask: bool = False, ) -> float: """ @@ -73,7 +74,9 @@ def save_segmentation( subject_dir : Path The directory where the subject's data is stored. seg_file : Path - The file where the segmentation results will be saved. + The file where the segmentation will be saved (relative to subject_dir/mri). + mask_file : str + The file where the mask will be saved (relative to subject_dir/mri). save_mask : bool, default=False Whether to save the mask or not. Default is False. @@ -86,7 +89,6 @@ def save_segmentation( from time import time starttime = time() from HypVINN.data_loader.data_utils import reorient_img - from HypVINN.config.hypvinn_files import HYPVINN_MASK_NAME, HYPVINN_SEG_NAME pred_arr, labels_cc = get_clean_labels(np.array(prediction, dtype=np.uint8)) # Mapped HypVINN labelst to FreeSurfer Hypvinn Labels @@ -101,7 +103,7 @@ def save_segmentation( LOGGER.info( f"HypoVINN Mask after re-orientation: {img2axcodes(mask_img)}" ) - nib.save(mask_img, subject_dir / "mri" / HYPVINN_MASK_NAME) + nib.save(mask_img, subject_dir / "mri" / mask_file) pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header) LOGGER.info(f"HypoVINN Prediction orientation: {img2axcodes(pred_img)}") @@ -110,7 +112,7 @@ def save_segmentation( f"HypoVINN Prediction after re-orientation: {img2axcodes(pred_img)}" ) pred_img.set_data_dtype(np.int16) # Maximum value 939 - nib.save(pred_img, subject_dir / seg_file) + nib.save(pred_img, subject_dir / "mri" / seg_file) return time() - starttime