Skip to content

Commit

Permalink
Refactor and cleanup of hypvinn
Browse files Browse the repository at this point in the history
- rename mode "multi" to "t1t2"
- reformatting for line length
- replace single quotes to double quotes
- add typing information
- clean up docstrings
- replace os.path with pathlib.Path
  • Loading branch information
dkuegler committed Apr 18, 2024
1 parent d82f78d commit fa1c678
Show file tree
Hide file tree
Showing 16 changed files with 308 additions and 205 deletions.
5 changes: 2 additions & 3 deletions CerebNet/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from FastSurferCNN.utils.checkpoint import (
get_checkpoints,
load_checkpoint_config_defaults,
YAML_DEFAULT as CHECKPOINT_PATHS_FILE,
)
from FastSurferCNN.utils.common import assert_no_root, SubjectList
from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT

logger = logging.get_logger(__name__)
DEFAULT_CEREBELLUM_STATSFILE = Path("stats/cerebellum.CerebNet.stats")
CHECKPOINT_PATHS_FILE = FASTSURFER_ROOT / "CerebNet/config/checkpoint_paths.yaml"


def setup_options():
Expand Down Expand Up @@ -90,7 +89,7 @@ def setup_options():
advanced,
"checkpoint",
files,
CHECKPOINT_PATHS_FILE
CHECKPOINT_PATHS_FILE,
)

parser.add_argument(
Expand Down
19 changes: 12 additions & 7 deletions FastSurferCNN/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

# IMPORTS
import os
from functools import lru_cache
from pathlib import Path
from typing import MutableSequence, Optional, Union, Literal, TypedDict, overload
from typing import MutableSequence, Optional, Union, Literal, TypedDict, cast, overload

import requests
import torch
Expand All @@ -33,11 +34,15 @@


class CheckpointConfigDict(TypedDict, total=False):
URL: list[str]
CKPT: dict[Plane, Path]
CFG: dict[Plane, Path]
url: list[str]
checkpoint: dict[Plane, Path]
config: dict[Plane, Path]


CheckpointConfigFields = Literal["checkpoint", "config", "url"]


@lru_cache
def load_checkpoint_config(filename: Path | str = YAML_DEFAULT) -> CheckpointConfigDict:
"""
Load the plane dictionary from the yaml file.
Expand Down Expand Up @@ -90,15 +95,15 @@ def load_checkpoint_config_defaults(


def load_checkpoint_config_defaults(
configtype: Literal["checkpoint", "config", "url"],
configtype: CheckpointConfigFields,
filename: str | Path = YAML_DEFAULT,
) -> dict[Plane, Path] | list[str]:
"""
Get the default value for a specific plane or the url.
Parameters
----------
configtype : "checkpoint", "config", "url
configtype : "checkpoint", "config", "url"
Type of value.
filename : str, Path
The path to the yaml file. Either absolute or relative to the FastSurfer root
Expand All @@ -112,7 +117,7 @@ def load_checkpoint_config_defaults(
if not isinstance(filename, Path):
filename = Path(filename)

configtype = configtype.lower()
configtype = cast(CheckpointConfigFields, configtype.lower())
if configtype not in ("url", "checkpoint", "config"):
raise ValueError("Type must be 'url', 'checkpoint' or 'config'")

Expand Down
4 changes: 2 additions & 2 deletions FastSurferCNN/utils/parser_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,11 @@ def add_arguments(parser: T_AddArgs, flags: Iterable[str]) -> T_AddArgs:


def add_plane_flags(
parser: argparse.ArgumentParser,
parser: T_AddArgs,
configtype: Literal["checkpoint", "config"],
files: Mapping[Plane, Path | str],
defaults_path: Path | str,
) -> argparse.ArgumentParser:
) -> T_AddArgs:
"""
Add plane arguments.
Expand Down
2 changes: 1 addition & 1 deletion HypVINN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Run the HypVINN/run_pipeline.py which has the following arguments:
* `--t2 </dir/T2**.nii.gz>` : T2 image path
* `--seg_log` : Path to file in which run logs will be saved. If not set logs will be stored in /sd/sid/logs/hypvinn_seg.log
### Image processing options
* `--no_reg` : Deactivate registration of T2 to T1. If multi modal input is used; images need to be registered externally,
* `--no_reg` : Deactivate registration of T2 to T1. If multi-modal input is used; images need to be registered externally,
* `--reg_mode` : Freesurfer Registration type to run. coreg : mri_coreg (Default) or robust : mri_robust_register.
* `--qc_snap`: Activate the creation of QC snapshots of the predicted HypVINN segmentation.
### FastSurfer Technical parameters (see FastSurfer documentation)
Expand Down
3 changes: 1 addition & 2 deletions HypVINN/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@
"inference",
"run_prediction",
"run_pipeline",
"run_prepoc"
]
]
2 changes: 1 addition & 1 deletion HypVINN/config/HypVINN_axial_v1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MODEL:
OUT_TENSOR_HEIGHT: 320
HEIGHT: 256
WIDTH: 256
MODE : 'multi'
MODE : 't1t2'
MULTI_AUTO_W : True
HETERO_INPUT : True

Expand Down
2 changes: 1 addition & 1 deletion HypVINN/config/HypVINN_coronal_v1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MODEL:
OUT_TENSOR_HEIGHT: 320
HEIGHT: 256
WIDTH: 256
MODE : 'multi'
MODE : 't1t2'
MULTI_AUTO_W : True
HETERO_INPUT : True

Expand Down
2 changes: 1 addition & 1 deletion HypVINN/config/HypVINN_sagittal_v1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MODEL:
OUT_TENSOR_HEIGHT: 320
HEIGHT: 256
WIDTH: 256
MODE : 'multi'
MODE : 't1t2'
MULTI_AUTO_W : True
HETERO_INPUT : True

Expand Down
4 changes: 2 additions & 2 deletions HypVINN/config/hypvinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
# Name of model
_C.MODEL.MODEL_NAME = ""

#modalities 't1', 't2' or multi
_C.MODEL.MODE ='t1'
#modalities 't1', 't2' or 't1t2'
_C.MODEL.MODE = "t1"

# Number of classes to predict, including background
_C.MODEL.NUM_CLASSES = 79
Expand Down
71 changes: 51 additions & 20 deletions HypVINN/data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
from numpy import typing as npt
import torch
from torch.utils.data import Dataset

Expand All @@ -21,9 +22,11 @@
from FastSurferCNN.data_loader.data_utils import get_thick_slices

import FastSurferCNN.utils.logging as logging
from HypVINN.utils import ModalityDict, ModalityMode

logger = logging.get_logger(__name__)


# Operator to load imaged for inference
class HypoVINN_dataset(Dataset):
"""
Expand All @@ -32,33 +35,41 @@ class HypoVINN_dataset(Dataset):
The Weight factor determines the running mode of the HypVINN model
if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1)
if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2)
if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based on the learned modality weights (mode = multi)
if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based on the learned modality weights (mode = t1t2)
"""
def __init__(self, subject_name, modalities, orig_zoom, cfg, mode='multi', transforms=None):
def __init__(
self,
subject_name: str,
modalities: ModalityDict,
orig_zoom: npt.NDArray[float],
cfg,
mode: ModalityMode = "t1t2",
transforms=None,
):
self.subject_name = subject_name
self.plane = cfg.DATA.PLANE
#Inference Mode
self.mode = mode
#set thickness base on train paramters
if cfg.MODEL.MODE in ['t1','t2']:
if cfg.MODEL.MODE in ["t1", "t2"]:
self.slice_thickness = cfg.MODEL.NUM_CHANNELS//2
else:
self.slice_thickness = cfg.MODEL.NUM_CHANNELS//4

self.base_res = cfg.MODEL.BASE_RES

if self.mode == 't1':
orig_thick = self._standarized_img(modalities['t1'],orig_zoom, modalitie='t1')
if self.mode == "t1":
orig_thick = self._standarized_img(modalities["t1"], orig_zoom, modality="t1")
orig_thick = np.concatenate((orig_thick, orig_thick), axis=-1)
self.weight_factor = torch.from_numpy(np.asarray([1.0, 0.0]))

elif self.mode == 't2':
orig_thick = self._standarized_img(modalities['t2'],orig_zoom, modalitie='t2')
elif self.mode == "t2":
orig_thick = self._standarized_img(modalities["t2"], orig_zoom, modality="t2")
orig_thick = np.concatenate((orig_thick, orig_thick), axis=-1)
self.weight_factor = torch.from_numpy(np.asarray([0.0, 1.0]))
else:
t1_orig_thick = self._standarized_img(modalities['t1'], orig_zoom, modalitie='t1')
t2_orig_thick = self._standarized_img(modalities['t2'],orig_zoom, modalitie='t2')
t1_orig_thick = self._standarized_img(modalities["t1"], orig_zoom, modality="t1")
t2_orig_thick = self._standarized_img(modalities["t2"], orig_zoom, modality="t2")
orig_thick = np.concatenate((t1_orig_thick, t2_orig_thick), axis=-1)
self.weight_factor = torch.from_numpy(np.asarray([0.5, 0.5]))

Expand All @@ -68,33 +79,49 @@ def __init__(self, subject_name, modalities, orig_zoom, cfg, mode='multi', trans
self.count = self.images.shape[0]
self.transforms = transforms

logger.info(f"Successfully loaded Image from {subject_name} for {self.plane} model")
logger.info(
f"Successfully loaded Image from {subject_name} for {self.plane} "
f"model"
)

if (cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 'multi' or cfg.MODEL.DUPLICATE_INPUT) :
logger.info(f"For inference T1 block weight and the T2 block are set to the weights learn during training")
if (cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT) :
logger.info(
f"For inference T1 block weight and the T2 block are set to "
f"the weights learn during training"
)
else:
logger.info(f"For inference T1 block weight was set to : {self.weight_factor.numpy()[0]} and the T2 block was set to: {self.weight_factor.numpy()[1]}")
logger.info(
f"For inference T1 block weight was set to: "
f"{self.weight_factor.numpy()[0]} and the T2 block was set to: "
f"{self.weight_factor.numpy()[1]}")

def _standarized_img(self,orig_data,orig_zoom,modalitie):
def _standarized_img(self, orig_data, orig_zoom, modality):
if self.plane == "sagittal":
orig_data = transform_axial2sagittal(orig_data)
self.zoom = orig_zoom[::-1][:2]
logger.info("Loading {} sagittal with input voxelsize {}".format(modalitie,self.zoom))
logger.info(
f"Loading {modality} sagittal with input voxelsize {self.zoom}"
)

elif self.plane == "coronal":
orig_data = transform_axial2coronal(orig_data)
self.zoom = orig_zoom[1:]
logger.info("Loading {} coronal with input voxelsize {}".format(modalitie,self.zoom))
logger.info(
f"Loading {modality} coronal with input voxelsize {self.zoom}"
)

else:
self.zoom = orig_zoom[:2]
logger.info("Loading {} axial with input voxelsize {}".format(modalitie,self.zoom))
logger.info(
f"Loading {modality} axial with input voxelsize {self.zoom}"
)

# Create thick slices
orig_thick = get_thick_slices(orig_data, self.slice_thickness)

return orig_thick
def _get_scale_factor(self):

def _get_scale_factor(self) -> npt.NDArray[float]:
"""
Get scaling factor to match original resolution of input image to
final resolution of FastSurfer base network. Input resolution is
Expand All @@ -108,14 +135,18 @@ def _get_scale_factor(self):

return scale

def __getitem__(self, index):
def __getitem__(self, index: int) -> dict[str, torch.Tensor | np.ndarray]:
img = self.images[index]

scale_factor = self._get_scale_factor()
if self.transforms is not None:
img = self.transforms(img)

return {'image': img, 'scale_factor': scale_factor,'weight_factor' : self.weight_factor}
return {
"image": img,
"scale_factor": scale_factor,
"weight_factor": self.weight_factor,
}

def __len__(self):
return self.count
Expand Down
2 changes: 0 additions & 2 deletions HypVINN/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import numpy as np




class HypVINN(FastSurferCNNBase):
"""
Construct HypVINN object.
Expand Down
Loading

0 comments on commit fa1c678

Please sign in to comment.