Skip to content

Commit

Permalink
Add mask_file as a parameter (with default value) for saving instead …
Browse files Browse the repository at this point in the history
…of a constant

Make saving a main thread function instead of a future/task
  • Loading branch information
dkuegler committed Jun 20, 2024
1 parent 78b7a9a commit ca1d06a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
19 changes: 9 additions & 10 deletions HypVINN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from FastSurferCNN.utils.common import assert_no_root, SerialExecutor

from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME
from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME, HYPVINN_MASK_NAME
from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image
from HypVINN.inference import Inference
from HypVINN.utils import ModalityDict, ModalityMode, ViewOperations
Expand Down Expand Up @@ -173,6 +173,7 @@ def main(
cfg_cor: Path,
cfg_sag: Path,
hypo_segfile: str = HYPVINN_SEG_NAME,
hypo_maskfile: str = HYPVINN_MASK_NAME,
allow_root: bool = False,
qc_snapshots: bool = False,
reg_mode: Literal["coreg", "robust", "none"] = "coreg",
Expand Down Expand Up @@ -209,6 +210,8 @@ def main(
The path to the sagittal configuration file.
hypo_segfile : str, default="{HYPVINN_SEG_NAME}"
The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}.
hypo_maskfile : str, default="{HYPVINN_MASK_NAME}"
The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}.
allow_root : bool, default=False
Whether to allow running as root user. Default is False.
qc_snapshots : bool, optional
Expand All @@ -224,7 +227,8 @@ def main(
device : str, default="auto"
The device to use. Default is "auto", which automatically selects the device.
viewagg_device : str, default="auto"
The view aggregation device to use. Default is "auto", which automatically selects the device.
The view aggregation device to use. Default is "auto", which automatically
selects the device.
Returns
-------
Expand Down Expand Up @@ -360,21 +364,17 @@ def main(
else:
orig_path = t2_path

save_future: Future = pool.submit(
save_segmentation,
time_needed = save_segmentation(
pred_classes,
orig_path=orig_path,
ras_affine=affine,
ras_header=header,
subject_dir=subject_dir,
seg_file=hypo_segfile,
mask_file=hypo_maskfile,
save_mask=True,
)
save_future.add_done_callback(
lambda x: logger.info(
f"Prediction successfully saved in {x.result()} seconds."
),
)
logger.info(f"Prediction successfully saved in {time_needed} seconds.")
if qc_snapshots:
qc_future: Optional[Future] = pool.submit(
plot_qc_images,
Expand Down Expand Up @@ -408,7 +408,6 @@ def main(
if qc_future:
# finish qc
qc_future.result()
save_future.result()

logger.info(
f"Processing whole pipeline finished in {time() - start:.4f} seconds."
Expand Down
12 changes: 7 additions & 5 deletions HypVINN/utils/img_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def save_segmentation(
ras_affine: npt.NDArray[float],
ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader,
subject_dir: Path,
seg_file: Path,
seg_file: str,
mask_file: str,
save_mask: bool = False,
) -> float:
"""
Expand All @@ -73,7 +74,9 @@ def save_segmentation(
subject_dir : Path
The directory where the subject's data is stored.
seg_file : Path
The file where the segmentation results will be saved.
The file where the segmentation will be saved (relative to subject_dir/mri).
mask_file : str
The file where the mask will be saved (relative to subject_dir/mri).
save_mask : bool, default=False
Whether to save the mask or not. Default is False.
Expand All @@ -86,7 +89,6 @@ def save_segmentation(
from time import time
starttime = time()
from HypVINN.data_loader.data_utils import reorient_img
from HypVINN.config.hypvinn_files import HYPVINN_MASK_NAME, HYPVINN_SEG_NAME

pred_arr, labels_cc = get_clean_labels(np.array(prediction, dtype=np.uint8))
# Mapped HypVINN labelst to FreeSurfer Hypvinn Labels
Expand All @@ -101,7 +103,7 @@ def save_segmentation(
LOGGER.info(
f"HypoVINN Mask after re-orientation: {img2axcodes(mask_img)}"
)
nib.save(mask_img, subject_dir / "mri" / HYPVINN_MASK_NAME)
nib.save(mask_img, subject_dir / "mri" / mask_file)

pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header)
LOGGER.info(f"HypoVINN Prediction orientation: {img2axcodes(pred_img)}")
Expand All @@ -110,7 +112,7 @@ def save_segmentation(
f"HypoVINN Prediction after re-orientation: {img2axcodes(pred_img)}"
)
pred_img.set_data_dtype(np.int16) # Maximum value 939
nib.save(pred_img, subject_dir / seg_file)
nib.save(pred_img, subject_dir / "mri" / seg_file)
return time() - starttime


Expand Down

0 comments on commit ca1d06a

Please sign in to comment.