Skip to content

Commit

Permalink
- Fix errors in t1_only mode
Browse files Browse the repository at this point in the history
- Fix errors in t1t2 mode
  • Loading branch information
dkuegler committed Apr 29, 2024
1 parent 4b798b4 commit 7b576d5
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 43 deletions.
35 changes: 30 additions & 5 deletions FastSurferCNN/utils/run_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from concurrent.futures import Executor, Future
from dataclasses import dataclass
from functools import partialmethod
from typing import Generator, Optional, Sequence
from typing import Generator, Optional, Sequence, Callable, Any, Collection, Iterable
from datetime import datetime

# TODO: python3.9+
# from collections.abc import Generator
Expand All @@ -17,13 +18,17 @@ class MessageBuffer:
out: bytes = b""
err: bytes = b""
retcode: Optional[int] = None
runtime: float = 0.

def __add__(self, other: "MessageBuffer") -> "MessageBuffer":
if not isinstance(other, MessageBuffer):
raise ValueError("Can only append another MessageBuffer!")

return MessageBuffer(
out=self.out + other.out, err=self.err + other.err, retcode=other.retcode
out=self.out + other.out,
err=self.err + other.err,
retcode=other.retcode,
runtime=max(self.runtime or 0.0, other.runtime or 0.0),
)

def __iadd__(self, other: "MessageBuffer"):
Expand All @@ -35,6 +40,7 @@ def __iadd__(self, other: "MessageBuffer"):
self.out += other.out
self.err += other.err
self.retcode = other.retcode
self.runtime = max(self.runtime or 0.0, other.runtime or 0.0)
return self

def out_str(self, encoding=None):
Expand All @@ -48,17 +54,24 @@ class Popen(subprocess.Popen):
"""
Extension of subprocess.Popen for convenience.
"""
_starttime: Optional[datetime] = None

def __init__(self, *args, **kwargs):
self._starttime = datetime.now()
super().__init__(*args, **kwargs)

def messages(self, timeout: float) -> Generator[MessageBuffer, None, None]:
from subprocess import TimeoutExpired

start = self._starttime or datetime.now()
while self.poll() is None:
try:
stdout, stderr = self.communicate(timeout=timeout)
yield MessageBuffer(
out=stdout if stdout else b"",
err=stderr if stderr else b"",
retcode=self.returncode,
runtime=(datetime.now() - start).total_seconds(),
)
except TimeoutExpired:
pass
Expand All @@ -70,15 +83,22 @@ def messages(self, timeout: float) -> Generator[MessageBuffer, None, None]:
b"" if self.stderr is None or self.stderr.closed else self.stderr.read()
)
if _stderr != b"" or _stdout != b"":
yield MessageBuffer(out=_stdout, err=_stderr, retcode=self.returncode)
yield MessageBuffer(
out=_stdout,
err=_stderr,
retcode=self.returncode,
runtime=(datetime.now() - start).total_seconds(),
)

def next_message(self, timeout: float) -> MessageBuffer:
start = self._starttime or datetime.now()
if self.poll() is None:
stdout, stderr = self.communicate(timeout=timeout)
return MessageBuffer(
out=stdout if stdout else b"",
err=stderr if stderr else b"",
retcode=self.returncode,
runtime=(datetime.now() - start).total_seconds(),
)

else:
Expand All @@ -89,7 +109,12 @@ def next_message(self, timeout: float) -> MessageBuffer:
b"" if self.stderr is None or self.stderr.closed else self.stderr.read()
)
if _stderr or _stdout:
return MessageBuffer(out=_stdout, err=_stderr, retcode=self.returncode)
return MessageBuffer(
out=_stdout,
err=_stderr,
retcode=self.returncode,
runtime=(datetime.now() - start).total_seconds(),
)
else:
raise StopIteration()

Expand Down Expand Up @@ -119,7 +144,7 @@ def finish(self, timeout: float = None) -> MessageBuffer:
self.wait(timeout)
except subprocess.TimeoutExpired:
self.terminate()
msg = MessageBuffer()
msg = MessageBuffer(runtime=0.0)
i = 0
for _msg in self.messages(timeout=0.25):
msg += _msg
Expand Down
64 changes: 41 additions & 23 deletions HypVINN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def option_parse() -> argparse.ArgumentParser:

# 1. Directory information (where to read from, where to write from and to incl. search-tag)
parser = parser_defaults.add_arguments(
parser, ["in_dir", "sd", "sid"],
parser, ["sd", "sid"],
)

parser = parser_defaults.add_arguments(parser, ["seg_log"])
Expand Down Expand Up @@ -173,7 +173,7 @@ def main(
cfg_ax: Path,
cfg_cor: Path,
cfg_sag: Path,
seg_file: str = HYPVINN_SEG_NAME,
hypo_segfile: str = HYPVINN_SEG_NAME,
allow_root: bool = False,
qc_snapshots: bool = False,
reg_mode: Literal["coreg", "robust", "none"] = "coreg",
Expand Down Expand Up @@ -231,7 +231,7 @@ def main(
)

# Create output directory if it does not already exist.
create_expand_output_directory(out_dir, qc_snapshots)
create_expand_output_directory(subject_dir, qc_snapshots)
logger.info(
f"Running HypVINN segmentation pipeline on subject {sid}"
)
Expand All @@ -250,7 +250,8 @@ def main(
hyvinn_preproc,
mode,
reg_mode,
out_dir=Path(out_dir),
subject_dir=Path(subject_dir),
threads=threads,
**kwargs,
)

Expand All @@ -264,7 +265,7 @@ def main(
for plane, _cfg_file, _ckpt_file in zip(PLANES, cfgs, ckpts):
logger.info(f"{plane} model configuration from {_cfg_file}")
view_ops[plane] = {
"cfg": set_up_cfgs(_cfg_file, out_dir, batch_size),
"cfg": set_up_cfgs(_cfg_file, subject_dir, batch_size),
"ckpt": _ckpt_file,
}

Expand Down Expand Up @@ -315,9 +316,8 @@ def main(
mode=mode,
)
logger.info(f"Model prediction finished in {time() - pred:0.4f} seconds")
logger.info(f"Saving prediction at {out_dir}")
logger.info(f"Saving results in {subject_dir}")

save = time()
if mode == 't1t2' or mode == 't1':
orig_path = t1_path
else:
Expand All @@ -329,38 +329,52 @@ def main(
orig_path=orig_path,
ras_affine=affine,
ras_header=header,
save_dir=out_dir,
seg_file=seg_file,
subject_dir=subject_dir,
seg_file=hypo_segfile,
save_mask=True,
)
save_future.add_done_callback(lambda x: logger.info(f"Prediction successfully saved as {x}"))
save_future.add_done_callback(
lambda x: logger.info(
f"Prediction successfully saved in {x.result()} seconds."
),
)
if qc_snapshots:
plot_qc_images(
save_dir=out_dir / "qc_snapshots",
qc_future: Optional[Future] = pool.submit(
plot_qc_images,
subject_qc_dir=subject_dir / "qc_snapshots",
orig_path=orig_path,
prediction_path=pred_path,
prediction_path=Path(hypo_segfile),
)
qc_future.add_done_callback(
lambda x: logger.info(f"QC snapshots saved in {x.result()} seconds."),
)
else:
qc_future = None

logger.info("Computing stats")
return_value = compute_stats(
orig_path=orig_path,
prediction_path=pred_path,
save_dir=out_dir / "stats",
prediction_path=Path(hypo_segfile),
stats_dir=subject_dir / "stats",
threads=threads,
)
if return_value != 0:
logger.error(return_value)

logger.info(
f"Processing segmentation finished in {time() - seg:0.4f} seconds"
f"Processing segmentation finished in {time() - seg:0.4f} seconds."
)
except (FileNotFoundError, RuntimeError) as e:
logger.info(f"Failed Evaluation on {subject_name}:")
logger.exception(e)
else:
if qc_future:
# finish qc
qc_future.result()
save_future.result()

logger.info(
f"Processing whole pipeline finished in {time() - start:.4f} "
f"seconds"
f"Processing whole pipeline finished in {time() - start:.4f} seconds."
)


Expand Down Expand Up @@ -460,21 +474,23 @@ def get_prediction(
out_scale=None,
mode: ModalityMode = "t1t2",
) -> npt.NDArray[int]:

# TODO There are probably several possibilities to accelerate this script.
# FastSurferVINN takes 7-8s vs. HypVINN 10+s per slicing direction.
# Solution: make this script/function more similar to the optimized FastSurferVINN
device, viewagg_device = model.get_device()
dim = model.get_max_size()

# Coronal model
logger.info(f"Evaluating Coronal model, cpkt: "
f"{view_opts['coronal']['ckpt']}")
model.set_model(view_opts["coronal"]["cfg"])
model.load_checkpoint(view_opts["coronal"]["ckpt"])

pred_shape = (dim, dim, dim, model.get_num_classes())
# Set up tensor to hold probabilities and run inference
pred_prob = torch.zeros(pred_shape, dtype=torch.float, device=viewagg_device)
for plane, opts in view_opts.items():
logger.info(f"Evaluating {plane} model, cpkt :{opts['ckpt']}")
model.set_cfg(opts["cfg"])
model.set_model(opts["cfg"])
model.load_checkpoint(opts["ckpt"])
pred_prob += model.run(subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode)

Expand Down Expand Up @@ -505,7 +521,7 @@ def set_up_cfgs(
batch_size: int = 1,
) -> "yacs.config.CfgNode":
cfg = load_config(cfg)
cfg.OUT_LOG_DIR = out_dir or cfg.LOG_DIR
cfg.OUT_LOG_DIR = str(out_dir or cfg.LOG_DIR)
cfg.TEST.BATCH_SIZE = batch_size

out_dims = cfg.DATA.PADDED_SIZE
Expand All @@ -522,7 +538,9 @@ def set_up_cfgs(
# arguments
parser = option_parse()
args = vars(parser.parse_args())
log_name = args["log_name"] or args["out_dir"] / "scripts" / "hypvinn_seg.log"
log_name = (args["log_name"] or
args["out_dir"] / args["sid"] / "scripts/hypvinn_seg.log")
del args["log_name"]

from FastSurferCNN.utils.logging import setup_logging
setup_logging(log_name)
Expand Down
11 changes: 7 additions & 4 deletions HypVINN/utils/img_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def save_segmentation(
orig_path: Path,
ras_affine: npt.NDArray[float],
ras_header,
save_dir: Path,
subject_dir: Path,
seg_file: Path,
save_mask: bool = False,
):
) -> float:
from time import time
starttime = time()
from HypVINN.data_loader.data_utils import reorient_img
from HypVINN.config.hypvinn_files import HYPVINN_MASK_NAME, HYPVINN_SEG_NAME

Expand All @@ -55,7 +57,7 @@ def save_segmentation(
LOGGER.info(
f"HypoVINN Mask after re-orientation: {img2axcodes(mask_img)}"
)
nib.save(mask_img, save_dir / "mri" / HYPVINN_MASK_NAME)
nib.save(mask_img, subject_dir / "mri" / HYPVINN_MASK_NAME)

pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header)
LOGGER.info(f"HypoVINN Prediction orientation: {img2axcodes(pred_img)}")
Expand All @@ -64,7 +66,8 @@ def save_segmentation(
f"HypoVINN Prediction after re-orientation: {img2axcodes(pred_img)}"
)
pred_img.set_data_dtype(np.int16) # Maximum value 939
nib.save(pred_img, save_dir / seg_file)
nib.save(pred_img, subject_dir / seg_file)
return time() - starttime


def save_logits(
Expand Down
Loading

0 comments on commit 7b576d5

Please sign in to comment.