From 7b576d5af56a6ea2a0f8c93cf4a118ba4caa31ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 29 Apr 2024 15:37:20 +0200 Subject: [PATCH] - Fix errors in t1_only mode - Fix errors in t1t2 mode --- FastSurferCNN/utils/run_tools.py | 35 ++++++++++++--- HypVINN/run_prediction.py | 64 +++++++++++++++++---------- HypVINN/utils/img_processing_utils.py | 11 +++-- HypVINN/utils/preproc.py | 22 ++++++--- HypVINN/utils/stats_utils.py | 4 +- HypVINN/utils/visualization_utils.py | 6 +-- run_fastsurfer.sh | 2 +- 7 files changed, 101 insertions(+), 43 deletions(-) diff --git a/FastSurferCNN/utils/run_tools.py b/FastSurferCNN/utils/run_tools.py index 732f6d5b..2a15d6d2 100644 --- a/FastSurferCNN/utils/run_tools.py +++ b/FastSurferCNN/utils/run_tools.py @@ -2,7 +2,8 @@ from concurrent.futures import Executor, Future from dataclasses import dataclass from functools import partialmethod -from typing import Generator, Optional, Sequence +from typing import Generator, Optional, Sequence, Callable, Any, Collection, Iterable +from datetime import datetime # TODO: python3.9+ # from collections.abc import Generator @@ -17,13 +18,17 @@ class MessageBuffer: out: bytes = b"" err: bytes = b"" retcode: Optional[int] = None + runtime: float = 0. def __add__(self, other: "MessageBuffer") -> "MessageBuffer": if not isinstance(other, MessageBuffer): raise ValueError("Can only append another MessageBuffer!") return MessageBuffer( - out=self.out + other.out, err=self.err + other.err, retcode=other.retcode + out=self.out + other.out, + err=self.err + other.err, + retcode=other.retcode, + runtime=max(self.runtime or 0.0, other.runtime or 0.0), ) def __iadd__(self, other: "MessageBuffer"): @@ -35,6 +40,7 @@ def __iadd__(self, other: "MessageBuffer"): self.out += other.out self.err += other.err self.retcode = other.retcode + self.runtime = max(self.runtime or 0.0, other.runtime or 0.0) return self def out_str(self, encoding=None): @@ -48,10 +54,16 @@ class Popen(subprocess.Popen): """ Extension of subprocess.Popen for convenience. """ + _starttime: Optional[datetime] = None + + def __init__(self, *args, **kwargs): + self._starttime = datetime.now() + super().__init__(*args, **kwargs) def messages(self, timeout: float) -> Generator[MessageBuffer, None, None]: from subprocess import TimeoutExpired + start = self._starttime or datetime.now() while self.poll() is None: try: stdout, stderr = self.communicate(timeout=timeout) @@ -59,6 +71,7 @@ def messages(self, timeout: float) -> Generator[MessageBuffer, None, None]: out=stdout if stdout else b"", err=stderr if stderr else b"", retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), ) except TimeoutExpired: pass @@ -70,15 +83,22 @@ def messages(self, timeout: float) -> Generator[MessageBuffer, None, None]: b"" if self.stderr is None or self.stderr.closed else self.stderr.read() ) if _stderr != b"" or _stdout != b"": - yield MessageBuffer(out=_stdout, err=_stderr, retcode=self.returncode) + yield MessageBuffer( + out=_stdout, + err=_stderr, + retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), + ) def next_message(self, timeout: float) -> MessageBuffer: + start = self._starttime or datetime.now() if self.poll() is None: stdout, stderr = self.communicate(timeout=timeout) return MessageBuffer( out=stdout if stdout else b"", err=stderr if stderr else b"", retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), ) else: @@ -89,7 +109,12 @@ def next_message(self, timeout: float) -> MessageBuffer: b"" if self.stderr is None or self.stderr.closed else self.stderr.read() ) if _stderr or _stdout: - return MessageBuffer(out=_stdout, err=_stderr, retcode=self.returncode) + return MessageBuffer( + out=_stdout, + err=_stderr, + retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), + ) else: raise StopIteration() @@ -119,7 +144,7 @@ def finish(self, timeout: float = None) -> MessageBuffer: self.wait(timeout) except subprocess.TimeoutExpired: self.terminate() - msg = MessageBuffer() + msg = MessageBuffer(runtime=0.0) i = 0 for _msg in self.messages(timeout=0.25): msg += _msg diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index 6725541f..2d542019 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -89,7 +89,7 @@ def option_parse() -> argparse.ArgumentParser: # 1. Directory information (where to read from, where to write from and to incl. search-tag) parser = parser_defaults.add_arguments( - parser, ["in_dir", "sd", "sid"], + parser, ["sd", "sid"], ) parser = parser_defaults.add_arguments(parser, ["seg_log"]) @@ -173,7 +173,7 @@ def main( cfg_ax: Path, cfg_cor: Path, cfg_sag: Path, - seg_file: str = HYPVINN_SEG_NAME, + hypo_segfile: str = HYPVINN_SEG_NAME, allow_root: bool = False, qc_snapshots: bool = False, reg_mode: Literal["coreg", "robust", "none"] = "coreg", @@ -231,7 +231,7 @@ def main( ) # Create output directory if it does not already exist. - create_expand_output_directory(out_dir, qc_snapshots) + create_expand_output_directory(subject_dir, qc_snapshots) logger.info( f"Running HypVINN segmentation pipeline on subject {sid}" ) @@ -250,7 +250,8 @@ def main( hyvinn_preproc, mode, reg_mode, - out_dir=Path(out_dir), + subject_dir=Path(subject_dir), + threads=threads, **kwargs, ) @@ -264,7 +265,7 @@ def main( for plane, _cfg_file, _ckpt_file in zip(PLANES, cfgs, ckpts): logger.info(f"{plane} model configuration from {_cfg_file}") view_ops[plane] = { - "cfg": set_up_cfgs(_cfg_file, out_dir, batch_size), + "cfg": set_up_cfgs(_cfg_file, subject_dir, batch_size), "ckpt": _ckpt_file, } @@ -315,9 +316,8 @@ def main( mode=mode, ) logger.info(f"Model prediction finished in {time() - pred:0.4f} seconds") - logger.info(f"Saving prediction at {out_dir}") + logger.info(f"Saving results in {subject_dir}") - save = time() if mode == 't1t2' or mode == 't1': orig_path = t1_path else: @@ -329,38 +329,52 @@ def main( orig_path=orig_path, ras_affine=affine, ras_header=header, - save_dir=out_dir, - seg_file=seg_file, + subject_dir=subject_dir, + seg_file=hypo_segfile, save_mask=True, ) - save_future.add_done_callback(lambda x: logger.info(f"Prediction successfully saved as {x}")) + save_future.add_done_callback( + lambda x: logger.info( + f"Prediction successfully saved in {x.result()} seconds." + ), + ) if qc_snapshots: - plot_qc_images( - save_dir=out_dir / "qc_snapshots", + qc_future: Optional[Future] = pool.submit( + plot_qc_images, + subject_qc_dir=subject_dir / "qc_snapshots", orig_path=orig_path, - prediction_path=pred_path, + prediction_path=Path(hypo_segfile), + ) + qc_future.add_done_callback( + lambda x: logger.info(f"QC snapshots saved in {x.result()} seconds."), ) + else: + qc_future = None logger.info("Computing stats") return_value = compute_stats( orig_path=orig_path, - prediction_path=pred_path, - save_dir=out_dir / "stats", + prediction_path=Path(hypo_segfile), + stats_dir=subject_dir / "stats", threads=threads, ) if return_value != 0: logger.error(return_value) logger.info( - f"Processing segmentation finished in {time() - seg:0.4f} seconds" + f"Processing segmentation finished in {time() - seg:0.4f} seconds." ) except (FileNotFoundError, RuntimeError) as e: logger.info(f"Failed Evaluation on {subject_name}:") logger.exception(e) else: + if qc_future: + # finish qc + qc_future.result() + save_future.result() + logger.info( - f"Processing whole pipeline finished in {time() - start:.4f} " - f"seconds" + f"Processing whole pipeline finished in {time() - start:.4f} seconds." ) @@ -460,21 +474,23 @@ def get_prediction( out_scale=None, mode: ModalityMode = "t1t2", ) -> npt.NDArray[int]: + + # TODO There are probably several possibilities to accelerate this script. + # FastSurferVINN takes 7-8s vs. HypVINN 10+s per slicing direction. + # Solution: make this script/function more similar to the optimized FastSurferVINN device, viewagg_device = model.get_device() dim = model.get_max_size() # Coronal model logger.info(f"Evaluating Coronal model, cpkt: " f"{view_opts['coronal']['ckpt']}") - model.set_model(view_opts["coronal"]["cfg"]) - model.load_checkpoint(view_opts["coronal"]["ckpt"]) pred_shape = (dim, dim, dim, model.get_num_classes()) # Set up tensor to hold probabilities and run inference pred_prob = torch.zeros(pred_shape, dtype=torch.float, device=viewagg_device) for plane, opts in view_opts.items(): logger.info(f"Evaluating {plane} model, cpkt :{opts['ckpt']}") - model.set_cfg(opts["cfg"]) + model.set_model(opts["cfg"]) model.load_checkpoint(opts["ckpt"]) pred_prob += model.run(subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode) @@ -505,7 +521,7 @@ def set_up_cfgs( batch_size: int = 1, ) -> "yacs.config.CfgNode": cfg = load_config(cfg) - cfg.OUT_LOG_DIR = out_dir or cfg.LOG_DIR + cfg.OUT_LOG_DIR = str(out_dir or cfg.LOG_DIR) cfg.TEST.BATCH_SIZE = batch_size out_dims = cfg.DATA.PADDED_SIZE @@ -522,7 +538,9 @@ def set_up_cfgs( # arguments parser = option_parse() args = vars(parser.parse_args()) - log_name = args["log_name"] or args["out_dir"] / "scripts" / "hypvinn_seg.log" + log_name = (args["log_name"] or + args["out_dir"] / args["sid"] / "scripts/hypvinn_seg.log") + del args["log_name"] from FastSurferCNN.utils.logging import setup_logging setup_logging(log_name) diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index 5748c3bc..efcf97c1 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -35,10 +35,12 @@ def save_segmentation( orig_path: Path, ras_affine: npt.NDArray[float], ras_header, - save_dir: Path, + subject_dir: Path, seg_file: Path, save_mask: bool = False, -): +) -> float: + from time import time + starttime = time() from HypVINN.data_loader.data_utils import reorient_img from HypVINN.config.hypvinn_files import HYPVINN_MASK_NAME, HYPVINN_SEG_NAME @@ -55,7 +57,7 @@ def save_segmentation( LOGGER.info( f"HypoVINN Mask after re-orientation: {img2axcodes(mask_img)}" ) - nib.save(mask_img, save_dir / "mri" / HYPVINN_MASK_NAME) + nib.save(mask_img, subject_dir / "mri" / HYPVINN_MASK_NAME) pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header) LOGGER.info(f"HypoVINN Prediction orientation: {img2axcodes(pred_img)}") @@ -64,7 +66,8 @@ def save_segmentation( f"HypoVINN Prediction after re-orientation: {img2axcodes(pred_img)}" ) pred_img.set_data_dtype(np.int16) # Maximum value 939 - nib.save(pred_img, save_dir / seg_file) + nib.save(pred_img, subject_dir / seg_file) + return time() - starttime def save_logits( diff --git a/HypVINN/utils/preproc.py b/HypVINN/utils/preproc.py index d327a79b..25b36b6c 100644 --- a/HypVINN/utils/preproc.py +++ b/HypVINN/utils/preproc.py @@ -30,15 +30,19 @@ def t1_to_t2_registration( t1_path: Path, t2_path: Path, - out_dir: Path, + subject_dir: Path, registration_type: RegistrationMode = "coreg", + threads: int = -1, ) -> Path: from FastSurferCNN.utils.run_tools import Popen + from FastSurferCNN.utils.threads import get_num_threads import shutil - lta_path = out_dir / "mri/transforms/t2tot1.lta" + if threads <= 0: + threads = get_num_threads() - t2_reg_path = out_dir / "mri/T2_nu_reg.mgz" + lta_path = subject_dir / "mri/transforms/t2tot1.lta" + t2_reg_path = subject_dir / "mri/T2_nu_reg.mgz" if registration_type == "coreg": exe = shutil.which("mri_coreg") @@ -51,6 +55,7 @@ def t1_to_t2_registration( "FREESURFER_HOME environment variable" ) args = [exe, "--mov", t2_path, "--targ", t1_path, "--reg", lta_path] + args = list(map(str, args)) + ["--threads", str(threads)] LOGGER.info("Running " + " ".join(args)) retval = Popen(args).finish() if retval.retcode != 0: @@ -58,6 +63,7 @@ def t1_to_t2_registration( raise RuntimeError("mri_coreg failed registration") else: + LOGGER.info(f"{exe} finished in {retval.runtime}!") exe = shutil.which("mri_vol2vol") if not bool(exe): if os.environ.get("FREESURFER_HOME", ""): @@ -76,6 +82,7 @@ def t1_to_t2_registration( "--cubic", "--keep-precision", ] + args = list(map(str, args)) LOGGER.info("Running " + " ".join(args)) retval = Popen(args).finish() if retval.retcode != 0: @@ -83,6 +90,7 @@ def t1_to_t2_registration( f"mri_vol2vol failed with error code {retval.retcode}." ) raise RuntimeError("mri_vol2vol failed applying registration") + LOGGER.info(f"{exe} finished in {retval.runtime}!") else: exe = shutil.which("mri_robust_register") if not bool(exe): @@ -101,6 +109,7 @@ def t1_to_t2_registration( "--mapmov", t2_reg_path, "--cost NMI", ] + args = list(map(str, args)) LOGGER.info("Running " + " ".join(args)) retval = Popen(args).finish() if retval.retcode != 0: @@ -108,6 +117,7 @@ def t1_to_t2_registration( f"mri_robust_register failed with error code {retval.retcode}." ) raise RuntimeError("mri_robust_register failed registration") + LOGGER.info(f"{exe} finished in {retval.runtime}!") return t2_reg_path @@ -117,7 +127,8 @@ def hyvinn_preproc( reg_mode: RegistrationMode, t1_path: Path, t2_path: Path, - out_dir: Path, + subject_dir: Path, + threads: int = -1, ) -> Path: if mode != "t1t2": @@ -142,8 +153,9 @@ def hyvinn_preproc( t2_path = t1_to_t2_registration( t1_path=t1_path, t2_path=t2_path, - out_dir=out_dir, + subject_dir=subject_dir, registration_type=reg_mode, + threads=threads, ) LOGGER.info( f"Registration finish in {time.time() - load_res:0.4f} seconds!" diff --git a/HypVINN/utils/stats_utils.py b/HypVINN/utils/stats_utils.py index ef8a5e15..0d088617 100644 --- a/HypVINN/utils/stats_utils.py +++ b/HypVINN/utils/stats_utils.py @@ -18,7 +18,7 @@ def compute_stats( orig_path: Path, prediction_path: Path, - save_dir: Path, + stats_dir: Path, threads: int, ) -> int | str: from collections import namedtuple @@ -37,7 +37,7 @@ def compute_stats( args.normfile = orig_path args.segfile = prediction_path - args.segstatsfile = save_dir / HYPVINN_STATS_NAME + args.segstatsfile = stats_dir / HYPVINN_STATS_NAME args.excludeid = [0] args.ids = labels args.merged_labels = [] diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 0d6d9ab0..4bde6be6 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -156,7 +156,7 @@ def select_index_to_plot(hyposeg, slice_step=2): def plot_qc_images( - save_dir: Path, + subject_qc_dir: Path, orig_path: Path, prediction_path: Path, padd: int = 45, @@ -167,7 +167,7 @@ def plot_qc_images( from HypVINN.data_loader.data_utils import transform_axial2coronal, hypo_map_subseg_2_fsseg from HypVINN.config.hypvinn_files import HYPVINN_QC_IMAGE_NAME - save_dir.mkdir(exist_ok=True, parents=True) + subject_qc_dir.mkdir(exist_ok=True, parents=True) image = nib.as_closest_canonical(nib.load(orig_path)) pred = nib.as_closest_canonical(nib.load(prediction_path)) @@ -216,6 +216,6 @@ def plot_qc_images( img_per_row=crop_image.shape[0], ) - fig.savefig(save_dir / HYPVINN_QC_IMAGE_NAME, transparent=False) + fig.savefig(subject_qc_dir / HYPVINN_QC_IMAGE_NAME, transparent=False) plt.close(fig) diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index f7740639..cdf91cca 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -936,7 +936,7 @@ if [[ "$run_seg_pipeline" == "1" ]] cmd=($python "$hypvinndir/run_prediction.py" --sd "${sd}" --sid "${subject}" "${hypvinn_flags[@]}" "${allow_root[@]}" --threads "$threads" --async_io --batch_size "$batch_size" --seg_log "$seg_log" --device "$device" - --viewagg_device "$viewagg_device" --t1) + --viewagg_device "$viewagg" --t1) if [[ "$run_biasfield" == "1" ]] then cmd+=("$norm_name")