Skip to content

Commit

Permalink
doc/overview/OUTPUT_FILES.md
Browse files Browse the repository at this point in the history
- remove/rename changed filenames
- format table

HypVINN/config/checkpoint_paths.yaml
- add config

HypVINN/data_loader/data_utils.py
- fix typing and formatting

HypVINN/utils/checkpoint.py
- fix YAML_DEFAULT

HypVINN/utils/mode_config.py
- set default values for get_hypvinn_mode

HypVINN/inference.py
- fix inclusion of ModalityMode

HypVINN/run_prediction.py
- move HypVINN/run_pipeline.py into run_prediction.py
- fix typing, e.g. FileBasedHeader
- fix function parameters
- add help text to hypo_segfile argument
- fix passing of t1_path and t2_path
- various other changes
  • Loading branch information
dkuegler committed Apr 20, 2024
1 parent b168655 commit f913211
Show file tree
Hide file tree
Showing 11 changed files with 527 additions and 481 deletions.
10 changes: 5 additions & 5 deletions HypVINN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Hypothalamic subfields segmentation pipeline
2. Hypothalamus Segmentation

### Running the tool
Run the HypVINN/run_pipeline.py which has the following arguments:
Run the HypVINN/run_prediction.py which has the following arguments:
### Input and output arguments
* `--sid <name>` : Subject ID, the subject data upon which to operate
* `--sd <name>` : Directory in which evaluation results should be written.
Expand Down Expand Up @@ -54,7 +54,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip

1. Run full pipeline
```
python HypVINN/run_pipeline.py --sid test_subject --sd /output \
python HypVINN/run_prediction.py --sid test_subject --sd /output \
--t1 /data/test_subject_t1.nii.gz \
--t2 /data/test_subject_t2.nii.gz \
--reg_mode coreg \
Expand All @@ -63,7 +63,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip
```
2. Run full pipeline only using a t1
```
python HypVINN/run_pipeline.py --sid test_subject --sd /output \
python HypVINN/run_prediction.py --sid test_subject --sd /output \
--t1 /data/test_subject_t1.nii.gz \
--reg_mode coreg \
--seg_log /outdir/test_subject.log \
Expand All @@ -72,7 +72,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip

3. Run pipeline without the registration step
```
python HypVINN/run_pipeline.py --sid test_subject --sd /output \
python HypVINN/run_prediction.py --sid test_subject --sd /output \
--t1 /data/test_subject_t1.nii.gz \
--t2 /data/test_subject_t2.nii.gz \
--reg_mode coreg \
Expand All @@ -82,7 +82,7 @@ The pipeline can do all pre-processing by itself (step 1). This step can be skip

4. Run pipeline with creation of qc snapshots
```
python HypVINN/run_pipeline.py --sid test_subject --sd /output \
python HypVINN/run_prediction.py --sid test_subject --sd /output \
--t1 /data/test_subject_t1.nii.gz \
--t2 /data/test_subject_t2.nii.gz \
--reg_mode coreg \
Expand Down
1 change: 0 additions & 1 deletion HypVINN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@
"utils",
"inference",
"run_prediction",
"run_pipeline",
]
5 changes: 5 additions & 0 deletions HypVINN/config/checkpoint_paths.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ checkpoint:
axial: "checkpoints/HypVINN_axial_v1.0.0.pkl"
coronal: "checkpoints/HypVINN_coronal_v1.0.0.pkl"
sagittal: "checkpoints/HypVINN_sagittal_v1.0.0.pkl"

config:
axial: "HypVINN/config/HypVINN_axial_v1.0.0.yaml"
coronal: "HypVINN/config/HypVINN_coronal_v1.0.0.yaml"
sagittal: "HypVINN/config/HypVINN_sagittal_v1.0.0.yaml"
103 changes: 36 additions & 67 deletions HypVINN/data_loader/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,25 @@
# limitations under the License.

# IMPORTS
import nibabel as nib
import numpy as np
from numpy import typing as npt

from FastSurferCNN.data_loader.conform import getscale, scalecrop
import nibabel as nib
import sys
from HypVINN.config.hypvinn_global_var import hyposubseg_labels, SAG2FULL_MAP, HYPVINN_CLASS_NAMES, FS_CLASS_NAMES


##
# Helper Functions
##
def calculate_flip_orientation(iornt,base_ornt):


def calculate_flip_orientation(iornt, base_ornt):
"""
Compute the flip orientation transform.
ornt[N, 1] is flip of axis N, where 1 means no flip and -1 means flip.
Parameters
----------
iornt
Expand All @@ -47,9 +53,9 @@ def calculate_flip_orientation(iornt,base_ornt):

return new_iornt

# reorient image based on base image
def reorient_img(img,ref_img):
'''

def reorient_img(img, ref_img):
"""
Function to reorient a Nibabel image based on the orientation of a reference nibabel image
The orientation transform. ornt[N,1]` is flip of axis N of the array implied by `shape`, where 1 means no flip and -1 means flip.
For example, if ``N==0 and ornt[0,1] == -1, and there’s an array arr of shape shape, the flip would correspond to the effect of
Expand All @@ -63,7 +69,7 @@ def reorient_img(img,ref_img):
Returns
-------
'''
"""

ref_ornt =nib.io_orientation(ref_img.affine)
iornt=nib.io_orientation(img.affine)
Expand All @@ -79,8 +85,7 @@ def reorient_img(img,ref_img):

return img

# Transformation for mapping
#TODO check compatibility with axis transform from CerebNet

def transform_axial2coronal(vol, axial2coronal=True):
"""
Function to transform volume into coronal axis and back
Expand All @@ -89,11 +94,13 @@ def transform_axial2coronal(vol, axial2coronal=True):
transform from coronal to axial = False
:return:
"""
# TODO check compatibility with axis transform from CerebNet
if axial2coronal:
return np.moveaxis(vol, [0, 1, 2], [0, 2, 1])
else:
return np.moveaxis(vol, [0, 1, 2], [0, 2, 1])
#TODO check compatibility with axis transform from CerebNet


def transform_axial2sagittal(vol, axial2sagittal=True):
"""
Function to transform volume into Sagittal axis and back
Expand All @@ -102,15 +109,16 @@ def transform_axial2sagittal(vol, axial2sagittal=True):
transform from sagittal to coronal = False
:return:
"""
# TODO check compatibility with axis transform from CerebNet
if axial2sagittal:
return np.moveaxis(vol, [0, 1, 2], [2, 0, 1])
else:
return np.moveaxis(vol, [0, 1, 2], [1, 2, 0])


# Same as CerebNet.datasets.utils.rescale_image
def rescale_image(img_data):
# Conform intensities
# TODO move function into FastSurferCNN, same: CerebNet.datasets.utils.rescale_image
src_min, scale = getscale(img_data, 0, 255)
mapped_data = img_data
if not img_data.dtype == np.dtype(np.uint8):
Expand All @@ -121,80 +129,39 @@ def rescale_image(img_data):
return new_data



def hypo_map_subseg2label(subseg):
'''
Function to perform look-up table mapping from subseg space to label space
Parameters
----------
subseg
Returns
-------
'''

h, w, d = subseg.shape
lbls, lbls_sag = hyposubseg_labels

lut_subseg = np.zeros(max(lbls) + 1, dtype='int')
for idx, value in enumerate(lbls):
lut_subseg[value] = idx

mapped_subseg = lut_subseg.ravel()[subseg.ravel()]
mapped_subseg = mapped_subseg.reshape((h, w, d))


# mapping left labels to right labels for sagittal view
subseg[subseg == 2] = 1
subseg[subseg == 5] = 4
subseg[subseg == 6] = 3
subseg[subseg == 8] = 7
subseg[subseg == 12] = 11
subseg[subseg == 20] = 13
subseg[subseg == 24] = 23

subseg[subseg == 126] = 226
subseg[subseg == 127] = 227
subseg[subseg == 128] = 228
subseg[subseg == 129] = 229

lut_subseg_sag = np.zeros(max(lbls_sag) + 1, dtype='int')
for idx, value in enumerate(lbls_sag):
lut_subseg_sag[value] = idx

mapped_subseg_sag = lut_subseg_sag.ravel()[subseg.ravel()]

mapped_subseg_sag = mapped_subseg_sag.reshape((h, w, d))

return mapped_subseg,mapped_subseg_sag
def hypo_map_label2subseg(mapped_subseg):
'''
Function to perform look-up table mapping from label space to subseg space
'''
def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]:
"""
Function to perform look-up table mapping from label space to subseg space
"""
# TODO can this function be replaced by a Mapper and a mapping file?
labels, _ = hyposubseg_labels
subseg = np.zeros_like(mapped_subseg)
h, w, d = subseg.shape
subseg = labels[mapped_subseg.ravel()]

return subseg.reshape((h, w, d))

def hypo_map_prediction_sagittal2full(prediction_sag):

def hypo_map_prediction_sagittal2full(
prediction_sag: npt.NDArray[int],
) -> npt.NDArray[int]:
"""
Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks
:param prediction_sag: sagittal prediction (labels)
:param lbl_type: type of label
:return: Remapped prediction
"""
# TODO can this function be replaced by a Mapper and a mapping file?

idx_list = list(SAG2FULL_MAP.values())
prediction_full = prediction_sag[:, idx_list, :, :]
return prediction_full


def hypo_map_subseg_2_fsseg(subseg,reverse=False):
def hypo_map_subseg_2_fsseg(
subseg: npt.NDArray[int],
reverse: bool = False,
) -> npt.NDArray[int]:
"""
Function to remap HypVINN internal labels to FastSurfer Labels and viceversa
Parameters
Expand All @@ -206,7 +173,9 @@ def hypo_map_subseg_2_fsseg(subseg,reverse=False):
-------
"""
fsseg = np.zeros_like(subseg,dtype=np.int16)
# TODO can this function be replaced by a Mapper and a mapping file?

fsseg = np.zeros_like(subseg, dtype=np.int16)

if not reverse:
for value, name in HYPVINN_CLASS_NAMES.items():
Expand Down
21 changes: 14 additions & 7 deletions HypVINN/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@
from HypVINN.models.networks import build_model
from HypVINN.data_loader.data_utils import hypo_map_prediction_sagittal2full
from HypVINN.data_loader.dataset import HypoVINN_dataset
from HypVINN.run_prediction import ModalityMode
from HypVINN.utils import ModalityMode

logger = logging.get_logger(__name__)


class Inference:
def __init__(self, cfg, args):
def __init__(
self,
cfg,
threads: int = -1,
async_io: bool = False,
device: str = "auto",
viewagg_device: str = "auto",
):

self._threads = getattr(args, "threads", 1)
self._threads = threads
torch.set_num_threads(self._threads)
self._async_io = getattr(args, "async_io", False)
self._async_io = async_io

# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
Expand All @@ -49,16 +56,16 @@ def __init__(self, cfg, args):
torch.set_flush_denormal(True)

# Define device and transfer model
self.device = find_device(args.device)
self.device = find_device(device)

if self.device.type == "cpu" and args.viewagg_device == "auto":
if self.device.type == "cpu" and viewagg_device == "auto":
self.viewagg_device = self.device
else:
# check, if GPU is big enough to run view agg on it
# (this currently takes the memory of the passed device)
self.viewagg_device = torch.device(
find_device(
args.viewagg_device,
viewagg_device,
flag_name="viewagg_device",
min_memory=4 * (2 ** 30),
)
Expand Down
Loading

0 comments on commit f913211

Please sign in to comment.