Skip to content

Commit

Permalink
Applying changes requested in code review
Browse files Browse the repository at this point in the history
  • Loading branch information
taha-abdullah authored and dkuegler committed Jun 20, 2024
1 parent 706cb46 commit 78b7a9a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
9 changes: 5 additions & 4 deletions HypVINN/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ def get_device(self):
@torch.no_grad()
def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float = None) -> torch.Tensor:
"""
Evaluate the model on a validation set.
Evaluate the model on a HypVINN dataset.
This method runs the model in evaluation mode on a validation set. It iterates over the validation set,
computes the model's predictions, and updates the prediction probabilities based on the plane of the data.
This method runs the model in evaluation mode on a HypVINN Dataset. It iterates over the given dataset and
computes the model's predictions.
Parameters
----------
Expand Down Expand Up @@ -355,7 +355,8 @@ def run(
"""
Run the inference process on a single subject.
This method sets up a DataLoader for the subject, runs the model in evaluation mode on the subject's data,
This method sets up the HypVINN DataLoader for the subject, runs the model in evaluation mode on the subject's
data,
and returns the updated prediction probabilities.
Parameters
Expand Down
7 changes: 3 additions & 4 deletions HypVINN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# IMPORTS
from typing import TYPE_CHECKING, Optional, cast, Literal
import argparse
Expand Down Expand Up @@ -183,7 +182,7 @@ def main(
device: str = "auto",
viewagg_device: str = "auto",
) -> int | str:
"""
f"""
Main function of the hypothalamus segmentation module.
Parameters
Expand All @@ -193,7 +192,7 @@ def main(
t2 : Path, optional
The path to the T2 image to process.
orig_name : Path, optional
The original name of the input image.
The path to the T1 image to process or FastSurfer orig image.
sid : str
The subject ID.
ckpt_ax : Path
Expand All @@ -209,7 +208,7 @@ def main(
cfg_sag : Path
The path to the sagittal configuration file.
hypo_segfile : str, default="{HYPVINN_SEG_NAME}"
The name of the hypothalamus segmentation file. Default is HYPVINN_SEG_NAME.
The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}.
allow_root : bool, default=False
Whether to allow running as root user. Default is False.
qc_snapshots : bool, optional
Expand Down
9 changes: 5 additions & 4 deletions HypVINN/utils/img_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def save_segmentation(
prediction: np.ndarray,
orig_path: Path,
ras_affine: npt.NDArray[float],
ras_header: nib.nifti1.Nifti1Header,
ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader,
subject_dir: Path,
seg_file: Path,
save_mask: bool = False,
Expand Down Expand Up @@ -118,7 +118,7 @@ def save_logits(
logits: npt.NDArray[float],
orig_path: Path,
ras_affine: npt.NDArray[float],
ras_header: nib.nifti1.Nifti1Header,
ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader,
save_dir: Path,
mode: str,
) -> Path:
Expand Down Expand Up @@ -171,7 +171,7 @@ def save_logits(
def get_clean_mask(segmentation: np.ndarray, optic=False) \
-> tuple[np.ndarray, np.ndarray, bool]:
"""
Get a clean mask by removing not connected components.
Get a clean mask by removing non-connected components from a dilated mask.
This function takes a segmentation mask and an optional boolean flag indicating whether to consider optic labels.
It removes not connected components from the segmentation mask and returns the cleaned segmentation mask, the
Expand Down Expand Up @@ -244,7 +244,8 @@ def get_clean_mask(segmentation: np.ndarray, optic=False) \

def get_clean_labels(segmentation: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Function to find the largest connected component of the segmentation.
Get clean labels by removing non-connected components from a dilated mask and any connected component with size
less than 3.
Parameters
----------
Expand Down
13 changes: 8 additions & 5 deletions HypVINN/utils/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path
from pathlib import Path

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from HypVINN.config.hypvinn_files import HYPVINN_LUT
from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT

_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT)


def remove_values_from_list(the_list, val):
Expand All @@ -42,13 +45,13 @@ def remove_values_from_list(the_list, val):

def get_lut(lookup_table_path: Path = HYPVINN_LUT):
f"""
Retrieve a lookup table (LUT) from a file.
Retrieve a color lookup table (LUT) from a file.
This function reads a file and constructs a lookup table (LUT) from it.
Parameters
----------
lookup_table_path: Path, default="{HYPVINN_LUT}"
lookup_table_path: Path, default="{_doc_HYPVINN_LUT}"
The path to the file from which the LUT will be constructed.
Returns
Expand Down Expand Up @@ -77,7 +80,7 @@ def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT):
----------
hyposeg : np.ndarray
The original segmentation map.
lut_file : Path, default="{HYPVINN_LUT}"
lut_file : Path, default="{_doc_HYPVINN_LUT}"
The path to the lookup table file.
Returns
Expand Down Expand Up @@ -255,7 +258,7 @@ def plot_qc_images(
The path to the predicted image.
padd : int, default=45
The padding value for cropping the images and segmentations.
lut_file : Path, default="{HYPVINN_LUT}"
lut_file : Path, default="{_doc_HYPVINN_LUT}"
The path to the lookup table file.
slice_step : int, default=2
The step size for selecting indices from the predicted segmentation.
Expand Down

0 comments on commit 78b7a9a

Please sign in to comment.