Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Added ability to run segmentation inference module in the test data w…
Browse files Browse the repository at this point in the history
…ithout or partial ground truth files. (#465)

Inference should become more flexible to work also on datasets where we do not have segmentations yet, or only some of the structures segmented. This will prove especially useful for partners like UCLH. Done by adding `allow_incomplete_labels` to the relevant methods.
  • Loading branch information
asantamariapang authored Jul 5, 2021
1 parent cab68cc commit d6d67bd
Show file tree
Hide file tree
Showing 20 changed files with 638 additions and 171 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ wheels/
.installed.cfg
*.egg
MANIFEST
packages-microsoft-prod.deb

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ created.
## Upcoming

### Added
- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference
module in the test data without or partial ground truth files.
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
jobs that run in AzureML.
Expand Down
8 changes: 8 additions & 0 deletions InnerEye/ML/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,17 +474,25 @@ class SegmentationModelBase(ModelConfigBase):
is_plotting_enabled: bool = param.Boolean(True, doc="If true, various overview plots with results are generated "
"during model evaluation. Set to False if you see "
"non-deterministic pull request build failures.")

show_patch_sampling: int = param.Integer(1, bounds=(0, None),
doc="Number of patients from the training set for which the effect of"
"patch sampling will be shown. Nifti images and thumbnails for each"
"of the first N subjects in the training set will be "
"written to the outputs folder.")

#: If true an error is raised in InnerEye.ML.utils.io_util.load_labels_from_dataset_source if the labels are not
#: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others (e.g.
#: FocalLoss) will fail with a cryptic error message. Set to false if you are sure that you want to use labels that
#: are not mutually exclusive.
check_exclusive: bool = param.Boolean(True, doc="Raise an error if the segmentation labels are not mutually exclusive.")

allow_incomplete_labels: bool = param.Boolean(
default=False,
doc="If False, the default, then test patient data must include all of the ground truth labels. If true then "
"some test patient data with missing ground truth data is allowed and will be reflected in the patient "
"counts in the metrics and report.")

def __init__(self, center_size: Optional[TupleInt3] = None,
inference_stride_size: Optional[TupleInt3] = None,
min_l_rate: float = 0,
Expand Down
51 changes: 33 additions & 18 deletions InnerEye/ML/dataset/full_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame,
full_image_sample_transforms: Optional[Compose3D[Sample]] = None):
super().__init__(args, data_frame)
self.full_image_sample_transforms = full_image_sample_transforms

self.allow_incomplete_labels = args.allow_incomplete_labels
# Check base_path
assert self.args.local_dataset is not None
if not self.args.local_dataset.is_dir():
Expand Down Expand Up @@ -250,7 +250,8 @@ def _extension_from_df_file_paths(file_paths: List[str]) -> str:
def get_samples_at_index(self, index: int) -> List[Sample]:
# load the channels into memory
ds = self.dataset_sources[self.dataset_indices[index]]
samples = [io_util.load_images_from_dataset_source(dataset_source=ds, check_exclusive=self.args.check_exclusive)] # type: ignore
samples = [io_util.load_images_from_dataset_source(dataset_source=ds,
check_exclusive=self.args.check_exclusive)] # type: ignore
return [Compose3D.apply(self.full_image_sample_transforms, x) for x in samples]

def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
Expand All @@ -259,34 +260,40 @@ def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
local_dataset_root_folder=self.args.local_dataset,
image_channels=self.args.image_channels,
ground_truth_channels=self.args.ground_truth_ids,
mask_channel=self.args.mask_id
)
mask_channel=self.args.mask_id,
allow_incomplete_labels=self.allow_incomplete_labels)


def convert_channels_to_file_paths(channels: List[str],
rows: pd.DataFrame,
local_dataset_root_folder: Path,
patient_id: str) -> Tuple[List[Path], str]:
patient_id: str,
allow_incomplete_labels: bool = False) -> Tuple[List[Optional[Path]], str]:
"""
Returns: 1) The full path for files specified in the training, validation and testing datasets, and
2) Missing channels or missing files.
Returns: 1) A list of path file objects specified in the training, validation and testing datasets, and
2) a string with description of missing channels, files and more than one channel per patient.
:param channels: channel type defined in the configuration file
:param rows: Input Pandas dataframe object containing subjectIds, path of local dataset, channel information
:param local_dataset_root_folder: Root directory which points to the local dataset
:param patient_id: string which contains subject identifier
:param allow_incomplete_labels: boolean flag. If false, all ground truth files must be provided. If true, ground
truth files are optional
"""
paths: List[Path] = []
failed_channel_info: str = ''
paths: List[Optional[Path]] = []
failed_channel_info = ''

for channel_id in channels:
row = rows.loc[rows[CSV_CHANNEL_HEADER] == channel_id]
if len(row) == 0:
if len(row) == 0 and not allow_incomplete_labels:
failed_channel_info += f"Patient {patient_id} does not have channel '{channel_id}'" + os.linesep
elif len(row) == 0 and allow_incomplete_labels:
# Keeps track of missing channels order
paths.append(None)
elif len(row) > 1:
failed_channel_info += f"Patient {patient_id} has more than one entry for channel '{channel_id}'" + \
os.linesep
else:
elif len(row) == 1:
image_path = local_dataset_root_folder / row[CSV_PATH_HEADER].values[0]
if not image_path.is_file():
failed_channel_info += f"Patient {patient_id}, file {image_path} does not exist" + os.linesep
Expand All @@ -300,7 +307,8 @@ def load_dataset_sources(dataframe: pd.DataFrame,
local_dataset_root_folder: Path,
image_channels: List[str],
ground_truth_channels: List[str],
mask_channel: Optional[str]) -> Dict[str, PatientDatasetSource]:
mask_channel: Optional[str],
allow_incomplete_labels: bool = False) -> Dict[str, PatientDatasetSource]:
"""
Prepares a patient-to-images mapping from a dataframe read directly from a dataset CSV file.
The dataframe contains per-patient per-channel image information, relative to a root directory.
Expand All @@ -311,6 +319,8 @@ def load_dataset_sources(dataframe: pd.DataFrame,
:param image_channels: The names of the image channels that should be used in the result.
:param ground_truth_channels: The names of the ground truth channels that should be used in the result.
:param mask_channel: The name of the mask channel that should be used in the result. This can be None.
:param allow_incomplete_labels: Boolean flag. If false, all ground truth files must be provided. If true, ground
truth files are optional. Default value is false.
:return: A dictionary mapping from an integer subject ID to a PatientDatasetSource.
"""
expected_headers = {CSV_SUBJECT_HEADER, CSV_PATH_HEADER, CSV_CHANNEL_HEADER}
Expand All @@ -328,16 +338,19 @@ def load_dataset_sources(dataframe: pd.DataFrame,
def get_mask_channel_or_default() -> Optional[Path]:
if mask_channel is None:
return None
paths = get_paths_for_channel_ids(channels=[mask_channel], allow_incomplete_labels_flag=allow_incomplete_labels)
if len(paths) == 0:
return None
else:
return get_paths_for_channel_ids(channels=[mask_channel])[0]
return paths[0]

def get_paths_for_channel_ids(channels: List[str]) -> List[Path]:
def get_paths_for_channel_ids(channels: List[str], allow_incomplete_labels_flag: bool) -> List[Optional[Path]]:
if len(set(channels)) < len(channels):
raise ValueError(f"ids have duplicated entries: {channels}")
rows = dataframe.loc[dataframe[CSV_SUBJECT_HEADER] == patient_id]
# converts channels to paths and makes second sanity check for channel data
paths, failed_channel_info = convert_channels_to_file_paths(channels, rows, local_dataset_root_folder,
patient_id)
patient_id, allow_incomplete_labels_flag)

if failed_channel_info:
raise ValueError(failed_channel_info)
Expand All @@ -349,9 +362,11 @@ def get_paths_for_channel_ids(channels: List[str]) -> List[Path]:
metadata = PatientMetadata.from_dataframe(dataframe, patient_id)
dataset_sources[patient_id] = PatientDatasetSource(
metadata=metadata,
image_channels=get_paths_for_channel_ids(channels=image_channels), # type: ignore
image_channels=get_paths_for_channel_ids(channels=image_channels, # type: ignore
allow_incomplete_labels_flag=False),
mask_channel=get_mask_channel_or_default(),
ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels) # type: ignore
)
ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels, # type: ignore
allow_incomplete_labels_flag=allow_incomplete_labels),
allow_incomplete_labels=allow_incomplete_labels)

return dataset_sources
7 changes: 6 additions & 1 deletion InnerEye/ML/dataset/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,24 @@ class PatientDatasetSource(SampleBase):
Dataset source locations for channels associated with a given patient in a particular dataset.
"""
image_channels: List[PathOrString]
ground_truth_channels: List[PathOrString]
ground_truth_channels: List[Optional[PathOrString]]
mask_channel: Optional[PathOrString]
metadata: PatientMetadata
allow_incomplete_labels: Optional[bool] = False

def __post_init__(self) -> None:
# make sure all properties are populated
common_util.check_properties_are_not_none(self, ignore=["mask_channel"])

if not self.image_channels:
raise ValueError("image_channels cannot be empty")

if not self.ground_truth_channels:
raise ValueError("ground_truth_channels cannot be empty")

if self.ground_truth_channels.count(None) > 0 and not self.allow_incomplete_labels:
raise ValueError("all ground_truth_channels must be provided")


@dataclass(frozen=True)
class Sample(SampleBase):
Expand Down
8 changes: 5 additions & 3 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ def setup(self) -> None:
unique_ids = set(split_data[CSV_SUBJECT_HEADER])
for patient_id in unique_ids:
rows = split_data.loc[split_data[CSV_SUBJECT_HEADER] == patient_id]
allow_incomplete_labels = self.config.allow_incomplete_labels # type: ignore
# Converts channels from data frame to file paths and gets errors if any
__, failed_channel_info = convert_channels_to_file_paths(all_channels,
rows,
local_dataset_root_folder,
patient_id)
rows,
local_dataset_root_folder,
patient_id,
allow_incomplete_labels)
full_failed_channel_info += failed_channel_info

if full_failed_channel_info:
Expand Down
42 changes: 30 additions & 12 deletions InnerEye/ML/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import SimpleITK as sitk
import numpy as np
from numpy.core.numeric import NaN
import torch
import torch.nn.functional as F
from azureml.core import Run
Expand All @@ -21,12 +22,13 @@
from InnerEye.Common.type_annotations import DictStrFloat, TupleFloat3
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import BACKGROUND_CLASS_NAME
from InnerEye.ML.metrics_dict import DataframeLogger, INTERNAL_TO_LOGGING_COLUMN_NAMES, MetricsDict, \
ScalarMetricsDict
from InnerEye.ML.metrics_dict import (DataframeLogger, INTERNAL_TO_LOGGING_COLUMN_NAMES, MetricsDict,
ScalarMetricsDict)
from InnerEye.ML.scalar_config import ScalarLoss
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, is_binary_array
from InnerEye.ML.utils.io_util import reverse_tuple_float3
from InnerEye.ML.utils.metrics_util import binary_classification_accuracy, mean_absolute_error, r2_score
from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, mean_absolute_error,
r2_score, is_missing_ground_truth)
from InnerEye.ML.utils.ml_util import check_size_matches
from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels

Expand Down Expand Up @@ -56,15 +58,15 @@ class InferenceMetricsForSegmentation(InferenceMetrics):
"""
Stores metrics for segmentation models, per execution mode and epoch.
"""
data_split: ModelExecutionMode
execution_mode: ModelExecutionMode
metrics: float

def get_metrics_log_key(self) -> str:
"""
Gets a string name for logging the metrics specific to the execution mode (train, val, test)
:return:
"""
return f"InferenceMetrics_{self.data_split.value}"
return f"InferenceMetrics_{self.execution_mode.value}"

def log_metrics(self, run_context: Run = None) -> None:
"""
Expand Down Expand Up @@ -230,9 +232,10 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
Calculate the dice for all foreground structures (the background class is completely ignored).
Returns a MetricsDict with metrics for each of the foreground
structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
If first element of a ground truth image channel is NaN, the image is flagged as NaN and not use.
:param ground_truth_ids: The names of all foreground classes.
:param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X].
:param voxel_spacing: voxel_spacing in 3D Z x Y x X
:param patient_id: for logging
"""
Expand All @@ -242,15 +245,34 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
f"the label tensor indicates that there are {number_of_classes - 1} classes.")
binaries = binaries_from_multi_label_array(segmentation, number_of_classes)

all_classes_are_binary = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]
if not np.all(all_classes_are_binary):
binary_classes = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]

# If ground truth image is nan, then will not be used for metrics computation.
nan_images = [is_missing_ground_truth(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]

# Compares element-wise if not binary then nan and checks all elements are True.
assert np.all(np.array(binary_classes) == ~np.array(nan_images))

# Validates that all binary images should be 0 or 1
if not np.all(np.array(binary_classes)[~np.array(nan_images)]):
raise ValueError("Ground truth values should be 0 or 1")
overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
metrics = MetricsDict(hues=ground_truth_ids)

def add_metric(metric_type: MetricType, value: float) -> None:
metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1])

for i, prediction in enumerate(binaries):
# Skip if background image
if i == 0:
continue
# Skip but record if nan_image
elif nan_images[i]:
add_metric(MetricType.DICE, NaN)
add_metric(MetricType.HAUSDORFF_mm, NaN)
add_metric(MetricType.MEAN_SURFACE_DIST_mm, NaN)
continue
check_size_matches(prediction, ground_truth[i], arg1_name="prediction", arg2_name="ground_truth")
if not is_binary_array(prediction):
raise ValueError("Predictions values should be 0 or 1")
Expand Down Expand Up @@ -280,10 +302,6 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
except Exception as e:
logging.warning(f"Cannot calculate mean distance for structure {i} of patient {patient_id}: {e}")
logging.debug(f"Patient {patient_id}, class {i} has Dice score {dice}")

def add_metric(metric_type: MetricType, value: float) -> None:
metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1])

add_metric(MetricType.DICE, dice)
add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance)
add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance)
Expand Down
Loading

0 comments on commit d6d67bd

Please sign in to comment.