Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Added ability to run segmentation inference module in the test data without or partial ground truth files. #465

Merged
merged 51 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9c08fc6
Added initial unit test to evaluate model predictions.
May 21, 2021
f867c9f
Updated log PR #465.
May 24, 2021
8507ae7
Added unit test for segmentation inference in test data and no ground…
May 25, 2021
6137efb
Renamed variable from: "for_inference" to "allow_incomplete_labels".
May 25, 2021
2e8f173
Completed type annotation.
May 25, 2021
a500663
Extended unit test "test_evaluate_model_predictions" to account for m…
May 26, 2021
e9723e5
Improved inference support for missing labels.
asantamariapang May 30, 2021
73955f6
Revert "Improved inference support for missing labels."
asantamariapang Jun 1, 2021
5caf245
Updated missing ground truth labels and masks.
asantamariapang Jun 2, 2021
d703f2c
Merge branch 'main' into alberto/inference
asantamariapang Jun 3, 2021
77d3244
Fixed bug and improved documentation.
asantamariapang Jun 8, 2021
051681a
Merge branch 'main' into alberto/inference
asantamariapang Jun 8, 2021
fc6c9cc
Merge branch 'main' into alberto/inference
javier-alvarez Jun 15, 2021
652b6aa
WiP allowing NaNs in averaging to count them
Jun 24, 2021
899e648
Deeper testing in test_evaluate_model_predictions
Jun 25, 2021
2b46a44
Reverting CHANGELOG for now
Jun 25, 2021
94b90ab
Unused import
Jun 25, 2021
36f3994
WiP testing partial_ground_truth metrics output
Jun 25, 2021
03e530c
messy WiP with testing ground truths
dumbledad Jun 26, 2021
a79b1d0
Removing started small test data script
dumbledad Jun 26, 2021
fcacfcd
Adding labels for partial test
dumbledad Jun 27, 2021
75b1c37
WiP on partial ground truth unit test of model_test
dumbledad Jun 27, 2021
6f4bbd8
Unit test of partial ground truth works, but other fails :(
dumbledad Jun 27, 2021
713c2f4
mypy fixes
dumbledad Jun 27, 2021
554c6d2
tidy
dumbledad Jun 27, 2021
a7f7a19
flake fixes
dumbledad Jun 27, 2021
a821f9b
Documentation typos
Jun 28, 2021
45c4d2a
Adding allow_incomplete_labels to Seg'ModelBase
Jun 28, 2021
65b9df3
Merge branch 'main' into alberto/inference
dumbledad Jun 29, 2021
86a8be7
Checking that partial ground truth is not allowed unless explicit
dumbledad Jun 29, 2021
58ab13b
Added unit test of IPYNB-HTML for partial ground truth
Jun 30, 2021
a4402f7
flake fixes
Jun 30, 2021
5f68941
temp change for end2end test
dumbledad Jul 1, 2021
987234a
WSL package for .Net on linux
dumbledad Jul 1, 2021
b6241d2
Passing on allow_partial in lightning model base
dumbledad Jul 1, 2021
5f173f4
Reverting change made for end2end test only
dumbledad Jul 1, 2021
edb88d6
Fixing param comment
Jul 2, 2021
53ea4fc
Fixing unnecessary allow_incomplete, already in args
Jul 2, 2021
3a86cf9
Removing redundant isinstance
Jul 2, 2021
b47e197
Moving duplicated fragment to util function
Jul 2, 2021
615ae1a
missing bracket
dumbledad Jul 4, 2021
6a0bd3a
Clearer comments for save_aggregates_to_csv
dumbledad Jul 4, 2021
b3ed3f1
Unit test for InnerEyeContainer setup pass through
dumbledad Jul 4, 2021
da79dd3
flake fixes (I hope!)
dumbledad Jul 4, 2021
eadd7f9
changelog
dumbledad Jul 4, 2021
f9540ed
Odd fix, moved definition to enclosing scope
dumbledad Jul 4, 2021
5f4fe5b
removing Path(__file__)
Jul 5, 2021
411277c
using is_missing again
Jul 5, 2021
5a5f946
no partial images!
Jul 5, 2021
f1fd3c8
Merge branch 'main' into alberto/inference
Jul 5, 2021
96840e8
Merge branch 'main' into alberto/inference
Jul 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ wheels/
.installed.cfg
*.egg
MANIFEST
packages-microsoft-prod.deb
ant0nsc marked this conversation as resolved.
Show resolved Hide resolved

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ created.
## Upcoming

### Added
- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference
module in the test data without or partial ground truth files.
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
jobs that run in AzureML.
Expand Down
8 changes: 8 additions & 0 deletions InnerEye/ML/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,17 +474,25 @@ class SegmentationModelBase(ModelConfigBase):
is_plotting_enabled: bool = param.Boolean(True, doc="If true, various overview plots with results are generated "
"during model evaluation. Set to False if you see "
"non-deterministic pull request build failures.")

show_patch_sampling: int = param.Integer(1, bounds=(0, None),
doc="Number of patients from the training set for which the effect of"
"patch sampling will be shown. Nifti images and thumbnails for each"
"of the first N subjects in the training set will be "
"written to the outputs folder.")

#: If true an error is raised in InnerEye.ML.utils.io_util.load_labels_from_dataset_source if the labels are not
#: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others (e.g.
#: FocalLoss) will fail with a cryptic error message. Set to false if you are sure that you want to use labels that
#: are not mutually exclusive.
check_exclusive: bool = param.Boolean(True, doc="Raise an error if the segmentation labels are not mutually exclusive.")

allow_incomplete_labels: bool = param.Boolean(
default=False,
doc="If False, the default, then test patient data must include all of the ground truth labels. If true then "
"some test patient data with missing ground truth data is allowed and will be reflected in the patient "
"counts in the metrics and report.")

def __init__(self, center_size: Optional[TupleInt3] = None,
inference_stride_size: Optional[TupleInt3] = None,
min_l_rate: float = 0,
Expand Down
51 changes: 33 additions & 18 deletions InnerEye/ML/dataset/full_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame,
full_image_sample_transforms: Optional[Compose3D[Sample]] = None):
super().__init__(args, data_frame)
self.full_image_sample_transforms = full_image_sample_transforms

self.allow_incomplete_labels = args.allow_incomplete_labels
# Check base_path
assert self.args.local_dataset is not None
if not self.args.local_dataset.is_dir():
Expand Down Expand Up @@ -250,7 +250,8 @@ def _extension_from_df_file_paths(file_paths: List[str]) -> str:
def get_samples_at_index(self, index: int) -> List[Sample]:
# load the channels into memory
ds = self.dataset_sources[self.dataset_indices[index]]
samples = [io_util.load_images_from_dataset_source(dataset_source=ds, check_exclusive=self.args.check_exclusive)] # type: ignore
samples = [io_util.load_images_from_dataset_source(dataset_source=ds,
check_exclusive=self.args.check_exclusive)] # type: ignore
return [Compose3D.apply(self.full_image_sample_transforms, x) for x in samples]

def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
Expand All @@ -259,34 +260,40 @@ def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
local_dataset_root_folder=self.args.local_dataset,
image_channels=self.args.image_channels,
ground_truth_channels=self.args.ground_truth_ids,
mask_channel=self.args.mask_id
)
mask_channel=self.args.mask_id,
allow_incomplete_labels=self.allow_incomplete_labels)


def convert_channels_to_file_paths(channels: List[str],
rows: pd.DataFrame,
local_dataset_root_folder: Path,
patient_id: str) -> Tuple[List[Path], str]:
patient_id: str,
allow_incomplete_labels: bool = False) -> Tuple[List[Optional[Path]], str]:
"""
Returns: 1) The full path for files specified in the training, validation and testing datasets, and
2) Missing channels or missing files.
Returns: 1) A list of path file objects specified in the training, validation and testing datasets, and
2) a string with description of missing channels, files and more than one channel per patient.

:param channels: channel type defined in the configuration file
:param rows: Input Pandas dataframe object containing subjectIds, path of local dataset, channel information
:param local_dataset_root_folder: Root directory which points to the local dataset
:param patient_id: string which contains subject identifier
:param allow_incomplete_labels: boolean flag. If false, all ground truth files must be provided. If true, ground
truth files are optional
"""
paths: List[Path] = []
failed_channel_info: str = ''
paths: List[Optional[Path]] = []
failed_channel_info = ''

for channel_id in channels:
row = rows.loc[rows[CSV_CHANNEL_HEADER] == channel_id]
if len(row) == 0:
if len(row) == 0 and not allow_incomplete_labels:
failed_channel_info += f"Patient {patient_id} does not have channel '{channel_id}'" + os.linesep
elif len(row) == 0 and allow_incomplete_labels:
# Keeps track of missing channels order
paths.append(None)
asantamariapang marked this conversation as resolved.
Show resolved Hide resolved
elif len(row) > 1:
failed_channel_info += f"Patient {patient_id} has more than one entry for channel '{channel_id}'" + \
os.linesep
else:
elif len(row) == 1:
image_path = local_dataset_root_folder / row[CSV_PATH_HEADER].values[0]
if not image_path.is_file():
failed_channel_info += f"Patient {patient_id}, file {image_path} does not exist" + os.linesep
Expand All @@ -300,7 +307,8 @@ def load_dataset_sources(dataframe: pd.DataFrame,
local_dataset_root_folder: Path,
image_channels: List[str],
ground_truth_channels: List[str],
mask_channel: Optional[str]) -> Dict[str, PatientDatasetSource]:
mask_channel: Optional[str],
allow_incomplete_labels: bool = False) -> Dict[str, PatientDatasetSource]:
"""
Prepares a patient-to-images mapping from a dataframe read directly from a dataset CSV file.
The dataframe contains per-patient per-channel image information, relative to a root directory.
Expand All @@ -311,6 +319,8 @@ def load_dataset_sources(dataframe: pd.DataFrame,
:param image_channels: The names of the image channels that should be used in the result.
:param ground_truth_channels: The names of the ground truth channels that should be used in the result.
:param mask_channel: The name of the mask channel that should be used in the result. This can be None.
:param allow_incomplete_labels: Boolean flag. If false, all ground truth files must be provided. If true, ground
truth files are optional. Default value is false.
:return: A dictionary mapping from an integer subject ID to a PatientDatasetSource.
"""
expected_headers = {CSV_SUBJECT_HEADER, CSV_PATH_HEADER, CSV_CHANNEL_HEADER}
Expand All @@ -328,16 +338,19 @@ def load_dataset_sources(dataframe: pd.DataFrame,
def get_mask_channel_or_default() -> Optional[Path]:
if mask_channel is None:
return None
paths = get_paths_for_channel_ids(channels=[mask_channel], allow_incomplete_labels_flag=allow_incomplete_labels)
if len(paths) == 0:
return None
else:
return get_paths_for_channel_ids(channels=[mask_channel])[0]
return paths[0]

def get_paths_for_channel_ids(channels: List[str]) -> List[Path]:
def get_paths_for_channel_ids(channels: List[str], allow_incomplete_labels_flag: bool) -> List[Optional[Path]]:
if len(set(channels)) < len(channels):
raise ValueError(f"ids have duplicated entries: {channels}")
rows = dataframe.loc[dataframe[CSV_SUBJECT_HEADER] == patient_id]
# converts channels to paths and makes second sanity check for channel data
paths, failed_channel_info = convert_channels_to_file_paths(channels, rows, local_dataset_root_folder,
patient_id)
patient_id, allow_incomplete_labels_flag)

if failed_channel_info:
raise ValueError(failed_channel_info)
Expand All @@ -349,9 +362,11 @@ def get_paths_for_channel_ids(channels: List[str]) -> List[Path]:
metadata = PatientMetadata.from_dataframe(dataframe, patient_id)
dataset_sources[patient_id] = PatientDatasetSource(
metadata=metadata,
image_channels=get_paths_for_channel_ids(channels=image_channels), # type: ignore
image_channels=get_paths_for_channel_ids(channels=image_channels, # type: ignore
allow_incomplete_labels_flag=False),
mask_channel=get_mask_channel_or_default(),
ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels) # type: ignore
)
ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels, # type: ignore
allow_incomplete_labels_flag=allow_incomplete_labels),
allow_incomplete_labels=allow_incomplete_labels)

return dataset_sources
7 changes: 6 additions & 1 deletion InnerEye/ML/dataset/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,24 @@ class PatientDatasetSource(SampleBase):
Dataset source locations for channels associated with a given patient in a particular dataset.
"""
image_channels: List[PathOrString]
ground_truth_channels: List[PathOrString]
ground_truth_channels: List[Optional[PathOrString]]
mask_channel: Optional[PathOrString]
metadata: PatientMetadata
allow_incomplete_labels: Optional[bool] = False

def __post_init__(self) -> None:
# make sure all properties are populated
common_util.check_properties_are_not_none(self, ignore=["mask_channel"])

if not self.image_channels:
raise ValueError("image_channels cannot be empty")

if not self.ground_truth_channels:
raise ValueError("ground_truth_channels cannot be empty")

if self.ground_truth_channels.count(None) > 0 and not self.allow_incomplete_labels:
raise ValueError("all ground_truth_channels must be provided")


@dataclass(frozen=True)
class Sample(SampleBase):
Expand Down
8 changes: 5 additions & 3 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ def setup(self) -> None:
unique_ids = set(split_data[CSV_SUBJECT_HEADER])
for patient_id in unique_ids:
rows = split_data.loc[split_data[CSV_SUBJECT_HEADER] == patient_id]
allow_incomplete_labels = self.config.allow_incomplete_labels # type: ignore
# Converts channels from data frame to file paths and gets errors if any
__, failed_channel_info = convert_channels_to_file_paths(all_channels,
rows,
local_dataset_root_folder,
patient_id)
rows,
local_dataset_root_folder,
patient_id,
allow_incomplete_labels)
full_failed_channel_info += failed_channel_info

if full_failed_channel_info:
Expand Down
42 changes: 30 additions & 12 deletions InnerEye/ML/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import SimpleITK as sitk
import numpy as np
from numpy.core.numeric import NaN
import torch
import torch.nn.functional as F
from azureml.core import Run
Expand All @@ -21,12 +22,13 @@
from InnerEye.Common.type_annotations import DictStrFloat, TupleFloat3
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import BACKGROUND_CLASS_NAME
from InnerEye.ML.metrics_dict import DataframeLogger, INTERNAL_TO_LOGGING_COLUMN_NAMES, MetricsDict, \
ScalarMetricsDict
from InnerEye.ML.metrics_dict import (DataframeLogger, INTERNAL_TO_LOGGING_COLUMN_NAMES, MetricsDict,
ScalarMetricsDict)
from InnerEye.ML.scalar_config import ScalarLoss
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, is_binary_array
from InnerEye.ML.utils.io_util import reverse_tuple_float3
from InnerEye.ML.utils.metrics_util import binary_classification_accuracy, mean_absolute_error, r2_score
from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, mean_absolute_error,
r2_score, is_missing_ground_truth)
from InnerEye.ML.utils.ml_util import check_size_matches
from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels

Expand Down Expand Up @@ -56,15 +58,15 @@ class InferenceMetricsForSegmentation(InferenceMetrics):
"""
Stores metrics for segmentation models, per execution mode and epoch.
"""
data_split: ModelExecutionMode
execution_mode: ModelExecutionMode
metrics: float

def get_metrics_log_key(self) -> str:
"""
Gets a string name for logging the metrics specific to the execution mode (train, val, test)
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
:return:
"""
return f"InferenceMetrics_{self.data_split.value}"
return f"InferenceMetrics_{self.execution_mode.value}"

def log_metrics(self, run_context: Run = None) -> None:
"""
Expand Down Expand Up @@ -230,9 +232,10 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
Calculate the dice for all foreground structures (the background class is completely ignored).
Returns a MetricsDict with metrics for each of the foreground
structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
If first element of a ground truth image channel is NaN, the image is flagged as NaN and not use.
:param ground_truth_ids: The names of all foreground classes.
:param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X].
:param voxel_spacing: voxel_spacing in 3D Z x Y x X
:param patient_id: for logging
"""
Expand All @@ -242,15 +245,34 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
f"the label tensor indicates that there are {number_of_classes - 1} classes.")
binaries = binaries_from_multi_label_array(segmentation, number_of_classes)

all_classes_are_binary = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]
if not np.all(all_classes_are_binary):
binary_classes = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]

# If ground truth image is nan, then will not be used for metrics computation.
nan_images = [is_missing_ground_truth(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]

# Compares element-wise if not binary then nan and checks all elements are True.
assert np.all(np.array(binary_classes) == ~np.array(nan_images))

# Validates that all binary images should be 0 or 1
if not np.all(np.array(binary_classes)[~np.array(nan_images)]):
raise ValueError("Ground truth values should be 0 or 1")
overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
metrics = MetricsDict(hues=ground_truth_ids)

def add_metric(metric_type: MetricType, value: float) -> None:
metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1])

for i, prediction in enumerate(binaries):
# Skip if background image
if i == 0:
continue
# Skip but record if nan_image
elif nan_images[i]:
add_metric(MetricType.DICE, NaN)
add_metric(MetricType.HAUSDORFF_mm, NaN)
add_metric(MetricType.MEAN_SURFACE_DIST_mm, NaN)
continue
check_size_matches(prediction, ground_truth[i], arg1_name="prediction", arg2_name="ground_truth")
if not is_binary_array(prediction):
raise ValueError("Predictions values should be 0 or 1")
Expand Down Expand Up @@ -280,10 +302,6 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
except Exception as e:
logging.warning(f"Cannot calculate mean distance for structure {i} of patient {patient_id}: {e}")
logging.debug(f"Patient {patient_id}, class {i} has Dice score {dice}")

def add_metric(metric_type: MetricType, value: float) -> None:
metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1])

add_metric(MetricType.DICE, dice)
add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance)
add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance)
Expand Down
Loading