Skip to content

Commit

Permalink
update run_fastsurfer.sh for hypvinn
Browse files Browse the repository at this point in the history
formatting
  • Loading branch information
dkuegler committed Apr 18, 2024
1 parent f71d2fe commit 7565c68
Show file tree
Hide file tree
Showing 13 changed files with 730 additions and 435 deletions.
15 changes: 6 additions & 9 deletions HypVINN/config/hypvinn_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,15 @@
# limitations under the License.

# IMPORTS
import os
from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT

from FastSurferCNN.utils.checkpoint import (
FASTSURFER_ROOT)

HYPVINN_LUT = FASTSURFER_ROOT / "HypVINN/config/HypVINN_ColorLUT.txt"

HYPVINN_LUT = os.path.join(FASTSURFER_ROOT,'HypVINN','config','HypVINN_ColorLUT.txt')
HYPVINN_STATS_NAME = "hypothalamus.HypVINN.stats"

HYPVINN_STATS_NAME = 'hypothalamus.HypVINN.stats'
HYPVINN_MASK_NAME = "hypothalamus_mask.HypVINN.nii.gz"

HYPVINN_MASK_NAME = 'hypothalamus_mask.HypVINN.nii.gz'
HYPVINN_SEG_NAME = "hypothalamus.HypVINN.nii.gz"

HYPVINN_SEG_NAME = 'hypothalamus.HypVINN.nii.gz'

HYPVINN_QC_IMAGE_NAME = 'hypothalamus.HypVINN_qc_screenshoot.png'
HYPVINN_QC_IMAGE_NAME = "hypothalamus.HypVINN_qc_screenshoot.png"
19 changes: 12 additions & 7 deletions HypVINN/config/hypvinn_global_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

import numpy as np

Plane = Literal["axial", "coronal", "sagittal"]


HYPVINN_CLASS_NAMES = {
0: "Background",

Expand Down Expand Up @@ -57,7 +63,7 @@
}

FS_CLASS_NAMES = {
"Background" : 0,
"Background": 0,

"R-N.opticus": 961,
"L-N.opticus": 962,
Expand Down Expand Up @@ -88,12 +94,12 @@
"L-Globus-pallidus": 986,
}

planes = ("axial", "coronal", "sagittal")


hyposubseg_labels = (np.array(list(HYPVINN_CLASS_NAMES.keys())),
np.array([0, 1, 3, 4, 7, 9, 10,
11, 14, 16, 17, 13, 122,
226, 227, 228, 229]))
hyposubseg_labels = (
np.array(list(HYPVINN_CLASS_NAMES.keys())),
np.array([0, 1, 3, 4, 7, 9, 10, 11, 14, 16, 17, 13, 122, 226, 227, 228, 229]),
)

SAG2FULL_MAP = {
# lbl: sag_lbl_index
Expand Down Expand Up @@ -125,4 +131,3 @@
129: 16,
229: 16
}

78 changes: 54 additions & 24 deletions HypVINN/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from time import time
from typing import Optional

import torch
import numpy as np
import time
from tqdm import tqdm

from torch.utils.data import DataLoader
from torchvision import transforms
from HypVINN.models.networks import build_model

import FastSurferCNN.utils.logging as logging
from FastSurferCNN.utils.common import find_device
from FastSurferCNN.data_loader.augmentation import ToTensorTest, ZeroPad2DTest
from HypVINN.models.networks import build_model
from HypVINN.data_loader.data_utils import hypo_map_prediction_sagittal2full
from HypVINN.data_loader.dataset import HypoVINN_dataset
import FastSurferCNN.utils.logging as logging
from FastSurferCNN.utils.common import find_device
from HypVINN.run_prediction import ModalityMode

logger = logging.get_logger(__name__)


class Inference:
def __init__(self, cfg,args):
def __init__(self, cfg, args):

self._threads = getattr(args, "threads", 1)
torch.set_num_threads(self._threads)
Expand Down Expand Up @@ -66,9 +70,10 @@ def __init__(self, cfg,args):
self.model = self.setup_model(cfg)
self.model_name = self.cfg.MODEL.MODEL_NAME



def setup_model(self, cfg=None):
def setup_model(
self,
cfg: Optional["yacs.config.CfgNode"] = None,
) -> torch.nn.Module:
if cfg is not None:
self.cfg = cfg

Expand Down Expand Up @@ -124,16 +129,15 @@ def get_device(self):

#TODO check is possible to modify to CerebNet inference mode from RAS directly to LIA (CerebNet.Inference._predict_single_subject)
@torch.no_grad()
def eval(self, val_loader, pred_prob, out_scale=None):
def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale=None):
self.model.eval()

start_index = 0
for batch_idx, batch in tqdm(enumerate(val_loader),total=len(val_loader)):

images, scale_factors,weight_factors = (batch['image'].to(self.device),
batch['scale_factor'].to(self.device),
batch['weight_factor'].to(self.device))
for batch_idx, batch in tqdm(enumerate(val_loader), total=len(val_loader)):

images = batch["image"].to(self.device)
scale_factors = batch["scale_factor"].to(self.device)
weight_factors = batch["weight_factor"].to(self.device)

pred = self.model(images, scale_factors, weight_factors, out_scale)

Expand All @@ -156,18 +160,44 @@ def eval(self, val_loader, pred_prob, out_scale=None):

return pred_prob

def run(self, subject_name, modalities, orig_zoom, pred_prob, out_res=None,mode='multi'):
def run(
self,
subject_name: str,
modalities,
orig_zoom,
pred_prob,
out_res=None,
mode: ModalityMode = "t1t2",
):
# Set up DataLoader
test_dataset = HypoVINN_dataset(subject_name, modalities, orig_zoom, self.cfg, mode = mode,
transforms=transforms.Compose([ZeroPad2DTest((self.cfg.DATA.PADDED_SIZE, self.cfg.DATA.PADDED_SIZE)), ToTensorTest()]))

test_data_loader = DataLoader(dataset=test_dataset, shuffle=False,
batch_size=self.cfg.TEST.BATCH_SIZE)
test_dataset = HypoVINN_dataset(
subject_name,
modalities,
orig_zoom,
self.cfg,
mode=mode,
transforms=transforms.Compose(
[
ZeroPad2DTest(
(self.cfg.DATA.PADDED_SIZE, self.cfg.DATA.PADDED_SIZE),
),
ToTensorTest(),
],
),
)

test_data_loader = DataLoader(
dataset=test_dataset,
shuffle=False,
batch_size=self.cfg.TEST.BATCH_SIZE,
)

# Run evaluation
start = time.time()
start = time()
pred_prob = self.eval(test_data_loader, pred_prob, out_scale=out_res)
logger.info("{} Inference on {} finished in {:0.4f} seconds".format(self.cfg.DATA.PLANE,subject_name, time.time()-start))
logger.info(
f"{self.cfg.DATA.PLANE} Inference on {subject_name} finished in "
f"{time()-start:0.4f} seconds"
)

return pred_prob

17 changes: 5 additions & 12 deletions HypVINN/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@ def __init__(self, params, padded_size=256):
nn.init.constant_(m.bias, 0)

def forward(self, x, scale_factor, weight_factor, scale_factor_out=None):
"""
Computational graph
:param tensor x: input image
:return tensor: prediction logits
"""
# Weight factor [wT1,wT2] has 3 stages [1,0],[0.5,0.5],[0,1],
#if the weight factor is [0.5,0.5] the automatically weights (s_weights) are passed
#If there is a 1 in the comparison the automatically weights will be replace by the first weight_factors pass
Expand Down Expand Up @@ -161,13 +156,12 @@ def forward(self, x, scale_factor, weight_factor, scale_factor_out=None):
return logits



_MODELS = {
"HypVinn": HypVINN,
}


def build_model(cfg):
def build_model(cfg) -> HypVINN:
"""
Build requested model.
Expand All @@ -181,9 +175,8 @@ def build_model(cfg):
model
Object of the initialized model.
"""
assert (
cfg.MODEL.MODEL_NAME in _MODELS.keys()
), f"Model {cfg.MODEL.MODEL_NAME} not supported"
if cfg.MODEL.MODEL_NAME not in _MODELS:
raise AssertionError(f"Model {cfg.MODEL.MODEL_NAME} not supported")
params = {k.lower(): v for k, v in dict(cfg.MODEL).items()}
model = _MODELS[cfg.MODEL.MODEL_NAME](params, padded_size=cfg.DATA.PADDED_SIZE)
return model
model_type = _MODELS[cfg.MODEL.MODEL_NAME]
return model_type(params, padded_size=cfg.DATA.PADDED_SIZE)
Loading

0 comments on commit 7565c68

Please sign in to comment.