Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Added ability to run segmentation inference module in the test data without or partial ground truth files. #465

Merged
merged 51 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9c08fc6
Added initial unit test to evaluate model predictions.
May 21, 2021
f867c9f
Updated log PR #465.
May 24, 2021
8507ae7
Added unit test for segmentation inference in test data and no ground…
May 25, 2021
6137efb
Renamed variable from: "for_inference" to "allow_incomplete_labels".
May 25, 2021
2e8f173
Completed type annotation.
May 25, 2021
a500663
Extended unit test "test_evaluate_model_predictions" to account for m…
May 26, 2021
e9723e5
Improved inference support for missing labels.
asantamariapang May 30, 2021
73955f6
Revert "Improved inference support for missing labels."
asantamariapang Jun 1, 2021
5caf245
Updated missing ground truth labels and masks.
asantamariapang Jun 2, 2021
d703f2c
Merge branch 'main' into alberto/inference
asantamariapang Jun 3, 2021
77d3244
Fixed bug and improved documentation.
asantamariapang Jun 8, 2021
051681a
Merge branch 'main' into alberto/inference
asantamariapang Jun 8, 2021
fc6c9cc
Merge branch 'main' into alberto/inference
javier-alvarez Jun 15, 2021
652b6aa
WiP allowing NaNs in averaging to count them
Jun 24, 2021
899e648
Deeper testing in test_evaluate_model_predictions
Jun 25, 2021
2b46a44
Reverting CHANGELOG for now
Jun 25, 2021
94b90ab
Unused import
Jun 25, 2021
36f3994
WiP testing partial_ground_truth metrics output
Jun 25, 2021
03e530c
messy WiP with testing ground truths
dumbledad Jun 26, 2021
a79b1d0
Removing started small test data script
dumbledad Jun 26, 2021
fcacfcd
Adding labels for partial test
dumbledad Jun 27, 2021
75b1c37
WiP on partial ground truth unit test of model_test
dumbledad Jun 27, 2021
6f4bbd8
Unit test of partial ground truth works, but other fails :(
dumbledad Jun 27, 2021
713c2f4
mypy fixes
dumbledad Jun 27, 2021
554c6d2
tidy
dumbledad Jun 27, 2021
a7f7a19
flake fixes
dumbledad Jun 27, 2021
a821f9b
Documentation typos
Jun 28, 2021
45c4d2a
Adding allow_incomplete_labels to Seg'ModelBase
Jun 28, 2021
65b9df3
Merge branch 'main' into alberto/inference
dumbledad Jun 29, 2021
86a8be7
Checking that partial ground truth is not allowed unless explicit
dumbledad Jun 29, 2021
58ab13b
Added unit test of IPYNB-HTML for partial ground truth
Jun 30, 2021
a4402f7
flake fixes
Jun 30, 2021
5f68941
temp change for end2end test
dumbledad Jul 1, 2021
987234a
WSL package for .Net on linux
dumbledad Jul 1, 2021
b6241d2
Passing on allow_partial in lightning model base
dumbledad Jul 1, 2021
5f173f4
Reverting change made for end2end test only
dumbledad Jul 1, 2021
edb88d6
Fixing param comment
Jul 2, 2021
53ea4fc
Fixing unnecessary allow_incomplete, already in args
Jul 2, 2021
3a86cf9
Removing redundant isinstance
Jul 2, 2021
b47e197
Moving duplicated fragment to util function
Jul 2, 2021
615ae1a
missing bracket
dumbledad Jul 4, 2021
6a0bd3a
Clearer comments for save_aggregates_to_csv
dumbledad Jul 4, 2021
b3ed3f1
Unit test for InnerEyeContainer setup pass through
dumbledad Jul 4, 2021
da79dd3
flake fixes (I hope!)
dumbledad Jul 4, 2021
eadd7f9
changelog
dumbledad Jul 4, 2021
f9540ed
Odd fix, moved definition to enclosing scope
dumbledad Jul 4, 2021
5f4fe5b
removing Path(__file__)
Jul 5, 2021
411277c
using is_missing again
Jul 5, 2021
5a5f946
no partial images!
Jul 5, 2021
f1fd3c8
Merge branch 'main' into alberto/inference
Jul 5, 2021
96840e8
Merge branch 'main' into alberto/inference
Jul 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ created.
## Upcoming

### Added

- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Added ability to run segmentation inference
module in the test data without or partial ground truth files.
- ([#454](https://github.com/microsoft/InnerEye-DeepLearning/pull/454)) Checking that labels are mutually exclusive.
- ([#447](https://github.com/microsoft/InnerEye-DeepLearning/pull/447/)) Added a sanity check to ensure there are no
missing channels, nor missing files. If missing channels in the csv file or filenames associated with channels are
Expand Down
3 changes: 2 additions & 1 deletion InnerEye/ML/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,8 @@ def create_and_set_torch_datasets(self, for_training: bool = True, for_inference
mode: FullImageDataset(
self,
dataset_splits[mode],
full_image_sample_transforms=full_image_transforms.test) # type: ignore
full_image_sample_transforms=full_image_transforms.test, # type: ignore
allow_incomplete_labels=True)
for mode in ModelExecutionMode if len(dataset_splits[mode]) > 0
}

Expand Down
2 changes: 2 additions & 0 deletions InnerEye/ML/dataset/cropping_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def create_random_cropped_sample(sample: Sample,
mask_center_crop = image_util.get_center_crop(image=sample.mask, crop_shape=center_size)
labels_center_crop = np.zeros(shape=[len(sample.labels)] + list(center_size), # type: ignore
dtype=ImageDataType.SEGMENTATION.value)
assert sample.labels is not None
for c in range(len(sample.labels)): # type: ignore
labels_center_crop[c] = image_util.get_center_crop(
image=sample.labels[c],
Expand All @@ -120,6 +121,7 @@ def create_random_cropped_sample(sample: Sample,
image=sample.image,
mask=sample.mask,
labels=sample.labels,
missing_labels=sample.missing_labels,
mask_center_crop=mask_center_crop,
labels_center_crop=labels_center_crop,
center_indices=center_point,
Expand Down
40 changes: 26 additions & 14 deletions InnerEye/ML/dataset/full_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ class FullImageDataset(GeneralDataset):
"""

def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame,
full_image_sample_transforms: Optional[Compose3D[Sample]] = None):
full_image_sample_transforms: Optional[Compose3D[Sample]] = None,
allow_incomplete_labels: bool = False):
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(args, data_frame)
self.full_image_sample_transforms = full_image_sample_transforms

self.allow_incomplete_labels = allow_incomplete_labels
# Check base_path
assert self.args.local_dataset is not None
if not self.args.local_dataset.is_dir():
Expand Down Expand Up @@ -250,7 +251,8 @@ def _extension_from_df_file_paths(file_paths: List[str]) -> str:
def get_samples_at_index(self, index: int) -> List[Sample]:
# load the channels into memory
ds = self.dataset_sources[self.dataset_indices[index]]
samples = [io_util.load_images_from_dataset_source(dataset_source=ds, check_exclusive=self.args.check_exclusive)] # type: ignore
samples = [io_util.load_images_from_dataset_source(dataset_source=ds,
check_exclusive=self.args.check_exclusive)] # type: ignore
return [Compose3D.apply(self.full_image_sample_transforms, x) for x in samples]

def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
Expand All @@ -259,34 +261,40 @@ def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]:
local_dataset_root_folder=self.args.local_dataset,
image_channels=self.args.image_channels,
ground_truth_channels=self.args.ground_truth_ids,
mask_channel=self.args.mask_id
mask_channel=self.args.mask_id,
allow_incomplete_labels=self.allow_incomplete_labels
)


def convert_channels_to_file_paths(channels: List[str],
rows: pd.DataFrame,
local_dataset_root_folder: Path,
patient_id: str) -> Tuple[List[Path], str]:
patient_id: str,
allow_incomplete_labels: bool = False) -> Tuple[List[Optional[Path]], str]:
"""
Returns: 1) The full path for files specified in the training, validation and testing datasets, and
2) Missing channels or missing files.
asantamariapang marked this conversation as resolved.
Show resolved Hide resolved

:param allow_incomplete_labels: flag to enforce all ground truth labels
asantamariapang marked this conversation as resolved.
Show resolved Hide resolved
:param channels: channel type defined in the configuration file
:param rows: Input Pandas dataframe object containing subjectIds, path of local dataset, channel information
:param local_dataset_root_folder: Root directory which points to the local dataset
:param patient_id: string which contains subject identifier
"""
paths: List[Path] = []
paths: List[Optional[Path]] = []
failed_channel_info: str = ''

for channel_id in channels:
row = rows.loc[rows[CSV_CHANNEL_HEADER] == channel_id]
if len(row) == 0:
if len(row) == 0 and not allow_incomplete_labels:
failed_channel_info += f"Patient {patient_id} does not have channel '{channel_id}'" + os.linesep
elif len(row) == 0 and allow_incomplete_labels:
# Keeps track of missing channels order
paths.append(None)
asantamariapang marked this conversation as resolved.
Show resolved Hide resolved
elif len(row) > 1:
failed_channel_info += f"Patient {patient_id} has more than one entry for channel '{channel_id}'" + \
os.linesep
else:
elif len(row) == 1:
image_path = local_dataset_root_folder / row[CSV_PATH_HEADER].values[0]
if not image_path.is_file():
failed_channel_info += f"Patient {patient_id}, file {image_path} does not exist" + os.linesep
Expand All @@ -300,7 +308,8 @@ def load_dataset_sources(dataframe: pd.DataFrame,
local_dataset_root_folder: Path,
image_channels: List[str],
ground_truth_channels: List[str],
mask_channel: Optional[str]) -> Dict[str, PatientDatasetSource]:
mask_channel: Optional[str],
allow_incomplete_labels: bool = False) -> Dict[str, PatientDatasetSource]:
"""
Prepares a patient-to-images mapping from a dataframe read directly from a dataset CSV file.
The dataframe contains per-patient per-channel image information, relative to a root directory.
Expand All @@ -311,6 +320,7 @@ def load_dataset_sources(dataframe: pd.DataFrame,
:param image_channels: The names of the image channels that should be used in the result.
:param ground_truth_channels: The names of the ground truth channels that should be used in the result.
:param mask_channel: The name of the mask channel that should be used in the result. This can be None.
:param allow_incomplete_labels: Boolean variable to allow missing ground truth files.
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
:return: A dictionary mapping from an integer subject ID to a PatientDatasetSource.
"""
expected_headers = {CSV_SUBJECT_HEADER, CSV_PATH_HEADER, CSV_CHANNEL_HEADER}
Expand All @@ -328,16 +338,19 @@ def load_dataset_sources(dataframe: pd.DataFrame,
def get_mask_channel_or_default() -> Optional[Path]:
if mask_channel is None:
return None
paths = get_paths_for_channel_ids(channels=[mask_channel])
if len(paths) == 0:
return None
else:
return get_paths_for_channel_ids(channels=[mask_channel])[0]
return paths[0]

def get_paths_for_channel_ids(channels: List[str]) -> List[Path]:
def get_paths_for_channel_ids(channels: List[str]) -> List[Optional[Path]]:
asantamariapang marked this conversation as resolved.
Show resolved Hide resolved
if len(set(channels)) < len(channels):
raise ValueError(f"ids have duplicated entries: {channels}")
rows = dataframe.loc[dataframe[CSV_SUBJECT_HEADER] == patient_id]
# converts channels to paths and makes second sanity check for channel data
paths, failed_channel_info = convert_channels_to_file_paths(channels, rows, local_dataset_root_folder,
patient_id)
patient_id, allow_incomplete_labels)

if failed_channel_info:
raise ValueError(failed_channel_info)
Expand All @@ -351,7 +364,6 @@ def get_paths_for_channel_ids(channels: List[str]) -> List[Path]:
metadata=metadata,
image_channels=get_paths_for_channel_ids(channels=image_channels), # type: ignore
mask_channel=get_mask_channel_or_default(),
ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels) # type: ignore
)
ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels)) # type: ignore

return dataset_sources
16 changes: 8 additions & 8 deletions InnerEye/ML/dataset/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def get_dict(self) -> Dict[str, Any]:
class PatientDatasetSource(SampleBase):
"""
Dataset source locations for channels associated with a given patient in a particular dataset.
Please note that "ground_truth_channels" is optional.
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
"""
image_channels: List[PathOrString]
ground_truth_channels: List[PathOrString]
ground_truth_channels: List[Optional[PathOrString]]
mask_channel: Optional[PathOrString]
metadata: PatientMetadata

Expand All @@ -139,8 +140,6 @@ def __post_init__(self) -> None:

if not self.image_channels:
raise ValueError("image_channels cannot be empty")
if not self.ground_truth_channels:
raise ValueError("ground_truth_channels cannot be empty")


@dataclass(frozen=True)
Expand All @@ -153,19 +152,20 @@ class Sample(SampleBase):
image: Union[np.ndarray, torch.Tensor]
# (Batches if from data loader) x Z x Y x X
mask: Union[np.ndarray, torch.Tensor]
# (Batches if from data loader) x Classes x Z X Y x X
labels: Union[np.ndarray, torch.Tensor]
# (Batches if from data loader) x Classes x Z X Y x X, where the first class is background
labels: Optional[Union[np.ndarray, torch.Tensor]]
metadata: PatientMetadata
missing_labels: List[bool]
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
dumbledad marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self) -> None:
# make sure all properties are populated
common_util.check_properties_are_not_none(self)

ml_util.check_size_matches(arg1=self.image, arg2=self.mask,
matching_dimensions=self._get_matching_dimensions())

ml_util.check_size_matches(arg1=self.image, arg2=self.labels,
matching_dimensions=self._get_matching_dimensions())
if self.labels is not None:
ml_util.check_size_matches(arg1=self.image, arg2=self.labels,
matching_dimensions=self._get_matching_dimensions())

@property
def patient_id(self) -> int:
Expand Down
1 change: 1 addition & 0 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Ten
ground_truth=cropped_sample.labels_center_crop,
allow_multiple_classes_for_each_pixel=True)[:, 1:]
# Number of foreground voxels per class, across all crops
assert cropped_sample.labels is not None
foreground_voxels = metrics_util.get_number_of_voxels_per_class(cropped_sample.labels)[:, 1:]
# Store Dice and voxel count per sample in the minibatch. We need a custom aggregation logic for Dice
# because it can be NaN. Also use custom logging for voxel count because Lightning's batch-size weighted
Expand Down
54 changes: 36 additions & 18 deletions InnerEye/ML/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,46 +223,63 @@ def _add_zero_distances(num_segmented_surface_pixels: int, seg2ref_distance_map_

def calculate_metrics_per_class(segmentation: np.ndarray,
ground_truth: np.ndarray,
missing_labels: List[bool],
ground_truth_ids: List[str],
voxel_spacing: TupleFloat3,
patient_id: Optional[int] = None) -> MetricsDict:
"""
Calculate the dice for all foreground structures (the background class is completely ignored).
Returns a MetricsDict with metrics for each of the foreground
Calculate the dice for provided foreground structures (the background class is completely ignored).
Returns a MetricsDict with metrics values for provided foreground class
structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
:param ground_truth_ids: The names of all foreground classes.
:param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]. Note that the value of the 'C'
dimension is function on the provided ground truth channels. The minimal value for
C is 2: one background channel and one ground truth channel provided
:param missing_labels: list of booleans, if boolean variable is True, indicates that given channel was not provided
and length of list is number of all foreground classes
:param ground_truth_ids: The names of all foreground classes
:param voxel_spacing: voxel_spacing in 3D Z x Y x X
:param patient_id: for logging
"""
number_of_classes = ground_truth.shape[0]
if len(ground_truth_ids) != (number_of_classes - 1):
raise ValueError(f"Received {len(ground_truth_ids)} foreground class names, but "
f"the label tensor indicates that there are {number_of_classes - 1} classes.")
binaries = binaries_from_multi_label_array(segmentation, number_of_classes)

all_classes_are_binary = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]
if not np.all(all_classes_are_binary):
# For 'ground_truth', the expected C dimension is (Background Channel) + (Provided Ground Truth Channels)
# We can resolve the number of provided channels by subtracting the number of ground truth channels that were
# not provided from the number of classes
assert ground_truth is not None
assert ground_truth.shape[0] >= 2
num_classes_including_background = len(ground_truth_ids) + 1
if len(ground_truth_ids) - missing_labels.count(True) != (ground_truth.shape[0] - 1):
raise ValueError(f"Received {len(ground_truth_ids) - missing_labels.count(True)} foreground class names, but "
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
f"the label tensor indicates that there are {num_classes_including_background - 1} classes.")
binaries = binaries_from_multi_label_array(segmentation, num_classes_including_background)

# Note that: i) binary_classes >= 2 since we count background class and at least one ground truth image class,
# ii) binary_classes <= num_classes_including_background-1
binary_classes = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])]

# Validates that all binary images should be 0 or 1
if not np.all(binary_classes):
raise ValueError("Ground truth values should be 0 or 1")
overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
metrics = MetricsDict(hues=ground_truth_ids)

ground_truth_index_counter = 1
for i, prediction in enumerate(binaries):
# Skips if background image or nan_image
if i == 0:
continue
check_size_matches(prediction, ground_truth[i], arg1_name="prediction", arg2_name="ground_truth")
if not is_binary_array(prediction):
raise ValueError("Predictions values should be 0 or 1")
# simpleitk returns a Dice score of 0 if both ground truth and prediction are all zeros.
# Skips if ground truth channel was not provided
if missing_labels[i-1]:
continue
# We want to be able to fish out those cases, and treat them specially later.
prediction_zero = np.all(prediction == 0)
gt_zero = np.all(ground_truth[i] == 0)
gt_zero = np.all(ground_truth[ground_truth_index_counter] == 0)
dice = mean_surface_distance = hausdorff_distance = math.nan
if not (prediction_zero and gt_zero):
prediction_image = sitk.GetImageFromArray(prediction.astype(np.uint8))
prediction_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(voxel_spacing)))
ground_truth_image = sitk.GetImageFromArray(ground_truth[i].astype(np.uint8))
# Use 'ground_truth_index_counter' to index the 'C' dimension
ground_truth_image = sitk.GetImageFromArray(ground_truth[ground_truth_index_counter].astype(np.uint8))
ground_truth_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(voxel_spacing)))
overlap_measures_filter.Execute(prediction_image, ground_truth_image)
dice = overlap_measures_filter.GetDiceCoefficient()
Expand All @@ -287,6 +304,7 @@ def add_metric(metric_type: MetricType, value: float) -> None:
add_metric(MetricType.DICE, dice)
add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance)
add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance)
ground_truth_index_counter += 1
return metrics


Expand Down
20 changes: 15 additions & 5 deletions InnerEye/ML/model_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,32 @@ def evaluate_model_predictions(process_id: int,
:param results_folder: Path to results folder
:returns [PatientMetadata, list[list]]: Patient metadata and list of computed metrics for each image.
"""

sample = dataset.get_samples_at_index(index=process_id)[0]
assert sample.missing_labels is not None
if sample.labels is None:
asantamariapang marked this conversation as resolved.
Show resolved Hide resolved
logging.info(f"Ground truth label were not provided for patient {sample.patient_id}, skipping evaluation from "
f"predictions")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logging.info(f"Ground truth label were not provided for patient {sample.patient_id}, skipping evaluation from "
f"predictions")
logging.info(f"Ground truth labels were not provided for patient {sample.patient_id}, skipping evaluation.")

return sample.metadata, MetricsDict(hues=config.ground_truth_ids)

logging.info(f"Evaluating predictions for patient {sample.patient_id}")

patient_results_folder = get_patient_results_folder(results_folder, sample.patient_id)
segmentation = load_nifti_image(patient_results_folder / DEFAULT_RESULT_IMAGE_NAME).image
metrics_per_class = metrics.calculate_metrics_per_class(segmentation,
sample.labels,
sample.missing_labels,
ground_truth_ids=config.ground_truth_ids,
voxel_spacing=sample.image_spacing,
patient_id=sample.patient_id)
thumbnails_folder = results_folder / THUMBNAILS_FOLDER
thumbnails_folder.mkdir(exist_ok=True)
plotting.plot_contours_for_all_classes(sample,
segmentation=segmentation,
foreground_class_names=config.ground_truth_ids,
result_folder=thumbnails_folder,
image_range=config.output_range)
if sample.missing_labels.count(True) == 0:
plotting.plot_contours_for_all_classes(sample,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not draw contours only for the present classes?

segmentation=segmentation,
foreground_class_names=config.ground_truth_ids,
result_folder=thumbnails_folder,
image_range=config.output_range)
return sample.metadata, metrics_per_class


Expand Down
Loading