diff --git a/CerebNet/run_prediction.py b/CerebNet/run_prediction.py index f9b2c4c2..d37c3183 100644 --- a/CerebNet/run_prediction.py +++ b/CerebNet/run_prediction.py @@ -25,13 +25,12 @@ from FastSurferCNN.utils.checkpoint import ( get_checkpoints, load_checkpoint_config_defaults, + YAML_DEFAULT as CHECKPOINT_PATHS_FILE, ) from FastSurferCNN.utils.common import assert_no_root, SubjectList -from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT logger = logging.get_logger(__name__) DEFAULT_CEREBELLUM_STATSFILE = Path("stats/cerebellum.CerebNet.stats") -CHECKPOINT_PATHS_FILE = FASTSURFER_ROOT / "CerebNet/config/checkpoint_paths.yaml" def setup_options(): @@ -90,7 +89,7 @@ def setup_options(): advanced, "checkpoint", files, - CHECKPOINT_PATHS_FILE + CHECKPOINT_PATHS_FILE, ) parser.add_argument( diff --git a/FastSurferCNN/utils/checkpoint.py b/FastSurferCNN/utils/checkpoint.py index 635b5db4..a886a7a0 100644 --- a/FastSurferCNN/utils/checkpoint.py +++ b/FastSurferCNN/utils/checkpoint.py @@ -14,8 +14,9 @@ # IMPORTS import os +from functools import lru_cache from pathlib import Path -from typing import MutableSequence, Optional, Union, Literal, TypedDict, overload +from typing import MutableSequence, Optional, Union, Literal, TypedDict, cast, overload import requests import torch @@ -33,11 +34,15 @@ class CheckpointConfigDict(TypedDict, total=False): - URL: list[str] - CKPT: dict[Plane, Path] - CFG: dict[Plane, Path] + url: list[str] + checkpoint: dict[Plane, Path] + config: dict[Plane, Path] +CheckpointConfigFields = Literal["checkpoint", "config", "url"] + + +@lru_cache def load_checkpoint_config(filename: Path | str = YAML_DEFAULT) -> CheckpointConfigDict: """ Load the plane dictionary from the yaml file. @@ -90,7 +95,7 @@ def load_checkpoint_config_defaults( def load_checkpoint_config_defaults( - configtype: Literal["checkpoint", "config", "url"], + configtype: CheckpointConfigFields, filename: str | Path = YAML_DEFAULT, ) -> dict[Plane, Path] | list[str]: """ @@ -98,7 +103,7 @@ def load_checkpoint_config_defaults( Parameters ---------- - configtype : "checkpoint", "config", "url + configtype : "checkpoint", "config", "url" Type of value. filename : str, Path The path to the yaml file. Either absolute or relative to the FastSurfer root @@ -112,7 +117,7 @@ def load_checkpoint_config_defaults( if not isinstance(filename, Path): filename = Path(filename) - configtype = configtype.lower() + configtype = cast(CheckpointConfigFields, configtype.lower()) if configtype not in ("url", "checkpoint", "config"): raise ValueError("Type must be 'url', 'checkpoint' or 'config'") diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 7db47143..966784a3 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -381,11 +381,11 @@ def add_arguments(parser: T_AddArgs, flags: Iterable[str]) -> T_AddArgs: def add_plane_flags( - parser: argparse.ArgumentParser, + parser: T_AddArgs, configtype: Literal["checkpoint", "config"], files: Mapping[Plane, Path | str], defaults_path: Path | str, -) -> argparse.ArgumentParser: +) -> T_AddArgs: """ Add plane arguments. diff --git a/HypVINN/README.md b/HypVINN/README.md index f82f201c..2077efb2 100644 --- a/HypVINN/README.md +++ b/HypVINN/README.md @@ -26,7 +26,7 @@ Run the HypVINN/run_pipeline.py which has the following arguments: * `--t2 ` : T2 image path * `--seg_log` : Path to file in which run logs will be saved. If not set logs will be stored in /sd/sid/logs/hypvinn_seg.log ### Image processing options - * `--no_reg` : Deactivate registration of T2 to T1. If multi modal input is used; images need to be registered externally, + * `--no_reg` : Deactivate registration of T2 to T1. If multi-modal input is used; images need to be registered externally, * `--reg_mode` : Freesurfer Registration type to run. coreg : mri_coreg (Default) or robust : mri_robust_register. * `--qc_snap`: Activate the creation of QC snapshots of the predicted HypVINN segmentation. ### FastSurfer Technical parameters (see FastSurfer documentation) diff --git a/HypVINN/__init__.py b/HypVINN/__init__.py index d399b982..0dc9ff4b 100644 --- a/HypVINN/__init__.py +++ b/HypVINN/__init__.py @@ -20,5 +20,4 @@ "inference", "run_prediction", "run_pipeline", - "run_prepoc" -] \ No newline at end of file +] diff --git a/HypVINN/config/HypVINN_axial_v1.0.0.yaml b/HypVINN/config/HypVINN_axial_v1.0.0.yaml index f73ec293..45f4c948 100644 --- a/HypVINN/config/HypVINN_axial_v1.0.0.yaml +++ b/HypVINN/config/HypVINN_axial_v1.0.0.yaml @@ -12,7 +12,7 @@ MODEL: OUT_TENSOR_HEIGHT: 320 HEIGHT: 256 WIDTH: 256 - MODE : 'multi' + MODE : 't1t2' MULTI_AUTO_W : True HETERO_INPUT : True diff --git a/HypVINN/config/HypVINN_coronal_v1.0.0.yaml b/HypVINN/config/HypVINN_coronal_v1.0.0.yaml index ddda1752..828caaec 100644 --- a/HypVINN/config/HypVINN_coronal_v1.0.0.yaml +++ b/HypVINN/config/HypVINN_coronal_v1.0.0.yaml @@ -12,7 +12,7 @@ MODEL: OUT_TENSOR_HEIGHT: 320 HEIGHT: 256 WIDTH: 256 - MODE : 'multi' + MODE : 't1t2' MULTI_AUTO_W : True HETERO_INPUT : True diff --git a/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml b/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml index a3f61166..ee63ab31 100644 --- a/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml +++ b/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml @@ -12,7 +12,7 @@ MODEL: OUT_TENSOR_HEIGHT: 320 HEIGHT: 256 WIDTH: 256 - MODE : 'multi' + MODE : 't1t2' MULTI_AUTO_W : True HETERO_INPUT : True diff --git a/HypVINN/config/hypvinn.py b/HypVINN/config/hypvinn.py index 91fcc4ec..48c3e151 100644 --- a/HypVINN/config/hypvinn.py +++ b/HypVINN/config/hypvinn.py @@ -24,8 +24,8 @@ # Name of model _C.MODEL.MODEL_NAME = "" -#modalities 't1', 't2' or multi -_C.MODEL.MODE ='t1' +#modalities 't1', 't2' or 't1t2' +_C.MODEL.MODE = "t1" # Number of classes to predict, including background _C.MODEL.NUM_CLASSES = 79 diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py index bee66347..c22c4bc2 100644 --- a/HypVINN/data_loader/dataset.py +++ b/HypVINN/data_loader/dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +from numpy import typing as npt import torch from torch.utils.data import Dataset @@ -21,9 +22,11 @@ from FastSurferCNN.data_loader.data_utils import get_thick_slices import FastSurferCNN.utils.logging as logging +from HypVINN.utils import ModalityDict, ModalityMode logger = logging.get_logger(__name__) + # Operator to load imaged for inference class HypoVINN_dataset(Dataset): """ @@ -32,33 +35,41 @@ class HypoVINN_dataset(Dataset): The Weight factor determines the running mode of the HypVINN model if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1) if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2) - if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based on the learned modality weights (mode = multi) + if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based on the learned modality weights (mode = t1t2) """ - def __init__(self, subject_name, modalities, orig_zoom, cfg, mode='multi', transforms=None): + def __init__( + self, + subject_name: str, + modalities: ModalityDict, + orig_zoom: npt.NDArray[float], + cfg, + mode: ModalityMode = "t1t2", + transforms=None, + ): self.subject_name = subject_name self.plane = cfg.DATA.PLANE #Inference Mode self.mode = mode #set thickness base on train paramters - if cfg.MODEL.MODE in ['t1','t2']: + if cfg.MODEL.MODE in ["t1", "t2"]: self.slice_thickness = cfg.MODEL.NUM_CHANNELS//2 else: self.slice_thickness = cfg.MODEL.NUM_CHANNELS//4 self.base_res = cfg.MODEL.BASE_RES - if self.mode == 't1': - orig_thick = self._standarized_img(modalities['t1'],orig_zoom, modalitie='t1') + if self.mode == "t1": + orig_thick = self._standarized_img(modalities["t1"], orig_zoom, modality="t1") orig_thick = np.concatenate((orig_thick, orig_thick), axis=-1) self.weight_factor = torch.from_numpy(np.asarray([1.0, 0.0])) - elif self.mode == 't2': - orig_thick = self._standarized_img(modalities['t2'],orig_zoom, modalitie='t2') + elif self.mode == "t2": + orig_thick = self._standarized_img(modalities["t2"], orig_zoom, modality="t2") orig_thick = np.concatenate((orig_thick, orig_thick), axis=-1) self.weight_factor = torch.from_numpy(np.asarray([0.0, 1.0])) else: - t1_orig_thick = self._standarized_img(modalities['t1'], orig_zoom, modalitie='t1') - t2_orig_thick = self._standarized_img(modalities['t2'],orig_zoom, modalitie='t2') + t1_orig_thick = self._standarized_img(modalities["t1"], orig_zoom, modality="t1") + t2_orig_thick = self._standarized_img(modalities["t2"], orig_zoom, modality="t2") orig_thick = np.concatenate((t1_orig_thick, t2_orig_thick), axis=-1) self.weight_factor = torch.from_numpy(np.asarray([0.5, 0.5])) @@ -68,33 +79,49 @@ def __init__(self, subject_name, modalities, orig_zoom, cfg, mode='multi', trans self.count = self.images.shape[0] self.transforms = transforms - logger.info(f"Successfully loaded Image from {subject_name} for {self.plane} model") + logger.info( + f"Successfully loaded Image from {subject_name} for {self.plane} " + f"model" + ) - if (cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 'multi' or cfg.MODEL.DUPLICATE_INPUT) : - logger.info(f"For inference T1 block weight and the T2 block are set to the weights learn during training") + if (cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT) : + logger.info( + f"For inference T1 block weight and the T2 block are set to " + f"the weights learn during training" + ) else: - logger.info(f"For inference T1 block weight was set to : {self.weight_factor.numpy()[0]} and the T2 block was set to: {self.weight_factor.numpy()[1]}") + logger.info( + f"For inference T1 block weight was set to: " + f"{self.weight_factor.numpy()[0]} and the T2 block was set to: " + f"{self.weight_factor.numpy()[1]}") - def _standarized_img(self,orig_data,orig_zoom,modalitie): + def _standarized_img(self, orig_data, orig_zoom, modality): if self.plane == "sagittal": orig_data = transform_axial2sagittal(orig_data) self.zoom = orig_zoom[::-1][:2] - logger.info("Loading {} sagittal with input voxelsize {}".format(modalitie,self.zoom)) + logger.info( + f"Loading {modality} sagittal with input voxelsize {self.zoom}" + ) elif self.plane == "coronal": orig_data = transform_axial2coronal(orig_data) self.zoom = orig_zoom[1:] - logger.info("Loading {} coronal with input voxelsize {}".format(modalitie,self.zoom)) + logger.info( + f"Loading {modality} coronal with input voxelsize {self.zoom}" + ) else: self.zoom = orig_zoom[:2] - logger.info("Loading {} axial with input voxelsize {}".format(modalitie,self.zoom)) + logger.info( + f"Loading {modality} axial with input voxelsize {self.zoom}" + ) # Create thick slices orig_thick = get_thick_slices(orig_data, self.slice_thickness) return orig_thick - def _get_scale_factor(self): + + def _get_scale_factor(self) -> npt.NDArray[float]: """ Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is @@ -108,14 +135,18 @@ def _get_scale_factor(self): return scale - def __getitem__(self, index): + def __getitem__(self, index: int) -> dict[str, torch.Tensor | np.ndarray]: img = self.images[index] scale_factor = self._get_scale_factor() if self.transforms is not None: img = self.transforms(img) - return {'image': img, 'scale_factor': scale_factor,'weight_factor' : self.weight_factor} + return { + "image": img, + "scale_factor": scale_factor, + "weight_factor": self.weight_factor, + } def __len__(self): return self.count diff --git a/HypVINN/models/networks.py b/HypVINN/models/networks.py index e5fcdb94..a26b8024 100644 --- a/HypVINN/models/networks.py +++ b/HypVINN/models/networks.py @@ -24,8 +24,6 @@ import numpy as np - - class HypVINN(FastSurferCNNBase): """ Construct HypVINN object. diff --git a/HypVINN/run_pipeline.py b/HypVINN/run_pipeline.py index 68c74bd5..439b94f3 100644 --- a/HypVINN/run_pipeline.py +++ b/HypVINN/run_pipeline.py @@ -17,13 +17,14 @@ from typing import Optional import time -from FastSurferCNN.utils import logging, parser_defaults -from FastSurferCNN.utils.checkpoint import get_checkpoints +from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults +from FastSurferCNN.utils.checkpoint import get_checkpoints, load_checkpoint_config_defaults from FastSurferCNN.utils.common import assert_no_root from HypVINN.run_prediction import run_hypo_seg from HypVINN.utils.preproc import hyvinn_preproc -from HypVINN.utils.mode_config import get_hypinn_mode_config +from HypVINN.utils.mode_config import get_hypinn_mode from HypVINN.utils.misc import create_expand_output_directory +from HypVINN.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE ## # Global Variables ## @@ -31,20 +32,41 @@ def optional_path(a: str) -> Optional[Path]: - if a.lower() in ["none", ""]: + """ + Convert a string to a Path object or None. + + Parameters + ---------- + a : str + The string to convert. + + Returns + ------- + Optional[Path] + The Path object or None. + """ + if a.lower() in ("none", ""): return None return Path(a) -def option_parse() -> argparse.Namespace: +def option_parse() -> argparse.ArgumentParser: + """ + A function to create an ArgumentParser object and parse the command line arguments. + + Returns + ------- + argparse.Ar + The parser object to parse arguments from the command line. + """ from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME parser = argparse.ArgumentParser( - description='Hypothalamus Segmentation', + description="Script for Hypothalamus Segmentation.", ) # 1. Directory information (where to read from, where to write from and to incl. search-tag) parser = parser_defaults.add_arguments( - parser, ["in_dir", "sd", "sid"] + parser, ["in_dir", "sd", "sid"], ) parser = parser_defaults.add_arguments(parser, ["seg_log"]) @@ -93,13 +115,13 @@ def option_parse() -> argparse.Namespace: ["device", "viewagg_device", "threads", "batch_size", "async_io", "allow_root"], ) - from HypVINN.utils.checkpoint import HYPVINN_AXI, HYPVINN_COR, HYPVINN_SAG - + files: dict[Plane, str | Path] = {k: "default" for k in PLANES} # 5. Checkpoint to load parser_defaults.add_plane_flags( advanced, "checkpoint", - {"coronal": HYPVINN_COR, "axial": HYPVINN_AXI, "sagittal": HYPVINN_SAG}, + files, + CHECKPOINT_PATHS_FILE, ) parser_defaults.add_plane_flags( @@ -110,11 +132,26 @@ def option_parse() -> argparse.Namespace: "axial": Path("HypVINN/config/HypVINN_axial_v1.0.0.yaml"), "sagittal": Path("HypVINN/config/HypVINN_sagittal_v1.0.0.yaml"), }, + CHECKPOINT_PATHS_FILE, ) - return parser.parse_args() + return parser def main(args: argparse.Namespace) -> int | str: + """ + Main function of the hypothalamus segmentation module. + + Parameters + ---------- + args: argparse.Namespace + The arguments to the script as created by `options_parse`. + + Returns + ------- + int, str + 0, if successful, an error message describing the cause for the + failure otherwise. + """ # mapped freesurfer orig input name to the hypvinn t1 name args.t1 = optional_path(args.orig_name) @@ -131,13 +168,20 @@ def main(args: argparse.Namespace) -> int | str: setup_logging(args.log_name) LOGGER.info("Checking or downloading default checkpoints ...") - from HypVINN.utils.checkpoint import URL as HYPVINN_URL - get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, url=HYPVINN_URL) + urls = load_checkpoint_config_defaults( + "url", + filename=CHECKPOINT_PATHS_FILE, + ) + get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, urls=urls) # Get configuration to run multi-modal or uni-modal - args = get_hypinn_mode_config(args) + mode = get_hypinn_mode( + getattr(args, "t1", None), + getattr(args, "t2", None), + ) + args.mode = mode - if args.mode: + if mode: # Create output directory if it does not already exist. create_expand_output_directory(args.out_dir, args.qc_snapshots) LOGGER.info( @@ -148,12 +192,24 @@ def main(args: argparse.Namespace) -> int | str: LOGGER.info(f"T2 image input {args.t2}") # Pre-processing -- T1 and T2 registration - args = hyvinn_preproc(args) + if mode == "t1t2": + # Note, that args.t1 and args.t2 are guaranteed to be not None + # via get_hypvinn_mode, which only returns t1t2, if t1 and t2 + # exist. + # hypvinn_preproc returns the path to the t2 that is registered + # to the t1 + args.t2 = hyvinn_preproc( + mode, + getattr(args, "reg_mode", "coreg"), + Path(args.t1), + Path(args.t2), + Path(args.out_dir), + ) # Segmentation pipeline run_hypo_seg( args, subject_name=args.sid, - out_dir=args.out_dir, + out_dir=Path(args.out_dir), t1_path=Path(args.t1), t2_path=Path(args.t2), mode=args.mode, @@ -179,6 +235,7 @@ def main(args: argparse.Namespace) -> int | str: if __name__ == "__main__": # arguments - args = option_parse() + parser = option_parse() + args = parser.parse_args() import sys sys.exit(main(args)) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index c34e5f27..ed50005c 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -11,41 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# IMPORTS import argparse from pathlib import Path -from typing import Any, Literal, Optional +from time import time -# IMPORTS import numpy as np import torch -import os import nibabel as nib -from time import time -from collections import defaultdict import FastSurferCNN.utils.logging as logging -from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME +from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME from HypVINN.config.hypvinn_global_var import Plane, planes -from HypVINN.data_loader.data_utils import hypo_map_label2subseg, hypo_map_subseg_2_fsseg +from HypVINN.data_loader.data_utils import hypo_map_label2subseg from HypVINN.inference import Inference +from HypVINN.models.networks import HypVINN +from HypVINN.utils import ModalityDict, ModalityMode, ViewOperations from HypVINN.utils.load_config import load_config from HypVINN.data_loader.data_utils import rescale_image from HypVINN.utils.stats_utils import compute_stats from HypVINN.utils.img_processing_utils import save_segmentation from HypVINN.utils.visualization_utils import plot_qc_images -ViewOperations = dict[Plane, Optional[dict[Literal["cfg", "ckpt"], Any]]] - logger = logging.get_logger(__name__) ## # Input array preparation ## -ModalityMode = Literal["t1", "t2", "t1t2"] -ModalityDict = dict[Literal["t1", "t2"], np.ndarray] - def load_volumes( mode: ModalityMode, @@ -112,46 +107,41 @@ def load_volumes( return modalities, affine, header, zoom, size -def run_model(model, subject_name, modalities, orig_zoom, pred_prob, out_scale, mode='multi'): - # get prediction - pred_prob = model.run(subject_name, modalities, orig_zoom, pred_prob, out_res=out_scale, mode=mode) - - return pred_prob - - -def get_prediction(subject_name, modalities, orig_zoom, model, gt_shape, view_opts, logger, out_scale=None, - mode='multi'): +def get_prediction( + subject_name: str, + modalities: ModalityDict, + orig_zoom, + model: HypVINN, + target_shape: tuple[int, int, int], + view_opts: ViewOperations, + out_scale=None, + mode: ModalityMode = "t1t2", +) -> torch.Tensor: device, viewagg_device = model.get_device() dim = model.get_max_size() # Coronal model - logger.info(f'Evaluating Coronal model, cpkt :{view_opts["coronal"]["ckpt"]}') + logger.info(f"Evaluating Coronal model, cpkt: " + f"{view_opts['coronal']['ckpt']}") model.set_model(view_opts["coronal"]["cfg"]) model.load_checkpoint(view_opts["coronal"]["ckpt"]) - pred_prob = torch.zeros((dim, dim, dim, model.get_num_classes()), dtype=torch.float).to(viewagg_device) - - # Set up tensor to hold probabilities and run inference (coronal model by default) - pred_prob = run_model(model, subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode) - - # Axial model - logger.info(f'Evaluating Axial model, cpkt :{view_opts["axial"]["ckpt"]}') - model.set_cfg(view_opts["axial"]["cfg"]) - model.load_checkpoint(view_opts["axial"]["ckpt"]) - pred_prob += run_model(model, subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode) - - # Sagittal model - logger.info(f'Evaluating Sagittal model, cpkt :{view_opts["sagittal"]["ckpt"]}') - model.set_model(view_opts["sagittal"]["cfg"]) - model.load_checkpoint(view_opts["sagittal"]["ckpt"]) - pred_prob += run_model(model, subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode) + pred_shape = (dim, dim, dim, model.get_num_classes()) + # Set up tensor to hold probabilities and run inference + pred_prob = torch.zeros(pred_shape, dtype=torch.float, device=viewagg_device) + for plane, opts in view_opts.items(): + logger.info(f"Evaluating {plane} model, cpkt :{opts['ckpt']}") + model.set_cfg(opts["cfg"]) + model.load_checkpoint(opts["ckpt"]) + pred_prob += model.run(subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode) # Post processing - h, w, d = gt_shape # final prediction shape equivalent to input ground truth shape + h, w, d = target_shape # final prediction shape equivalent to input ground truth shape - if np.any(gt_shape < pred_prob.shape[:3]): - # if orig was padded before running through model (difference in aseg_size and pred_shape), select - # slices of interest only. This currently works only for "top_left" padding (see augmentation) + if np.any(target_shape < pred_prob.shape[:3]): + # if orig was padded before running through model (difference in + # aseg_size and pred_shape), select slices of interest only. + # This currently works only for "top_left" padding (see augmentation) pred_prob = pred_prob[0:h, 0:w, 0:d, :] # Get hard predictions and map to freesurfer label space @@ -235,7 +225,7 @@ def run_hypo_seg( modalities, orig_zoom, model, - gt_shape=orig_size, + target_shape=orig_size, view_opts=view_ops, out_scale=None, mode=mode, diff --git a/HypVINN/utils/__init__.py b/HypVINN/utils/__init__.py index e69de29b..2f2b06fe 100644 --- a/HypVINN/utils/__init__.py +++ b/HypVINN/utils/__init__.py @@ -0,0 +1,10 @@ +from typing import Any, Literal, Optional + +from numpy import ndarray + +from FastSurferCNN.utils import Plane + +ViewOperations = dict[Plane, Optional[dict[Literal["cfg", "ckpt"], Any]]] +ModalityMode = Literal["t1", "t2", "t1t2"] +ModalityDict = dict[Literal["t1", "t2"], ndarray] +RegistrationMode = Literal["robust", "coreg", "none"] diff --git a/HypVINN/utils/mode_config.py b/HypVINN/utils/mode_config.py index 314634c5..b21eda0d 100644 --- a/HypVINN/utils/mode_config.py +++ b/HypVINN/utils/mode_config.py @@ -13,59 +13,54 @@ # limitations under the License. import os +from pathlib import Path +from typing import Optional + from FastSurferCNN.utils import logging +from HypVINN.utils import ModalityMode LOGGER = logging.get_logger(__name__) -def get_hypinn_mode_config(args): +def get_hypinn_mode( + t1_path: Optional[Path], + t2_path: Optional[Path], +) -> ModalityMode: - LOGGER.info('Setting up input mode...') - if hasattr(args, 't1') and hasattr(args, 't2'): - if os.path.isfile(str(args.t1)) and os.path.isfile(str(args.t2)): - args.mode = 'multi' - elif os.path.isfile(str(args.t1)): - args.mode ='t1' - args.t2 = None - elif os.path.isfile(str(args.t2)): - args.mode ='t2' - args.t1 = None - LOGGER.info('Warning: T2 mode selected. Only passing a T2 image can generate not so accurate results.\n ' - 'Best results are obtained when a T2 image is accompanied with a T1 image.') - else: - args.mode= None + LOGGER.info("Setting up input mode...") + if t1_path is not None and t2_path is not None: + if t1_path.is_file() and t2_path.is_file(): + return "t1t2" + msg = [] + if not t1_path.is_file(): + msg.append(f"the t1 file does not exist ({t1_path})") + if not t2_path.is_file(): + msg.append(f"the t2 file does not exist ({t2_path})") + raise RuntimeError( + f"ERROR: Both the t1 and the t2 flags were passed, but " + f"{' and '.join(msg)}." + ) - elif hasattr(args, 't1'): - if os.path.isfile(str(args.t1)): - args.mode = 't1' - args.t2 = None - else: - if hasattr(args,'t2'): - if os.path.isfile(str(args.t2)): - args.mode = 't2' - args.t1 = None - LOGGER.info( - 'Warning: T2 mode selected. Only passing a T2 image can generate not so accurate results.\n ' - 'Best results are obtained when a T2 image is accompanied with a T1 image.') - else: - args.mode = None - else: - args.mode = None - elif hasattr(args,'t2'): - if os.path.isfile(str(args.t2)): - args.mode = 't2' - args.t1 = None - LOGGER.info('Warning: T2 mode selected. Only passing a T2 image can generate not so accurate results.\n ' - 'Best results are obtained when a T2 image is accompanied with a T1 image.') - else: - args.mode = None + elif t1_path: + if t1_path.is_file(): + return "t1" + raise RuntimeError( + f"ERROR: The t1 flag was passed, but the t1 file does not exist " + f"({t1_path})." + ) + elif t2_path: + if t2_path.is_file(): + LOGGER.info( + "Warning: T2 mode selected. The quality of segmentations based " + "on only a T2 image is significantly worse than when T1 images " + "are included." + ) + return "t2" + raise RuntimeError( + f"ERROR: The t2 flag was passed, but the t1 file does not exist " + f"({t1_path})." + ) else: - args.mode = None - - if args.mode: - LOGGER.info('HypVINN mode is setup to {} input mode'.format(args.mode)) - - return args - - - + raise RuntimeError( + "No t1 or t2 flags were passed, invalid configuration." + ) diff --git a/HypVINN/utils/preproc.py b/HypVINN/utils/preproc.py index 46ebba6d..d327a79b 100644 --- a/HypVINN/utils/preproc.py +++ b/HypVINN/utils/preproc.py @@ -15,22 +15,30 @@ import argparse import time +from pathlib import Path + import nibabel as nib import os import numpy as np from FastSurferCNN.utils import logging +from HypVINN.utils import ModalityMode, RegistrationMode LOGGER = logging.get_logger(__name__) -def t1_to_t2_registration(t1_path, t2_path, out_dir, registration_type="coreg"): +def t1_to_t2_registration( + t1_path: Path, + t2_path: Path, + out_dir: Path, + registration_type: RegistrationMode = "coreg", +) -> Path: from FastSurferCNN.utils.run_tools import Popen import shutil - lta_path = os.path.join(out_dir, "mri", "transforms", "t2tot1.lta") + lta_path = out_dir / "mri/transforms/t2tot1.lta" - t2_reg_path = os.path.join(out_dir, "mri", "T2_nu_reg.mgz") + t2_reg_path = out_dir / "mri/T2_nu_reg.mgz" if registration_type == "coreg": exe = shutil.which("mri_coreg") @@ -57,16 +65,17 @@ def t1_to_t2_registration(t1_path, t2_path, out_dir, registration_type="coreg"): else: raise RuntimeError( "Could not find mri_vol2vol, source FreeSurfer or set " - "the FREESURFER_HOME environment variable" + "the FREESURFER_HOME environment variable" ) - args = [exe, - "--mov", t2_path, - "--targ", t1_path, - "--reg", lta_path, - "--o", t2_reg_path, - "--cubic", - "--keep-precision", - ] + args = [ + exe, + "--mov", t2_path, + "--targ", t1_path, + "--reg", lta_path, + "--o", t2_reg_path, + "--cubic", + "--keep-precision", + ] LOGGER.info("Running " + " ".join(args)) retval = Popen(args).finish() if retval.retcode != 0: @@ -82,15 +91,16 @@ def t1_to_t2_registration(t1_path, t2_path, out_dir, registration_type="coreg"): else: raise RuntimeError( "Could not find mri_robust_register, source FreeSurfer or " - "set the FREESURFER_HOME environment variable" + "set the FREESURFER_HOME environment variable" ) - args = [exe, - "--mov", t2_path, - "--dst", t1_path, - "--lta", lta_path, - "--mapmov", t2_reg_path, - "--cost NMI", - ] + args = [ + exe, + "--mov", t2_path, + "--dst", t1_path, + "--lta", lta_path, + "--mapmov", t2_reg_path, + "--cost NMI", + ] LOGGER.info("Running " + " ".join(args)) retval = Popen(args).finish() if retval.retcode != 0: @@ -102,41 +112,50 @@ def t1_to_t2_registration(t1_path, t2_path, out_dir, registration_type="coreg"): return t2_reg_path -def hyvinn_preproc(args: argparse.Namespace) -> argparse.Namespace: - - if args.mode == "multi": - if args.reg_mode != "none": - load_res = time.time() - # Print Warning if Resolution from both images is different - t1_zoom = nib.load(args.t1).header.get_zooms() - t2_zoom = nib.load(args.t2).header.get_zooms() - - if not np.allclose(np.array(t1_zoom), np.array(t2_zoom),rtol=0.05): - LOGGER.info( - f"Warning: Resolution from T1 ({t1_zoom}) and T2 " - f"({t2_zoom}) image are different.\n " - "Resolution of the T2 image will be interpolated " - "to the one of the T1 image." - ) - - LOGGER.info("Registering T1 to T2 ...") - args.t2 = t1_to_t2_registration( - t1_path=args.t1, - t2_path=args.t2, - out_dir=args.out_dir, - registration_type=args.reg_mode, - ) - LOGGER.info( - f"Registration finish in {time.time() - load_res:0.4f} seconds!" - ) - else: +def hyvinn_preproc( + mode: ModalityMode, + reg_mode: RegistrationMode, + t1_path: Path, + t2_path: Path, + out_dir: Path, +) -> Path: + + if mode != "t1t2": + raise RuntimeError( + "hypvinn_preproc should only be called for t1t2 mode." + ) + if reg_mode != "none": + load_res = time.time() + # Print Warning if Resolution from both images is different + t1_zoom = nib.load(t1_path).header.get_zooms() + t2_zoom = nib.load(t2_path).header.get_zooms() + + if not np.allclose(np.array(t1_zoom), np.array(t2_zoom), rtol=0.05): LOGGER.info( - "Warning: No registration step, registering T1w and T2w is " - "required when running the multi-modal input mode.\n " - "No register images can generate wrong predictions. Omit this " - "message if input images are already registered." + f"Warning: Resolution from T1 ({t1_zoom}) and T2 " + f"({t2_zoom}) image are different.\n " + "Resolution of the T2 image will be interpolated " + "to the one of the T1 image." ) - LOGGER.info("---" * 30) + LOGGER.info("Registering T1 to T2 ...") + t2_path = t1_to_t2_registration( + t1_path=t1_path, + t2_path=t2_path, + out_dir=out_dir, + registration_type=reg_mode, + ) + LOGGER.info( + f"Registration finish in {time.time() - load_res:0.4f} seconds!" + ) + else: + LOGGER.info( + "Warning: No registration step, registering T1w and T2w is " + "required when running the multi-modal input mode.\n " + "No register images can generate wrong predictions. Omit this " + "message if input images are already registered." + ) + + LOGGER.info("---" * 30) - return args + return t2_path