Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
m-reuter committed Jun 19, 2024
2 parents 1a38900 + d4767d9 commit 198a2ff
Show file tree
Hide file tree
Showing 37 changed files with 633 additions and 608 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name: deploy-docker

on:
release:
types:
- published
# release:
# types:
# - published
workflow_dispatch:

jobs:
deploy-gpu:
Expand Down
23 changes: 21 additions & 2 deletions CerebNet/datasets/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,26 @@ def _process_segm_volumes(

def _load_volumes(self, subject_path, store_talairach=False):
"""
[MISSING].
Loads the original image and cerebellum sub-segmentation from the given subject path.
Also loads the Talairach coordinates if store_talairach is set to True.
Parameters
----------
subject_path : str
The path to the subject's data directory.
store_talairach : bool, default=False
If True, the method will attempt to load the Talairach coordinates. Defaults to False.
Returns
-------
orig : np.ndarray
The original image.
cereb_subseg : np.ndarray
The cerebellum sub-segmentation loaded from the subject's data directory.
img_meta_data : dict
Dictionary containing the affine transformation and header from cereb_subseg file.
If store_talairach is True and Talairach coordinates file exists, also contains the
Talairach coordinates.
"""
orig_path = join(subject_path, self.cfg.IMAGE_NAME)
subseg_path = join(subject_path, self.cfg.CEREB_SUBSEG_NAME)
Expand Down Expand Up @@ -181,7 +200,7 @@ def load_subject(self, current_subject, store_talairach=False, load_aux_data=Fal
Parameters
----------
current_subject : [MISSING]
current_subject : str
Subject ID.
store_talairach : bool, optional
Whether to store Talairach coordinates. Defaults to False.
Expand Down
12 changes: 6 additions & 6 deletions CerebNet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,15 +634,15 @@ def _crop_transform_pad_fn(image, pad_tuples, pad):
Parameters
----------
image : [MISSING]
[MISSING].
pad_tuples : [MISSING]
[MISSING].
image : np.ndarray, torch.Tensor
Input image.
pad_tuples : List[Tuple[int, int]]
List of padding tuples for each axis.
Returns
-------
[MISSING TYPE]
[MISSING Discription].
partial
A partial function to pad the image.
"""
if all(p1 == 0 and p2 == 0 for p1, p2 in pad_tuples):
return None
Expand Down
134 changes: 123 additions & 11 deletions CerebNet/utils/meters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class TestMeter:
def __init__(self, classname_to_ids):
"""
Constructor function.
Parameters
----------
classname_to_ids : dict
Dictionary containing class names and their corresponding ids.
"""
# class_id: class_name
self.classname_to_ids = classname_to_ids
Expand All @@ -46,7 +51,20 @@ def __init__(self, classname_to_ids):

def _compute_hd(self, pred_bin, gt_bin):
"""
[MISSING].
Compute the Hausdorff Distance (HD) between the predicted binary segmentation map
and the ground truth binary segmentation map.
Parameters
----------
pred_bin : np.array
Predicted binary segmentation map.
gt_bin : np.array
Ground truth binary segmentation map.
Returns
-------
hd_dict : dict
Dictionary containing the maximum HD and 95th percentile HD.
"""
hd_dict = {}
if np.count_nonzero(pred_bin) == 0:
Expand All @@ -61,14 +79,38 @@ def _compute_hd(self, pred_bin, gt_bin):

def _get_binray_map(self, lbl_map, class_names):
"""
[MISSING].
Generate binary map based on the label map and class names.
Parameters
----------
lbl_map : np.array
Label map where each pixel/voxel is assigned a class label.
class_names : list
List of class names to be considered in the binary map.
Returns
-------
bin_map : np.array
Binary map where True represents class and False represents its absence.
"""
bin_map = np.logical_or.reduce(list(map(lambda l: lbl_map == l, class_names)))
return bin_map

def metrics_per_class(self, pred, gt):
"""
[MISSING].
Compute metrics for each class in the predicted and ground truth segmentation maps.
Parameters
----------
pred : np.array
Predicted segmentation map.
gt : np.array
Ground truth segmentation map.
Returns
-------
metrics : dict
Dict containing metrics for each class.
"""
metrics = {"Label": [], "Dice": [], "HD95": [], "HD_Max": [], "VS": []}
for lbl_name, lbl_id in self.classname_to_ids.items():
Expand Down Expand Up @@ -116,8 +158,27 @@ def __init__(
device=None,
writer=None,
):
""""
"""
Constructor function.
Parameters
----------
cfg : object
Configuration object containing all the configuration parameters.
mode : str
Mode of operation ("Train" or "Val").
global_step : int
The global step count.
total_iter : int, optional
Total number of iterations.
total_epoch : int, optional
Total number of epochs.
class_names : list, optional
List of class names.
device : str, optional
Device to be used for computation.
writer : object, optional
Writer object for tensorboard.
"""
self._cfg = cfg
self.mode = mode.capitalize()
Expand All @@ -144,6 +205,15 @@ def reset(self):
def update_stats(self, pred, labels, loss_dict=None):
"""
Update stats.
Parameters
----------
pred : torch.Tensor
Predicted labels.
labels : torch.Tensor
Ground truth labels.
loss_dict : dict, optional
Dictionary containing loss values.
"""
self.dice_score.update((pred, labels))
if loss_dict is None:
Expand All @@ -154,18 +224,34 @@ def update_stats(self, pred, labels, loss_dict=None):
def write_summary(self, loss_dict):
"""
Write summary.
Parameters
----------
loss_dict : dict
Dictionary containing loss values.
"""
if self.writer is None:
return
for name, loss in loss_dict.items():
self.writer.add_scalar(f"{self.mode}/{name}", loss.item(), self.global_iter)
self.global_iter += 1

def prediction_visualize(
self, cur_iter, cur_epoch, img_batch, label_batch, pred_batch
):
def prediction_visualize(self, cur_iter, cur_epoch, img_batch, label_batch, pred_batch):
"""
[MISSING].
Visualize prediction results for current iteration and epoch.
Parameters
----------
cur_iter : int
Current iteration number.
cur_epoch : int
Current epoch number.
img_batch : torch.Tensor
Input image batch.
label_batch : torch.Tensor
Ground truth label batch.
pred_batch : torch.Tensor
Predicted label batch.
"""
if self.writer is None:
return
Expand All @@ -179,7 +265,14 @@ def prediction_visualize(

def log_iter(self, cur_iter, cur_epoch):
"""
[MISSING].
Log training or validation progress at each iteration.
Parameters
----------
cur_iter : int
The current iteration number.
cur_epoch : int
The current epoch number.
"""
if (cur_iter + 1) % self._cfg.TRAIN.LOG_INTERVAL == 0:
out_losses = {}
Expand All @@ -203,15 +296,34 @@ def log_iter(self, cur_iter, cur_epoch):

def log_lr(self, lr, step=None):
"""
[MISSING].
Log learning rate at each step.
Parameters
----------
lr : list
Learning rate at the current step. Expected to be a list where the first
element is the learning rate.
step : int, optional
Current step number. If not provided, the global iteration
number is used.
"""
if step is None:
step = self.global_iter
self.writer.add_scalar("Train/lr", lr[0], step)

def log_epoch(self, cur_epoch):
"""
[MISSING].
Log mean Dice score and confusion matrix at the end of each epoch.
Parameters
----------
cur_epoch : int
Current epoch number.
Returns
-------
dice_score : float
The mean Dice score for the non-background classes.
"""
dice_score_per_class, confusion_mat = self.dice_score.compute(per_class=True)
dice_score = dice_score_per_class[1:].mean()
Expand Down
22 changes: 10 additions & 12 deletions CerebNet/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(

def reset(self):
"""
[MISSING].
Reset the state of the object.
"""
self.union = torch.zeros(self.n_classes, self.n_classes)
self.intersection = torch.zeros(self.n_classes, self.n_classes)
Expand All @@ -84,7 +84,7 @@ def _check_output_type(self, output):
"""
if not (isinstance(output, tuple)):
raise TypeError(
"Output should a tuple consist of of torch.Tensors, but given {}".format(
"Output should be a tuple consisting of torch.Tensors, but given {}".format(
type(output)
)
)
Expand All @@ -93,15 +93,13 @@ def _update_union_intersection(self, batch_output, labels_batch):
"""
Update the union and intersection matrices based on batch predictions and labels.
[MISSING DESCRIPTION]
Parameters:
-----------
batch_output : [MISSING TYPE]
[MISSING DESCRIPTION]
batch_output : torch.Tensor
Batch predictions from the model.
labels_batch : [MISSING TYPE]
[MISSING DESCRIPTION]
labels_batch : np.ndarray or torch.Tensor
Batch labels from the dataset.
"""
# self.union.to(batch_output.device)
# self.intersection.to(batch_output.device)
Expand All @@ -120,8 +118,8 @@ def update(self, output):
Parameters
----------
output : [MISSING]
[MISSING DESCRIPTION].
output : tuple of torch.Tensor
Tuple of predictions and labels.
"""
self._check_output_type(output)

Expand All @@ -144,7 +142,7 @@ def update(self, output):

def compute(self, per_class=False, class_idxs=None):
"""
[MISSING].
Compute the Dice score.
"""
dice_cm_mat = self._dice_confusion_matrix(class_idxs)
dice_score_per_class = dice_cm_mat.diagonal()
Expand All @@ -156,7 +154,7 @@ def compute(self, per_class=False, class_idxs=None):

def _dice_confusion_matrix(self, class_idxs):
"""
[MISSING].
Compute the Dice score confusion matrix.
"""
dice_intersection = self.intersection.cpu().numpy()
dice_union = self.union.cpu().numpy()
Expand Down
Loading

0 comments on commit 198a2ff

Please sign in to comment.