Skip to content

Commit

Permalink
Merge pull request #2 from taha-abdullah/hypvinn
Browse files Browse the repository at this point in the history
Documentation for HypVINN
  • Loading branch information
santiestrada32 committed Jun 5, 2024
2 parents 36380ee + 9a6822b commit 016a5ba
Show file tree
Hide file tree
Showing 15 changed files with 1,005 additions and 93 deletions.
2 changes: 1 addition & 1 deletion FastSurferCNN/utils/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import numpy as np
import pandas
import torch
from matplotlib.cm import get_cmap
from matplotlib.pyplot import get_cmap
from matplotlib.colors import Colormap
from numpy import typing as npt

Expand Down
9 changes: 8 additions & 1 deletion HypVINN/config/hypvinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,14 @@


def get_cfg_hypvinn():
"""Get a yacs CfgNode object with default values for my_project."""
"""
Get a yacs CfgNode object with default values for HypVINN project.
Returns
-------
_C : yacs.config.CfgNode
A clone of the default configuration node for the HypVINN project.
"""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()
143 changes: 109 additions & 34 deletions HypVINN/data_loader/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,23 @@
##


def calculate_flip_orientation(iornt, base_ornt):
def calculate_flip_orientation(iornt: np.ndarray, base_ornt: np.ndarray) -> np.ndarray:
"""
Compute the flip orientation transform.
ornt[N, 1] is flip of axis N, where 1 means no flip and -1 means flip.
Parameters
----------
iornt
base_ornt
iornt : np.ndarray
Initial orientation.
base_ornt : np.ndarray
Base orientation.
Returns
-------
new_iornt : np.ndarray
New orientation.
"""
new_iornt=iornt.copy()

Expand All @@ -56,21 +60,20 @@ def calculate_flip_orientation(iornt, base_ornt):

def reorient_img(img, ref_img):
"""
Function to reorient a Nibabel image based on the orientation of a reference nibabel image
The orientation transform. ornt[N,1]` is flip of axis N of the array implied by `shape`, where 1 means no flip and -1 means flip.
For example, if ``N==0 and ornt[0,1] == -1, and there’s an array arr of shape shape, the flip would correspond to the effect of
np.flipud(arr). ornt[:,0] is the transpose that needs to be done to the implied array, as in arr.transpose(ornt[:,0])
Reorient a Nibabel image based on the orientation of a reference nibabel image.
Parameters
----------
img: nibabel Image to reorient
base_img: referece orientation nibabel image
img : nibabel.Nifti1Image
Nibabel Image to reorient.
ref_img : nibabel.Nifti1Image
Reference orientation nibabel image.
Returns
-------
img : nibabel.Nifti1Image
Reoriented image.
"""

ref_ornt =nib.io_orientation(ref_img.affine)
iornt=nib.io_orientation(img.affine)

Expand All @@ -86,13 +89,26 @@ def reorient_img(img, ref_img):
return img


def transform_axial2coronal(vol, axial2coronal=True):
def transform_axial2coronal(vol: np.ndarray, axial2coronal: bool = True) -> np.ndarray:
"""
Function to transform volume into coronal axis and back
:param np.ndarray vol: image volume to transform
:param bool axial2coronal: transform from axial to coronal = True (default),
transform from coronal to axial = False
:return:
Transforms a volume into the coronal axis and back.
This function is used to transform a volume into the coronal axis and back. The transformation is done by moving
the axes of the volume. If the `axial2coronal` parameter is set to True, the function will transform from axial to
coronal. If it is set to False, the function will transform from coronal to axial.
Parameters
----------
vol : np.ndarray
The image volume to transform.
axial2coronal : bool, optional
A flag to determine the direction of the transformation. If True, transform from axial to coronal. If False,
transform from coronal to axial. (Default: True).
Returns
-------
np.ndarray
The transformed volume.
"""
# TODO check compatibility with axis transform from CerebNet
if axial2coronal:
Expand All @@ -101,13 +117,26 @@ def transform_axial2coronal(vol, axial2coronal=True):
return np.moveaxis(vol, [0, 1, 2], [0, 2, 1])


def transform_axial2sagittal(vol, axial2sagittal=True):
def transform_axial2sagittal(vol: np.ndarray, axial2sagittal: bool = True) -> np.ndarray:
"""
Function to transform volume into Sagittal axis and back
:param np.ndarray vol: image volume to transform
:param bool coronal2sagittal: transform from coronal to sagittal = True (default),
transform from sagittal to coronal = False
:return:
Transforms a volume into the sagittal axis and back.
This function is used to transform a volume into the sagittal axis and back. The transformation is done by moving
the axes of the volume. If the `axial2sagittal` parameter is set to True, the function will transform from axial to
sagittal. If it is set to False, the function will transform from sagittal to axial.
Parameters
----------
vol : np.ndarray
The image volume to transform.
axial2sagittal : bool, default=True
A flag to determine the direction of the transformation. If True, transform from axial to sagittal. If False,
transform from sagittal to axial. (Default: True).
Returns
-------
np.ndarray
The transformed volume.
"""
# TODO check compatibility with axis transform from CerebNet
if axial2sagittal:
Expand All @@ -116,7 +145,22 @@ def transform_axial2sagittal(vol, axial2sagittal=True):
return np.moveaxis(vol, [0, 1, 2], [1, 2, 0])


def rescale_image(img_data):
def rescale_image(img_data: np.ndarray) -> np.ndarray:
"""
Rescale the image data to the range [0, 255].
This function rescales the input image data to the range [0, 255].
Parameters
----------
img_data : np.ndarray
The image data to rescale.
Returns
-------
np.ndarray
The rescaled image data.
"""
# Conform intensities
# TODO move function into FastSurferCNN, same: CerebNet.datasets.utils.rescale_image
src_min, scale = getscale(img_data, 0, 255)
Expand All @@ -131,7 +175,19 @@ def rescale_image(img_data):

def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]:
"""
Function to perform look-up table mapping from label space to subseg space
Perform look-up table mapping from label space to subseg space.
This function is used to perform a look-up table mapping from label space to subseg space.
Parameters
----------
mapped_subseg : npt.NDArray[int]
The input array in label space to be mapped to subseg space.
Returns
-------
npt.NDArray[int]
The mapped array in subseg space.
"""
# TODO can this function be replaced by a Mapper and a mapping file?
labels, _ = hyposubseg_labels
Expand All @@ -143,13 +199,23 @@ def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]:


def hypo_map_prediction_sagittal2full(
prediction_sag: npt.NDArray[int],
prediction_sag: npt.NDArray[int],
) -> npt.NDArray[int]:
"""
Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks
:param prediction_sag: sagittal prediction (labels)
:param lbl_type: type of label
:return: Remapped prediction
Remap the prediction on the sagittal network to full label space.
This function is used to remap the prediction on the sagittal network to the full label space used by the coronal
and axial networks.
Parameters
----------
prediction_sag : npt.NDArray[int]
The sagittal prediction in label space to be remapped to full label space.
Returns
-------
npt.NDArray[int]
The remapped prediction in full label space.
"""
# TODO can this function be replaced by a Mapper and a mapping file?

Expand All @@ -163,15 +229,24 @@ def hypo_map_subseg_2_fsseg(
reverse: bool = False,
) -> npt.NDArray[int]:
"""
Function to remap HypVINN internal labels to FastSurfer Labels and viceversa
Remap HypVINN internal labels to FastSurfer Labels and vice versa.
This function is used to remap HypVINN internal labels to FastSurfer Labels and vice versa. If the `reverse`
parameter is set to False, the function will map HypVINN labels to FastSurfer labels. If it is set to True,
the function will map FastSurfer labels to HypVINN labels.
Parameters
----------
subseg
reverse
subseg : npt.NDArray[int]
The input array with HypVINN or FastSurfer labels to be remapped.
reverse : bool, optional
A flag to determine the direction of the remapping. If False, remap HypVINN labels to FastSurfer labels.
If True, remap FastSurfer labels to HypVINN labels. (Default: False).
Returns
-------
npt.NDArray[int]
The remapped array with FastSurfer or HypVINN labels.
"""
# TODO can this function be replaced by a Mapper and a mapping file?

Expand Down
104 changes: 89 additions & 15 deletions HypVINN/data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,27 @@
# Operator to load imaged for inference
class HypoVINN_dataset(Dataset):
"""
Class to load MRI-Image and process it to correct format for HypVINN network inference
The HypVINN Dataset passed during Inference the input images,the scale factor for the VINN layer and a weight factor (wT1,wT2)
The Weight factor determines the running mode of the HypVINN model
if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1)
if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2)
if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based on the learned modality weights (mode = t1t2)
Class to load MRI-Image and process it to correct format for HypVINN network inference.
The HypVINN Dataset passed during Inference the input images,the scale factor for the VINN layer and a weight factor
(wT1,wT2).
The Weight factor determines the running mode of the HypVINN model.
if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1).
if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2).
if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based
on the learned modality weights (mode = t1t2).
Methods
-------
_standarized_img(orig_data: np.ndarray, orig_zoom: npt.NDArray[float], modality: np.ndarray) -> np.ndarray
Standardize the image based on the original data, original zoom, and modality.
_get_scale_factor() -> npt.NDArray[float]
Get the scaling factor to match the original resolution of the input image to the final resolution of the
FastSurfer base network.
__getitem__(index: int) -> dict[str, torch.Tensor | np.ndarray]
Retrieve the image, scale factor, and weight factor for a given index.
__len__()
Return the number of images in the dataset.
"""
def __init__(
self,
Expand All @@ -46,6 +61,25 @@ def __init__(
mode: ModalityMode = "t1t2",
transforms=None,
):
"""
Initialize the HypoVINN Dataset.
Parameters
----------
subject_name : str
The name of the subject.
modalities : ModalityDict
The modalities of the subject.
orig_zoom : npt.NDArray[float]
The original zoom of the subject.
cfg : CfgNode
The configuration object.
mode : ModalityMode, default="t1t2"
The running mode of the HypVINN model. (Default: "t1t2").
transforms : Callable, optional
The transformations to apply to the images. (Default: None).
"""
self.subject_name = subject_name
self.plane = cfg.DATA.PLANE
#Inference Mode
Expand Down Expand Up @@ -84,7 +118,8 @@ def __init__(
f"model"
)

if (cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT) :
if ((cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and
(self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT)) :
logger.info(
f"For inference T1 block weight and the T2 block are set to "
f"the weights learn during training"
Expand All @@ -95,7 +130,25 @@ def __init__(
f"{self.weight_factor.numpy()[0]} and the T2 block was set to: "
f"{self.weight_factor.numpy()[1]}")

def _standarized_img(self, orig_data, orig_zoom, modality):
def _standarized_img(self, orig_data: np.ndarray, orig_zoom: npt.NDArray[float],
modality: np.ndarray) -> np.ndarray:
"""
Standardize the image based on the original data, original zoom, and modality.
Parameters
----------
orig_data : np.ndarray
The original data of the image.
orig_zoom : npt.NDArray[float]
The original zoom of the image.
modality : np.ndarray
The modality of the image.
Returns
-------
orig_thick : np.ndarray
The standardized image.
"""
if self.plane == "sagittal":
orig_data = transform_axial2sagittal(orig_data)
self.zoom = orig_zoom[::-1][:2]
Expand Down Expand Up @@ -123,19 +176,40 @@ def _standarized_img(self, orig_data, orig_zoom, modality):

def _get_scale_factor(self) -> npt.NDArray[float]:
"""
Get scaling factor to match original resolution of input image to
final resolution of FastSurfer base network. Input resolution is
taken from voxel size in image header.
ToDO: This needs to be updated based on the plane we are looking at in case we
are dealing with non-isotropic images as inputs.
:param img_zoom:
:return np.ndarray(float32): scale factor along x and y dimension
Get the scaling factor to match the original resolution of the input image to
the final resolution of the FastSurfer base network. The input resolution is
taken from the voxel size in the image header.
Returns
-------
scale : npt.NDArray[float]
The scaling factor along the x and y dimensions. This is a numpy array of float values.
"""
# TODO: This needs to be updated based on the plane we are looking at in case we
# are dealing with non-isotropic images as inputs.

scale = self.base_res / np.asarray(self.zoom)

return scale

def __getitem__(self, index: int) -> dict[str, torch.Tensor | np.ndarray]:
"""
Retrieve the image, scale factor, and weight factor for a given index.
This method retrieves the image at the given index from the images attribute, calculates the scale factor,
applies any transformations to the image if they are defined, and returns a dictionary containing the image,
scale factor, and weight factor.
Parameters
----------
index : int
The index of the image to retrieve.
Returns
-------
dict[str, torch.Tensor | np.ndarray]
A dictionary containing the image, scale factor, and weight factor.
"""
img = self.images[index]

scale_factor = self._get_scale_factor()
Expand Down
Loading

0 comments on commit 016a5ba

Please sign in to comment.