From f913211eb0d0825b4e29a00ce8440af52be72ae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Thu, 18 Apr 2024 18:55:22 +0200 Subject: [PATCH] doc/overview/OUTPUT_FILES.md - remove/rename changed filenames - format table HypVINN/config/checkpoint_paths.yaml - add config HypVINN/data_loader/data_utils.py - fix typing and formatting HypVINN/utils/checkpoint.py - fix YAML_DEFAULT HypVINN/utils/mode_config.py - set default values for get_hypvinn_mode HypVINN/inference.py - fix inclusion of ModalityMode HypVINN/run_prediction.py - move HypVINN/run_pipeline.py into run_prediction.py - fix typing, e.g. FileBasedHeader - fix function parameters - add help text to hypo_segfile argument - fix passing of t1_path and t2_path - various other changes --- HypVINN/README.md | 10 +- HypVINN/__init__.py | 1 - HypVINN/config/checkpoint_paths.yaml | 5 + HypVINN/data_loader/data_utils.py | 103 ++--- HypVINN/inference.py | 21 +- HypVINN/run_pipeline.py | 241 ------------ HypVINN/run_prediction.py | 544 ++++++++++++++++++++------- HypVINN/utils/checkpoint.py | 2 +- HypVINN/utils/mode_config.py | 4 +- doc/overview/OUTPUT_FILES.md | 19 +- run_fastsurfer.sh | 58 ++- 11 files changed, 527 insertions(+), 481 deletions(-) delete mode 100644 HypVINN/run_pipeline.py diff --git a/HypVINN/README.md b/HypVINN/README.md index 2077efb2..dada73b4 100644 --- a/HypVINN/README.md +++ b/HypVINN/README.md @@ -18,7 +18,7 @@ Hypothalamic subfields segmentation pipeline 2. Hypothalamus Segmentation ### Running the tool -Run the HypVINN/run_pipeline.py which has the following arguments: +Run the HypVINN/run_prediction.py which has the following arguments: ### Input and output arguments * `--sid ` : Subject ID, the subject data upon which to operate * `--sd ` : Directory in which evaluation results should be written. @@ -54,7 +54,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip 1. Run full pipeline ``` - python HypVINN/run_pipeline.py --sid test_subject --sd /output \ + python HypVINN/run_prediction.py --sid test_subject --sd /output \ --t1 /data/test_subject_t1.nii.gz \ --t2 /data/test_subject_t2.nii.gz \ --reg_mode coreg \ @@ -63,7 +63,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip ``` 2. Run full pipeline only using a t1 ``` - python HypVINN/run_pipeline.py --sid test_subject --sd /output \ + python HypVINN/run_prediction.py --sid test_subject --sd /output \ --t1 /data/test_subject_t1.nii.gz \ --reg_mode coreg \ --seg_log /outdir/test_subject.log \ @@ -72,7 +72,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip 3. Run pipeline without the registration step ``` - python HypVINN/run_pipeline.py --sid test_subject --sd /output \ + python HypVINN/run_prediction.py --sid test_subject --sd /output \ --t1 /data/test_subject_t1.nii.gz \ --t2 /data/test_subject_t2.nii.gz \ --reg_mode coreg \ @@ -82,7 +82,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip 4. Run pipeline with creation of qc snapshots ``` - python HypVINN/run_pipeline.py --sid test_subject --sd /output \ + python HypVINN/run_prediction.py --sid test_subject --sd /output \ --t1 /data/test_subject_t1.nii.gz \ --t2 /data/test_subject_t2.nii.gz \ --reg_mode coreg \ diff --git a/HypVINN/__init__.py b/HypVINN/__init__.py index 0dc9ff4b..870c1df8 100644 --- a/HypVINN/__init__.py +++ b/HypVINN/__init__.py @@ -19,5 +19,4 @@ "utils", "inference", "run_prediction", - "run_pipeline", ] diff --git a/HypVINN/config/checkpoint_paths.yaml b/HypVINN/config/checkpoint_paths.yaml index 1d0c380a..7d72753f 100644 --- a/HypVINN/config/checkpoint_paths.yaml +++ b/HypVINN/config/checkpoint_paths.yaml @@ -5,3 +5,8 @@ checkpoint: axial: "checkpoints/HypVINN_axial_v1.0.0.pkl" coronal: "checkpoints/HypVINN_coronal_v1.0.0.pkl" sagittal: "checkpoints/HypVINN_sagittal_v1.0.0.pkl" + +config: + axial: "HypVINN/config/HypVINN_axial_v1.0.0.yaml" + coronal: "HypVINN/config/HypVINN_coronal_v1.0.0.yaml" + sagittal: "HypVINN/config/HypVINN_sagittal_v1.0.0.yaml" diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index 6deb4d99..b07a9af0 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -13,19 +13,25 @@ # limitations under the License. # IMPORTS +import nibabel as nib import numpy as np +from numpy import typing as npt + from FastSurferCNN.data_loader.conform import getscale, scalecrop -import nibabel as nib -import sys from HypVINN.config.hypvinn_global_var import hyposubseg_labels, SAG2FULL_MAP, HYPVINN_CLASS_NAMES, FS_CLASS_NAMES + + ## # Helper Functions ## -def calculate_flip_orientation(iornt,base_ornt): + + +def calculate_flip_orientation(iornt, base_ornt): """ Compute the flip orientation transform. ornt[N, 1] is flip of axis N, where 1 means no flip and -1 means flip. + Parameters ---------- iornt @@ -47,9 +53,9 @@ def calculate_flip_orientation(iornt,base_ornt): return new_iornt -# reorient image based on base image -def reorient_img(img,ref_img): - ''' + +def reorient_img(img, ref_img): + """ Function to reorient a Nibabel image based on the orientation of a reference nibabel image The orientation transform. ornt[N,1]` is flip of axis N of the array implied by `shape`, where 1 means no flip and -1 means flip. For example, if ``N==0 and ornt[0,1] == -1, and there’s an array arr of shape shape, the flip would correspond to the effect of @@ -63,7 +69,7 @@ def reorient_img(img,ref_img): Returns ------- - ''' + """ ref_ornt =nib.io_orientation(ref_img.affine) iornt=nib.io_orientation(img.affine) @@ -79,8 +85,7 @@ def reorient_img(img,ref_img): return img -# Transformation for mapping -#TODO check compatibility with axis transform from CerebNet + def transform_axial2coronal(vol, axial2coronal=True): """ Function to transform volume into coronal axis and back @@ -89,11 +94,13 @@ def transform_axial2coronal(vol, axial2coronal=True): transform from coronal to axial = False :return: """ + # TODO check compatibility with axis transform from CerebNet if axial2coronal: return np.moveaxis(vol, [0, 1, 2], [0, 2, 1]) else: return np.moveaxis(vol, [0, 1, 2], [0, 2, 1]) -#TODO check compatibility with axis transform from CerebNet + + def transform_axial2sagittal(vol, axial2sagittal=True): """ Function to transform volume into Sagittal axis and back @@ -102,15 +109,16 @@ def transform_axial2sagittal(vol, axial2sagittal=True): transform from sagittal to coronal = False :return: """ + # TODO check compatibility with axis transform from CerebNet if axial2sagittal: return np.moveaxis(vol, [0, 1, 2], [2, 0, 1]) else: return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) -# Same as CerebNet.datasets.utils.rescale_image def rescale_image(img_data): # Conform intensities + # TODO move function into FastSurferCNN, same: CerebNet.datasets.utils.rescale_image src_min, scale = getscale(img_data, 0, 255) mapped_data = img_data if not img_data.dtype == np.dtype(np.uint8): @@ -121,59 +129,11 @@ def rescale_image(img_data): return new_data - -def hypo_map_subseg2label(subseg): - ''' - Function to perform look-up table mapping from subseg space to label space - - - Parameters - ---------- - subseg - - Returns - ------- - - ''' - - h, w, d = subseg.shape - lbls, lbls_sag = hyposubseg_labels - - lut_subseg = np.zeros(max(lbls) + 1, dtype='int') - for idx, value in enumerate(lbls): - lut_subseg[value] = idx - - mapped_subseg = lut_subseg.ravel()[subseg.ravel()] - mapped_subseg = mapped_subseg.reshape((h, w, d)) - - - # mapping left labels to right labels for sagittal view - subseg[subseg == 2] = 1 - subseg[subseg == 5] = 4 - subseg[subseg == 6] = 3 - subseg[subseg == 8] = 7 - subseg[subseg == 12] = 11 - subseg[subseg == 20] = 13 - subseg[subseg == 24] = 23 - - subseg[subseg == 126] = 226 - subseg[subseg == 127] = 227 - subseg[subseg == 128] = 228 - subseg[subseg == 129] = 229 - - lut_subseg_sag = np.zeros(max(lbls_sag) + 1, dtype='int') - for idx, value in enumerate(lbls_sag): - lut_subseg_sag[value] = idx - - mapped_subseg_sag = lut_subseg_sag.ravel()[subseg.ravel()] - - mapped_subseg_sag = mapped_subseg_sag.reshape((h, w, d)) - - return mapped_subseg,mapped_subseg_sag -def hypo_map_label2subseg(mapped_subseg): - ''' - Function to perform look-up table mapping from label space to subseg space - ''' +def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]: + """ + Function to perform look-up table mapping from label space to subseg space + """ + # TODO can this function be replaced by a Mapper and a mapping file? labels, _ = hyposubseg_labels subseg = np.zeros_like(mapped_subseg) h, w, d = subseg.shape @@ -181,20 +141,27 @@ def hypo_map_label2subseg(mapped_subseg): return subseg.reshape((h, w, d)) -def hypo_map_prediction_sagittal2full(prediction_sag): + +def hypo_map_prediction_sagittal2full( + prediction_sag: npt.NDArray[int], +) -> npt.NDArray[int]: """ Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks :param prediction_sag: sagittal prediction (labels) :param lbl_type: type of label :return: Remapped prediction """ + # TODO can this function be replaced by a Mapper and a mapping file? idx_list = list(SAG2FULL_MAP.values()) prediction_full = prediction_sag[:, idx_list, :, :] return prediction_full -def hypo_map_subseg_2_fsseg(subseg,reverse=False): +def hypo_map_subseg_2_fsseg( + subseg: npt.NDArray[int], + reverse: bool = False, +) -> npt.NDArray[int]: """ Function to remap HypVINN internal labels to FastSurfer Labels and viceversa Parameters @@ -206,7 +173,9 @@ def hypo_map_subseg_2_fsseg(subseg,reverse=False): ------- """ - fsseg = np.zeros_like(subseg,dtype=np.int16) + # TODO can this function be replaced by a Mapper and a mapping file? + + fsseg = np.zeros_like(subseg, dtype=np.int16) if not reverse: for value, name in HYPVINN_CLASS_NAMES.items(): diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 274ef1f1..85cac5f3 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -27,17 +27,24 @@ from HypVINN.models.networks import build_model from HypVINN.data_loader.data_utils import hypo_map_prediction_sagittal2full from HypVINN.data_loader.dataset import HypoVINN_dataset -from HypVINN.run_prediction import ModalityMode +from HypVINN.utils import ModalityMode logger = logging.get_logger(__name__) class Inference: - def __init__(self, cfg, args): + def __init__( + self, + cfg, + threads: int = -1, + async_io: bool = False, + device: str = "auto", + viewagg_device: str = "auto", + ): - self._threads = getattr(args, "threads", 1) + self._threads = threads torch.set_num_threads(self._threads) - self._async_io = getattr(args, "async_io", False) + self._async_io = async_io # Set random seed from configs. np.random.seed(cfg.RNG_SEED) @@ -49,16 +56,16 @@ def __init__(self, cfg, args): torch.set_flush_denormal(True) # Define device and transfer model - self.device = find_device(args.device) + self.device = find_device(device) - if self.device.type == "cpu" and args.viewagg_device == "auto": + if self.device.type == "cpu" and viewagg_device == "auto": self.viewagg_device = self.device else: # check, if GPU is big enough to run view agg on it # (this currently takes the memory of the passed device) self.viewagg_device = torch.device( find_device( - args.viewagg_device, + viewagg_device, flag_name="viewagg_device", min_memory=4 * (2 ** 30), ) diff --git a/HypVINN/run_pipeline.py b/HypVINN/run_pipeline.py deleted file mode 100644 index 439b94f3..00000000 --- a/HypVINN/run_pipeline.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -from pathlib import Path -from typing import Optional -import time - -from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults -from FastSurferCNN.utils.checkpoint import get_checkpoints, load_checkpoint_config_defaults -from FastSurferCNN.utils.common import assert_no_root -from HypVINN.run_prediction import run_hypo_seg -from HypVINN.utils.preproc import hyvinn_preproc -from HypVINN.utils.mode_config import get_hypinn_mode -from HypVINN.utils.misc import create_expand_output_directory -from HypVINN.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE -## -# Global Variables -## -LOGGER = logging.get_logger(__name__) - - -def optional_path(a: str) -> Optional[Path]: - """ - Convert a string to a Path object or None. - - Parameters - ---------- - a : str - The string to convert. - - Returns - ------- - Optional[Path] - The Path object or None. - """ - if a.lower() in ("none", ""): - return None - return Path(a) - - -def option_parse() -> argparse.ArgumentParser: - """ - A function to create an ArgumentParser object and parse the command line arguments. - - Returns - ------- - argparse.Ar - The parser object to parse arguments from the command line. - """ - from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME - parser = argparse.ArgumentParser( - description="Script for Hypothalamus Segmentation.", - ) - - # 1. Directory information (where to read from, where to write from and to incl. search-tag) - parser = parser_defaults.add_arguments( - parser, ["in_dir", "sd", "sid"], - ) - - parser = parser_defaults.add_arguments(parser, ["seg_log"]) - - # 2. Options for the MRI volumes - parser = parser_defaults.add_arguments( - parser, ["t1"] - ) - parser.add_argument( - '--t2', - type=optional_path, - default=None, - required=False, - help="Path to the T2 image to process.", - ) - - # 3. Image processing options - parser.add_argument( - "--qc_snap", - action='store_true', - dest="qc_snapshots", - help="Create qc snapshots in //qc_snapshots.", - ) - parser.add_argument( - "--reg_mode", - type=str, - default="coreg", - choices=["none", "coreg", "robust"], - help="Freesurfer Registration type to run. coreg: mri_coreg, " - "robust : mri_robust_register, none: entirely deactivates " - "registration of T2 to T1, if both images are passed, " - "images need to be register properly externally.", - ) - parser.add_argument( - "--hypo_segfile", - type=Path, - default=Path("mri") / HYPVINN_SEG_NAME, - dest="hypo_segfile", - help="" - ) - - # 4. Options for advanced, technical parameters - advanced = parser.add_argument_group(title="Advanced options") - parser_defaults.add_arguments( - advanced, - ["device", "viewagg_device", "threads", "batch_size", "async_io", "allow_root"], - ) - - files: dict[Plane, str | Path] = {k: "default" for k in PLANES} - # 5. Checkpoint to load - parser_defaults.add_plane_flags( - advanced, - "checkpoint", - files, - CHECKPOINT_PATHS_FILE, - ) - - parser_defaults.add_plane_flags( - advanced, - "config", - { - "coronal": Path("HypVINN/config/HypVINN_coronal_v1.0.0.yaml"), - "axial": Path("HypVINN/config/HypVINN_axial_v1.0.0.yaml"), - "sagittal": Path("HypVINN/config/HypVINN_sagittal_v1.0.0.yaml"), - }, - CHECKPOINT_PATHS_FILE, - ) - return parser - - -def main(args: argparse.Namespace) -> int | str: - """ - Main function of the hypothalamus segmentation module. - - Parameters - ---------- - args: argparse.Namespace - The arguments to the script as created by `options_parse`. - - Returns - ------- - int, str - 0, if successful, an error message describing the cause for the - failure otherwise. - """ - - # mapped freesurfer orig input name to the hypvinn t1 name - args.t1 = optional_path(args.orig_name) - # set output dir - args.out_dir = args.out_dir / args.sid - # Warning if run as root user - args.allow_root or assert_no_root() - start = time.time() - try: - # Set up logging - from FastSurferCNN.utils.logging import setup_logging - if not args.log_name: - args.log_name = args.out_dir / "scripts" / "hypvinn_seg.log" - setup_logging(args.log_name) - - LOGGER.info("Checking or downloading default checkpoints ...") - urls = load_checkpoint_config_defaults( - "url", - filename=CHECKPOINT_PATHS_FILE, - ) - get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, urls=urls) - - # Get configuration to run multi-modal or uni-modal - mode = get_hypinn_mode( - getattr(args, "t1", None), - getattr(args, "t2", None), - ) - args.mode = mode - - if mode: - # Create output directory if it does not already exist. - create_expand_output_directory(args.out_dir, args.qc_snapshots) - LOGGER.info( - f"Running HypVINN segmentation pipeline on subject {args.sid}" - ) - LOGGER.info(f"Output will be stored in: {args.out_dir}") - LOGGER.info(f"T1 image input {args.t1}") - LOGGER.info(f"T2 image input {args.t2}") - - # Pre-processing -- T1 and T2 registration - if mode == "t1t2": - # Note, that args.t1 and args.t2 are guaranteed to be not None - # via get_hypvinn_mode, which only returns t1t2, if t1 and t2 - # exist. - # hypvinn_preproc returns the path to the t2 that is registered - # to the t1 - args.t2 = hyvinn_preproc( - mode, - getattr(args, "reg_mode", "coreg"), - Path(args.t1), - Path(args.t2), - Path(args.out_dir), - ) - # Segmentation pipeline - run_hypo_seg( - args, - subject_name=args.sid, - out_dir=Path(args.out_dir), - t1_path=Path(args.t1), - t2_path=Path(args.t2), - mode=args.mode, - threads=args.threads, - seg_file=args.hypo_segfile, - ) - else: - return ( - f"Failed Evaluation on {args.sid} couldn't determine the " - f"processing mode. Please check that T1 or T2 images are " - f"available.\nT1 image path: {args.t1}\nT2 image path " - f"{args.t2}.\nNo T1 or T2 image available." - ) - except (FileNotFoundError, RuntimeError) as e: - LOGGER.info(f"Failed Evaluation on {args.sid}:") - LOGGER.exception(e) - else: - LOGGER.info( - f"Processing whole pipeline finished in {time.time() - start:.4f} " - f"seconds" - ) - - -if __name__ == "__main__": - # arguments - parser = option_parse() - args = parser.parse_args() - import sys - sys.exit(main(args)) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index ed50005c..6725541f 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -13,26 +13,37 @@ # limitations under the License. # IMPORTS +from typing import TYPE_CHECKING, Optional, cast, Literal import argparse from pathlib import Path from time import time import numpy as np +from numpy import typing as npt import torch -import nibabel as nib -import FastSurferCNN.utils.logging as logging +if TYPE_CHECKING: + import yacs.config + from nibabel.filebasedimages import FileBasedHeader + +from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults +from FastSurferCNN.utils.checkpoint import ( + get_checkpoints, + load_checkpoint_config_defaults, +) +from FastSurferCNN.utils.common import assert_no_root, SerialExecutor from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME -from HypVINN.config.hypvinn_global_var import Plane, planes -from HypVINN.data_loader.data_utils import hypo_map_label2subseg +from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image from HypVINN.inference import Inference -from HypVINN.models.networks import HypVINN from HypVINN.utils import ModalityDict, ModalityMode, ViewOperations +from HypVINN.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE +from HypVINN.utils.img_processing_utils import save_segmentation from HypVINN.utils.load_config import load_config -from HypVINN.data_loader.data_utils import rescale_image +from HypVINN.utils.misc import create_expand_output_directory +from HypVINN.utils.mode_config import get_hypinn_mode +from HypVINN.utils.preproc import hyvinn_preproc from HypVINN.utils.stats_utils import compute_stats -from HypVINN.utils.img_processing_utils import save_segmentation from HypVINN.utils.visualization_utils import plot_qc_images logger = logging.get_logger(__name__) @@ -42,81 +53,413 @@ ## +def optional_path(a: Path | str) -> Optional[Path]: + """ + Convert a string to a Path object or None. + + Parameters + ---------- + a : str + The string to convert. + + Returns + ------- + Optional[Path] + The Path object or None. + """ + if isinstance(a, Path): + return a + if a.lower() in ("none", ""): + return None + return Path(a) + + +def option_parse() -> argparse.ArgumentParser: + """ + A function to create an ArgumentParser object and parse the command line arguments. + + Returns + ------- + argparse.Ar + The parser object to parse arguments from the command line. + """ + parser = argparse.ArgumentParser( + description="Script for Hypothalamus Segmentation.", + ) + + # 1. Directory information (where to read from, where to write from and to incl. search-tag) + parser = parser_defaults.add_arguments( + parser, ["in_dir", "sd", "sid"], + ) + + parser = parser_defaults.add_arguments(parser, ["seg_log"]) + + # 2. Options for the MRI volumes + parser = parser_defaults.add_arguments( + parser, ["t1"] + ) + parser.add_argument( + '--t2', + type=optional_path, + default=None, + required=False, + help="Path to the T2 image to process.", + ) + + # 3. Image processing options + parser.add_argument( + "--qc_snap", + action='store_true', + dest="qc_snapshots", + help="Create qc snapshots in //qc_snapshots.", + ) + parser.add_argument( + "--reg_mode", + type=str, + default="coreg", + choices=["none", "coreg", "robust"], + help="Freesurfer Registration type to run. coreg: mri_coreg, " + "robust : mri_robust_register, none: entirely deactivates " + "registration of T2 to T1, if both images are passed, " + "images need to be register properly externally.", + ) + default_hypo_segfile = Path("mri") / HYPVINN_SEG_NAME + parser.add_argument( + "--hypo_segfile", + type=Path, + default=default_hypo_segfile, + dest="hypo_segfile", + help=f"File pattern on where to save the hypothalamus segmentation file " + f"(default: {default_hypo_segfile})." + ) + + # 4. Options for advanced, technical parameters + advanced = parser.add_argument_group(title="Advanced options") + parser_defaults.add_arguments( + advanced, + ["device", "viewagg_device", "threads", "batch_size", "async_io", "allow_root"], + ) + + files: dict[Plane, str | Path] = {k: "default" for k in PLANES} + # 5. Checkpoint to load + parser_defaults.add_plane_flags( + advanced, + "checkpoint", + files, + CHECKPOINT_PATHS_FILE, + ) + + parser_defaults.add_plane_flags( + advanced, + "config", + { + "coronal": Path("HypVINN/config/HypVINN_coronal_v1.0.0.yaml"), + "axial": Path("HypVINN/config/HypVINN_axial_v1.0.0.yaml"), + "sagittal": Path("HypVINN/config/HypVINN_sagittal_v1.0.0.yaml"), + }, + CHECKPOINT_PATHS_FILE, + ) + return parser + + +def main( + out_dir: Path, + t2: Optional[Path], + orig_name: Optional[Path], + sid: str, + ckpt_ax: Path, + ckpt_cor: Path, + ckpt_sag: Path, + cfg_ax: Path, + cfg_cor: Path, + cfg_sag: Path, + seg_file: str = HYPVINN_SEG_NAME, + allow_root: bool = False, + qc_snapshots: bool = False, + reg_mode: Literal["coreg", "robust", "none"] = "coreg", + threads: int = -1, + batch_size: int = 1, + async_io: bool = False, + device: str = "auto", + viewagg_device: str = "auto", +) -> int | str: + """ + Main function of the hypothalamus segmentation module. + + Parameters + ---------- + + Returns + ------- + int, str + 0, if successful, an error message describing the cause for the + failure otherwise. + """ + from concurrent.futures import ProcessPoolExecutor, Future + if threads != 1: + pool = ProcessPoolExecutor(threads) + else: + pool = SerialExecutor() + prep_tasks: dict[str, Future] = {} + + # mapped freesurfer orig input name to the hypvinn t1 name + t1_path = orig_name + t2_path = t2 + subject_name = sid + subject_dir = out_dir / sid + # Warning if run as root user + allow_root or assert_no_root() + start = time() + try: + # Set up logging + prep_tasks["cp"] = pool.submit(prepare_checkpoints, ckpt_ax, ckpt_cor, ckpt_sag) + + kwargs = {} + if t1_path is not None: + kwargs["t1_path"] = Path(t1_path) + if t2_path: + kwargs["t2_path"] = Path(t2_path) + # Get configuration to run multi-modal or uni-modal + mode = get_hypinn_mode(**kwargs) + + if not mode: + return ( + f"Failed Evaluation on {subject_name} couldn't determine the " + f"processing mode. Please check that T1 or T2 images are " + f"available.\nT1 image path: {t1_path}\nT2 image path " + f"{t2_path}.\nNo T1 or T2 image available." + ) + + # Create output directory if it does not already exist. + create_expand_output_directory(out_dir, qc_snapshots) + logger.info( + f"Running HypVINN segmentation pipeline on subject {sid}" + ) + logger.info(f"Output will be stored in: {subject_dir}") + logger.info(f"T1 image input {t1_path}") + logger.info(f"T2 image input {t2_path}") + + # Pre-processing -- T1 and T2 registration + if mode == "t1t2": + # Note, that t1_path and t2_path are guaranteed to be not None + # via get_hypvinn_mode, which only returns t1t2, if t1 and t2 + # exist. + # hypvinn_preproc returns the path to the t2 that is registered + # to the t1 + prep_tasks["reg"] = pool.submit( + hyvinn_preproc, + mode, + reg_mode, + out_dir=Path(out_dir), + **kwargs, + ) + + # Segmentation pipeline + seg = time() + view_ops: ViewOperations = {a: None for a in PLANES} + logger.info("Setting up HypVINN run") + + cfgs = (cfg_ax, cfg_cor, cfg_sag) + ckpts = (ckpt_ax, ckpt_cor, ckpt_sag) + for plane, _cfg_file, _ckpt_file in zip(PLANES, cfgs, ckpts): + logger.info(f"{plane} model configuration from {_cfg_file}") + view_ops[plane] = { + "cfg": set_up_cfgs(_cfg_file, out_dir, batch_size), + "ckpt": _ckpt_file, + } + + model = view_ops[plane]["cfg"].MODEL + if mode != model.MODE and "HypVinn" not in model.MODEL_NAME: + raise AssertionError( + f"Modality mode different between input arg: " + f"{mode} and axial train cfg: {model.MODE}" + ) + + cfg_fin, ckpt_fin = view_ops["coronal"].values() + + if "reg" in prep_tasks: + t2_path = prep_tasks["reg"].result() + kwargs["t2_path"] = t2_path + prep_tasks["load"] = pool.submit(load_volumes, mode=mode, **kwargs) + + # Set up model + model = Inference( + cfg=cfg_fin, + async_io=async_io, + threads=threads, + viewagg_device=viewagg_device, + device=device, + ) + + logger.info('----' * 30) + logger.info(f"Evaluating hypothalamus model on {subject_name}") + + # wait for all prep tasks to finish + for ptask in prep_tasks.values(): + if e := ptask.exception(): + raise e + + # Load Images + image_data, affine, header, orig_zoom, orig_size = prep_tasks["load"].result() + logger.info(f"Scale factor: {orig_zoom}") + + pred = time() + pred_classes = get_prediction( + subject_name, + image_data, + orig_zoom, + model, + target_shape=orig_size, + view_opts=view_ops, + out_scale=None, + mode=mode, + ) + logger.info(f"Model prediction finished in {time() - pred:0.4f} seconds") + logger.info(f"Saving prediction at {out_dir}") + + save = time() + if mode == 't1t2' or mode == 't1': + orig_path = t1_path + else: + orig_path = t2_path + + save_future: Future = pool.submit( + save_segmentation, + pred_classes, + orig_path=orig_path, + ras_affine=affine, + ras_header=header, + save_dir=out_dir, + seg_file=seg_file, + save_mask=True, + ) + save_future.add_done_callback(lambda x: logger.info(f"Prediction successfully saved as {x}")) + if qc_snapshots: + plot_qc_images( + save_dir=out_dir / "qc_snapshots", + orig_path=orig_path, + prediction_path=pred_path, + ) + + logger.info("Computing stats") + return_value = compute_stats( + orig_path=orig_path, + prediction_path=pred_path, + save_dir=out_dir / "stats", + threads=threads, + ) + if return_value != 0: + logger.error(return_value) + + logger.info( + f"Processing segmentation finished in {time() - seg:0.4f} seconds" + ) + except (FileNotFoundError, RuntimeError) as e: + logger.info(f"Failed Evaluation on {subject_name}:") + logger.exception(e) + else: + logger.info( + f"Processing whole pipeline finished in {time() - start:.4f} " + f"seconds" + ) + + +def prepare_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag): + logger.info("Checking or downloading default checkpoints ...") + urls = load_checkpoint_config_defaults( + "url", + filename=CHECKPOINT_PATHS_FILE, + ) + get_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag, urls=urls) + + def load_volumes( mode: ModalityMode, - t1_path: Path, - t2_path: Path, + t1_path: Optional[Path] = None, + t2_path: Optional[Path] = None, ) -> tuple[ ModalityDict, - np.ndarray, - nib.FilebasedHeader, - np.ndarray, - tuple[int, ...], + npt.NDArray[float], + "FileBasedHeader", + tuple[float, float, float], + tuple[int, int, int], ]: + import nibabel as nib modalities: ModalityDict = {} t1_size = () t2_size = () t1_zoom = () t2_zoom = () - affine = np.ndarray([0]) - header = None - zoom = () - size = () + affine: npt.NDArray[float] = np.ndarray([0]) + header: Optional["FileBasedHeader"] = None + zoom: tuple[float, float, float] = (0.0, 0.0, 0.0) + size: tuple[int, ...] = (0, 0, 0) - if "t1" in mode: + if t1_path: logger.info(f'Loading T1 image from : {t1_path}') t1 = nib.load(t1_path) t1 = nib.as_closest_canonical(t1) if mode in ('t1t2', 't1'): affine = t1.affine header = t1.header + else: + raise RuntimeError(f"Invalid mode {mode}, or inconsistent with t1_path!") t1_zoom = t1.header.get_zooms() - zoom = np.round(t1_zoom, 3) + zoom = cast(tuple[float, float, float], tuple(np.round(t1_zoom, 3))) # Conform Intensities - modalities['t1'] = rescale_image(np.asarray(t1.dataobj)) - t1_size = modalities['t1'].shape + modalities["t1"] = rescale_image(np.asarray(t1.dataobj)) + t1_size: tuple[int, ...] = modalities["t1"].shape size = t1_size - if "t2" in mode: - logger.info(f'Loading T2 image from : {t2_path}') + if t2_path: + logger.info(f"Loading T2 image from {t2_path}") t2 = nib.load(t2_path) t2 = nib.as_closest_canonical(t2) - if mode == 't2': + t2_zoom = t2.header.get_zooms() + if mode == "t2": affine = t2.affine header = t2.header - t2_zoom = t2.header.get_zooms() - zoom = np.round(t2_zoom, 3) + zoom = cast(tuple[float, float, float], tuple(np.round(t2_zoom, 3))) + elif mode == "t1t2": + pass + else: + raise RuntimeError(f"Invalid mode {mode}, or inconsistent with t2_path!") # Conform Intensities - modalities['t2'] = np.asarray(rescale_image(t2.get_fdata()), dtype=np.uint8) - t2_size = modalities['t2'].shape + modalities["t2"] = np.asarray(rescale_image(t2.get_fdata()), dtype=np.uint8) + t2_size = modalities["t2"].shape size = t2_size if mode == "t1t2": if not np.allclose(np.array(t1_zoom), np.array(t2_zoom), rtol=0.05): raise AssertionError( - f"T1 : {t1_zoom} and T2 : {t2_zoom} images have different " - f"resolutions" + f"T1 {t1_zoom} and T2 {t2_zoom} images have different resolutions!" ) if not np.allclose(np.array(t1_size), np.array(t2_size), rtol=0.05): raise AssertionError( - f"T1 : {t1_size} and T2 : {t2_size} images have different size" + f"T1 {t1_size} and T2 {t2_size} images have different size!" ) elif mode not in ("t1", "t2"): - raise ValueError(f"Invalid Mode in for modality {mode}") + raise ValueError(f"Invalid mode {mode}, vs. 't1', 't2', 't1t2'") + + if header is None: + raise ValueError("Missing a header!") + if len(size) != 3: + raise RuntimeError("Invalid ndims of data!") + _size = cast(tuple[int, int, int], size) - return modalities, affine, header, zoom, size + return modalities, affine, header, zoom, _size def get_prediction( subject_name: str, modalities: ModalityDict, orig_zoom, - model: HypVINN, + model: Inference, target_shape: tuple[int, int, int], view_opts: ViewOperations, out_scale=None, mode: ModalityMode = "t1t2", -) -> torch.Tensor: +) -> npt.NDArray[int]: device, viewagg_device = model.get_device() dim = model.get_max_size() @@ -156,120 +499,33 @@ def get_prediction( ## # Processing ## -def set_up_cfgs(cfg, args): +def set_up_cfgs( + cfg: "yacs.config.CfgNode", + out_dir: Path, + batch_size: int = 1, +) -> "yacs.config.CfgNode": cfg = load_config(cfg) - cfg.OUT_LOG_DIR = args.out_dir if args.out_dir is not None else cfg.LOG_DIR - cfg.TEST.BATCH_SIZE = args.batch_size + cfg.OUT_LOG_DIR = out_dir or cfg.LOG_DIR + cfg.TEST.BATCH_SIZE = batch_size out_dims = cfg.DATA.PADDED_SIZE - cfg.MODEL.OUT_TENSOR_WIDTH = out_dims if out_dims > cfg.DATA.PADDED_SIZE else cfg.DATA.PADDED_SIZE - cfg.MODEL.OUT_TENSOR_HEIGHT = out_dims if out_dims > cfg.DATA.PADDED_SIZE else cfg.DATA.PADDED_SIZE + if out_dims > cfg.DATA.PADDED_SIZE: + cfg.MODEL.OUT_TENSOR_WIDTH = out_dims + cfg.MODEL.OUT_TENSOR_HEIGHT = out_dims + else: + cfg.MODEL.OUT_TENSOR_WIDTH = cfg.DATA.PADDED_SIZE + cfg.MODEL.OUT_TENSOR_HEIGHT = cfg.DATA.PADDED_SIZE return cfg -def run_hypo_seg( - args: argparse.Namespace, - subject_name: str, - mode: ModalityMode, - t1_path: Path, - t2_path: Path, - out_dir: Path, - threads: int, - seg_file: Path = Path("mri") / HYPVINN_SEG_NAME, -): - start = time() - - view_ops: ViewOperations = {a: None for a in planes} - logger.info('Setting up HypVINN run') - cfg_ax = set_up_cfgs(args.cfg_ax, args) - logger.info(f'Axial model configuration from : {args.cfg_ax}') - view_ops["axial"] = {"cfg": cfg_ax, "ckpt": args.ckpt_ax} - - cfg_sag = set_up_cfgs(args.cfg_sag, args) - logger.info(f'Sagittal model configuration from : {args.cfg_sag}') - view_ops["sagittal"] = {"cfg": cfg_sag, "ckpt": args.ckpt_sag} +if __name__ == "__main__": + # arguments + parser = option_parse() + args = vars(parser.parse_args()) + log_name = args["log_name"] or args["out_dir"] / "scripts" / "hypvinn_seg.log" - cfg_cor = set_up_cfgs(args.cfg_cor, args) - logger.info(f'Coronal model configuration from : {args.cfg_cor}') - view_ops["coronal"] = {"cfg": cfg_cor, "ckpt": args.ckpt_cor} + from FastSurferCNN.utils.logging import setup_logging + setup_logging(log_name) - for plane, pcfg in zip(planes, (cfg_ax, cfg_cor, cfg_sag)): - model = pcfg.MODEL - if mode != model.MODE and 'HypVinn' not in model.MODEL_NAME: - raise AssertionError( - f"Modality mode different between input arg: " - f"{mode} and axial train cfg: {model.MODE}" - ) - - cfg_fin, ckpt_fin = cfg_cor, args.ckpt_cor - - # Set up model - model = Inference(cfg=cfg_fin, args=args) - - logger.info('----' * 30) - logger.info(f"Evaluating hypothalamus model on {subject_name}") - load = time() - - # Load Images - modalities, ras_affine, ras_header, orig_zoom, orig_size = load_volumes( - mode=mode, - t1_path=t1_path, - t2_path=t2_path, - ) - logger.info(f"Scale factor: {orig_zoom}") - logger.info(f"images loaded in {time() - load:0.4f} seconds") - - load = time() - pred_classes = get_prediction( - subject_name, - modalities, - orig_zoom, - model, - target_shape=orig_size, - view_opts=view_ops, - out_scale=None, - mode=mode, - logger=logger, - ) - logger.info(f"Model prediction finished in {time() - load:0.4f} seconds") - logger.info(f"Saving prediction at {out_dir}") - - save = time() - if mode == 't1t2' or mode == 't1': - orig_path = t1_path - else: - orig_path = t2_path - - pred_path = save_segmentation( - pred_classes, - orig_path=orig_path, - ras_affine=ras_affine, - ras_header=ras_header, - save_dir=out_dir, - seg_file=seg_file, - save_mask=True, - ) - logger.info( - f"Prediction successfully saved as {pred_path} in " - f"{time() - save:0.4f} seconds" - ) - if getattr(args, "qc_snapshots", False): - plot_qc_images( - save_dir=out_dir / "qc_snapshots", - orig_path=orig_path, - prediction_path=pred_path, - ) - - logger.info("Computing stats") - return_value = compute_stats( - orig_path=orig_path, - prediction_path=pred_path, - save_dir=out_dir / "stats", - threads=threads, - ) - if return_value != 0: - logger.error(return_value) - - logger.info( - f"Processing segmentation finished in {time() - start:0.4f} seconds" - ) + import sys + sys.exit(main(**args)) diff --git a/HypVINN/utils/checkpoint.py b/HypVINN/utils/checkpoint.py index cd7f7f9d..f2e122b5 100644 --- a/HypVINN/utils/checkpoint.py +++ b/HypVINN/utils/checkpoint.py @@ -14,4 +14,4 @@ from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT -YAML_DEFAULT = FASTSURFER_ROOT / "CerebNet/config/checkpoint_paths.yaml" +YAML_DEFAULT = FASTSURFER_ROOT / "HypVINN/config/checkpoint_paths.yaml" diff --git a/HypVINN/utils/mode_config.py b/HypVINN/utils/mode_config.py index b21eda0d..15ba6d74 100644 --- a/HypVINN/utils/mode_config.py +++ b/HypVINN/utils/mode_config.py @@ -23,8 +23,8 @@ def get_hypinn_mode( - t1_path: Optional[Path], - t2_path: Optional[Path], + t1_path: Optional[Path] = None, + t2_path: Optional[Path] = None, ) -> ModalityMode: LOGGER.info("Setting up input mode...") diff --git a/doc/overview/OUTPUT_FILES.md b/doc/overview/OUTPUT_FILES.md index 29e0eef5..0ef19454 100644 --- a/doc/overview/OUTPUT_FILES.md +++ b/doc/overview/OUTPUT_FILES.md @@ -30,19 +30,18 @@ The cerebellum module outputs the files in the table shown below. Unless switche The hypothalamus module outputs the files in the table shown below. Unless switched off by the `--no_hypvinn` argument, this module is automatically run whenever the segmentation module is run. It adds three files, an image with the sub-segmentation of the hypothalamus and a text file with summary statistics. -| directory | filename | module | description | -|:------------|----------------------------------|---------|-----------------------------------------------| -| mri | hypothalamus.HypVINN.nii.gz | hypvinn | hypothalamus sub-segmentation | -| mri | hypothalamus_mask.HypVINN.nii.gz | hypvinn | hypothalamus sub-segmentation mask | -| stats | hypothalamus.HypVINN.stats | hypvinn | table of hypothalamus segmentation statistics | +| directory | filename | module | description | +|:----------|----------------------------------|---------|-----------------------------------------------| +| mri | hypothalamus.HypVINN.nii.gz | hypvinn | hypothalamus sub-segmentation | +| mri | hypothalamus_mask.HypVINN.nii.gz | hypvinn | hypothalamus sub-segmentation mask | +| stats | hypothalamus.HypVINN.stats | hypvinn | table of hypothalamus segmentation statistics | If a T2 image is also passed, the following images are created. -| directory | filename | module | description | -|:----------|-----------------|---------|--------------------------------| -| mri | T2_orig.mgz | hypvinn | conformed T2 image | -| mri | T2_orig_nu.mgz | hypvinn | biasfield-corrected T2 image | -| mri | T2_nu_reg.mgz | hypvinn | co-registered T2 to orig image | +| directory | filename | module | description | +|:----------|---------------|---------|--------------------------------| +| mri | T2_nu.mgz | hypvinn | biasfield-corrected T2 image | +| mri | T2_nu_reg.mgz | hypvinn | co-registered T2 to orig image | ## Surface module diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 89054b2e..f7740639 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -32,6 +32,7 @@ fi fastsurfercnndir="$FASTSURFER_HOME/FastSurferCNN" cerebnetdir="$FASTSURFER_HOME/CerebNet" +hypvinndir="$FASTSURFER_HOME/HypVINN" reconsurfdir="$FASTSURFER_HOME/recon_surf" # Regular flags defaults @@ -49,6 +50,8 @@ hypo_segfile="" hypo_statsfile="" hypvinn_flags=() conformed_name="" +norm_name="" +norm_name_t2="" seg_log="" run_talairach_registration="false" atlas3T="false" @@ -374,7 +377,7 @@ case $key in --reg_mode) mode=$(echo "$2" | tr "[:upper:]" "[:lower:]") if [[ "$mode" =~ /^(none|coreg|robust)$/ ]] ; then - hypvinn_flags=("${hypvinn_flags[@]}" --regmode "$mode") + hypvinn_flags+=(--regmode "$mode") else echo "Invalid --reg_mode option, must be 'none', 'coreg' or 'robust'." fi @@ -391,6 +394,11 @@ case $key in shift # past argument shift # past value ;; + --norm_name_t2) + norm_name_t2="$2" + shift # past argument + shift # past value + ;; --aseg_segfile) aseg_segfile="$2" shift # past argument @@ -682,6 +690,11 @@ if [[ -z "$norm_name" ]] norm_name="${sd}/${subject}/mri/orig_nu.mgz" fi +if [[ -z "$norm_name_t2" ]] + then + norm_name_t2="${sd}/${subject}/mri/T2_nu.mgz" +fi + if [[ -z "$seg_log" ]] then seg_log="${sd}/${subject}/scripts/deep-seg.log" @@ -825,8 +838,6 @@ if [[ "$run_seg_pipeline" == "1" ]] exit 1 fi fi - # TODO Bias field correct also the the t2 input - # compute the bias-field corrected image if [[ "$run_biasfield" == "1" ]] then # this will always run, since norm_name is set to subject_dir/mri/orig_nu.mgz, if it is not passed/empty @@ -875,6 +886,23 @@ if [[ "$run_seg_pipeline" == "1" ]] exit 1 fi fi + + if [[ -n "$t2" ]] + then + # ... we have a t2 image, bias field-correct it + echo "INFO: Running N4 bias-field correction of the t2" | tee -a "$seg_log" + cmd=($python "${reconsurfdir}/N4_bias_correct.py" "--in" "$t2" + --out "$norm_name_t2" --threads "$threads") + echo "${cmd[@]}" |& tee -a "$seg_log" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: T2 Biasfield correction failed" | tee -a "$seg_log" + exit 1 + fi + + + fi fi if [[ "$run_cereb_module" == "1" ]] @@ -903,6 +931,30 @@ if [[ "$run_seg_pipeline" == "1" ]] fi fi + if [[ "$run_hypvinn_module" == "1" ]] + then + cmd=($python "$hypvinndir/run_prediction.py" --sd "${sd}" --sid "${subject}" + "${hypvinn_flags[@]}" "${allow_root[@]}" --threads "$threads" --async_io + --batch_size "$batch_size" --seg_log "$seg_log" --device "$device" + --viewagg_device "$viewagg_device" --t1) + if [[ "$run_biasfield" == "1" ]] + then + cmd+=("$norm_name") + if [[ -n "$t2" ]] ; then cmd+=(--t2 "$norm_name_t2"); fi + else + echo "WARNING: We strongly recommended to run the hypvinn module is not run with --no_biasfield!" + cmd+=("$t1") + if [[ -n "$t2" ]] ; then cmd+=(--t2 "$t2"); fi + fi + echo "${cmd[@]}" |& tee -a "$seg_log" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: Hypothalamus Segmentation failed" |& tee -a "$seg_log" + exit 1 + fi + fi + # if [[ ! -f "$merged_segfile" ]] # then # ln -s -r "$asegdkt_segfile" "$merged_segfile"