Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Enable store_dataset_sample (#525)
Browse files Browse the repository at this point in the history
* Enable store_dataset_sample

* Add missing property

* Enable test

* Add changelog
  • Loading branch information
javier-alvarez authored Jul 7, 2021
1 parent 3fc71b8 commit f3446c8
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 38 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs that run in AzureML.
gets uploaded to AzureML, by skipping all test folders.

### Fixed

- ([#525](https://github.com/microsoft/InnerEye-DeepLearning/pull/525)) Enable --store_dataset_sample
- ([#495](https://github.com/microsoft/InnerEye-DeepLearning/pull/495)) Fix model comparison.
- ([#482](https://github.com/microsoft/InnerEye-DeepLearning/pull/482)) Check bool parameter is either true or false.
- ([#475](https://github.com/microsoft/InnerEye-DeepLearning/pull/475)) Bug in AML SDK meant that we could not train
Expand Down
20 changes: 11 additions & 9 deletions InnerEye/ML/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,13 @@ class SegmentationModelBase(ModelConfigBase):
"`inside/outside body` information."
"This channel must be present in the dataset")

#: The type of image normalization that should be applied. Must be None, or of type
#: The type of image normalization that should be applied. Must be of type
# :attr:`PhotometricNormalizationMethod`: Unchanged, SimpleNorm, MriWindow , CtWindow, TrimmedNorm
norm_method: PhotometricNormalizationMethod = \
param.ClassSelector(default=PhotometricNormalizationMethod.CtWindow,
class_=PhotometricNormalizationMethod,
instantiate=False,
doc="The type of image normalization that should be applied. Must be one of None, "
doc="The type of image normalization that should be applied. Must be one of "
"Unchanged, SimpleNorm, MriWindow , CtWindow, TrimmedNorm")

#: The Window setting for the :attr:`PhotometricNormalizationMethod.CtWindow` normalization.
Expand Down Expand Up @@ -436,9 +436,9 @@ class SegmentationModelBase(ModelConfigBase):
"in the same order as in ground_truth_ids_display_names")

roi_interpreted_types: List[str] = param.List(None, class_=str, bounds=(1, None), instantiate=False,
allow_None=True,
doc="List of str with the ROI interpreted Types. Possible values "
"(None, CTV, ORGAN, EXTERNAL)")
allow_None=True,
doc="List of str with the ROI interpreted Types. Possible values "
"(None, CTV, ORGAN, EXTERNAL)")

interpreter: str = param.String("Default_Interpreter", doc="The interpreter that created the DICOM-RT file")

Expand Down Expand Up @@ -482,16 +482,18 @@ class SegmentationModelBase(ModelConfigBase):
"written to the outputs folder.")

#: If true an error is raised in InnerEye.ML.utils.io_util.load_labels_from_dataset_source if the labels are not
#: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others (e.g.
#: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others
# (e.g.
#: FocalLoss) will fail with a cryptic error message. Set to false if you are sure that you want to use labels that
#: are not mutually exclusive.
check_exclusive: bool = param.Boolean(True, doc="Raise an error if the segmentation labels are not mutually exclusive.")
check_exclusive: bool = param.Boolean(True,
doc="Raise an error if the segmentation labels are not mutually exclusive.")

allow_incomplete_labels: bool = param.Boolean(
default=False,
doc="If False, the default, then test patient data must include all of the ground truth labels. If true then "
"some test patient data with missing ground truth data is allowed and will be reflected in the patient "
"counts in the metrics and report.")
"some test patient data with missing ground truth data is allowed and will be reflected in the patient "
"counts in the metrics and report.")

def __init__(self, center_size: Optional[TupleInt3] = None,
inference_stride_size: Optional[TupleInt3] = None,
Expand Down
21 changes: 16 additions & 5 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import image_util, metrics_util, model_util
from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example
from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels
from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss, get_masked_model_outputs_and_labels

Expand All @@ -38,6 +39,7 @@ class SegmentationLightning(InnerEyeLightning):

def __init__(self, config: SegmentationModelBase, *args: Any, **kwargs: Any) -> None:
super().__init__(config, *args, **kwargs)
self.config = config
self.model = config.create_model()
self.loss_fn = model_util.create_segmentation_loss_function(config)
self.ground_truth_ids = config.ground_truth_ids
Expand Down Expand Up @@ -108,6 +110,8 @@ def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Ten
Computes and stores all metrics coming out of a single training step.
:param cropped_sample: The batched image crops used for training or validation.
:param segmentation: The segmentation that was produced by the model.
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
`validation_step`.
"""
# dice_per_crop_and_class has one row per crop, with background class removed
# Dice NaN means that both ground truth and prediction are empty.
Expand All @@ -133,11 +137,18 @@ def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Ten
self.storing_logger.train_diagnostics.append(center_indices)
else:
self.storing_logger.val_diagnostics.append(center_indices)
# if self.train_val_params.in_training_mode:
# # store the sample train patch from this epoch for visualization
# if batch_index == self.example_to_save and self.config.store_dataset_sample:
# _store_dataset_sample(self.config, self.train_val_params.epoch, forward_pass_result,
# cropped_sample)

if is_training and self.config.store_dataset_sample:
# store the sample train patch from this epoch for visualization
# remove batches and channels
dataset_example = DatasetExample(image=cropped_sample.image[0][0].cpu().detach().numpy(),
labels=cropped_sample.labels[0].cpu().detach().numpy(),
prediction=segmentation[0].cpu().detach().numpy(),
header=cropped_sample.metadata[0].image_header, # type: ignore
patient_id=cropped_sample.metadata[0].patient_id, # type: ignore
epoch=self.current_epoch)
store_and_upload_example(dataset_example, self.config)

num_subjects = cropped_sample.image.shape[0]
self.log_on_epoch(name=MetricType.SUBJECT_COUNT,
value=num_subjects,
Expand Down
20 changes: 6 additions & 14 deletions InnerEye/ML/utils/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -139,23 +139,15 @@ def __post_init__(self) -> None:


def store_and_upload_example(dataset_example: DatasetExample,
args: Optional[SegmentationModelBase] = None,
images_folder: Optional[Path] = None) -> None:
segmentation_config: SegmentationModelBase) -> None:
"""
Stores an example input and output of the network to Nifti files.
:param dataset_example: The dataset example, with image, label and prediction, that should be written.
:param args: configuration information to be used for normalization.
:param images_folder: The folder to which the result Nifti files should be written. If args is not None,
the args.example_images_folder is used instead.
:param segmentation_config: configuration information to be used for normalization and example_images_folder
"""

if images_folder is not None:
folder = images_folder
else:
folder = args.example_images_folder if args else Path("")
if folder != "" and not os.path.exists(folder):
os.mkdir(folder)
folder = segmentation_config.example_images_folder
os.makedirs(folder, exist_ok=True)

def create_file_name(suffix: str) -> str:
fn = "p" + str(dataset_example.patient_id) + "_e_" + str(dataset_example.epoch) + "_" + suffix + ".nii.gz"
Expand All @@ -165,7 +157,7 @@ def create_file_name(suffix: str) -> str:
io_util.store_image_as_short_nifti(image=dataset_example.image,
header=dataset_example.header,
file_name=create_file_name(suffix="image"),
args=args)
args=segmentation_config)

# merge multiple binary masks (one per class) into a single multi-label map image
labels = image_util.merge_masks(dataset_example.labels)
Expand Down
10 changes: 5 additions & 5 deletions Tests/ML/test_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_get_total_number_of_training_epochs() -> None:
c.recovery_start_epoch = 2
assert c.get_total_number_of_training_epochs() == 8


@pytest.mark.parametrize("image_channels", [["region"], ["random_123"]])
@pytest.mark.parametrize("ground_truth_ids", [["region", "region"], ["region", "other_region"]])
def test_invalid_model_train(test_output_dirs: OutputFolderForTests, image_channels: Any,
Expand Down Expand Up @@ -98,7 +99,7 @@ def _mean_list(lists: List[List[float]]) -> List[float]:
train_config.mask_id = None if no_mask_channel else train_config.mask_id
train_config.random_seed = 42
train_config.class_weights = [0.5, 0.25, 0.25]
train_config.store_dataset_sample = True
train_config.store_dataset_sample = no_mask_channel
train_config.recovery_checkpoint_save_interval = 1
train_config.check_exclusive = False

Expand Down Expand Up @@ -197,11 +198,10 @@ def assert_all_close(metric: str, expected: List[float], **kwargs: Any) -> None:
model_training_result.get_val_metric(MetricType.SECONDS_PER_BATCH.value)
model_training_result.get_train_metric(MetricType.SECONDS_PER_BATCH.value)

# Issue #372
# # Test for saving of example images
# assert train_config.example_images_folder.is_dir()
# example_files = list(train_config.example_images_folder.rglob("*.*"))
# assert len(example_files) == 3 * 2
assert train_config.example_images_folder.is_dir() if train_config.store_dataset_sample else True
example_files = list(train_config.example_images_folder.rglob("*.*"))
assert len(example_files) == (3 * 2 * 2 if train_config.store_dataset_sample else 0) # images x epochs x patients


def test_create_data_loaders() -> None:
Expand Down
11 changes: 7 additions & 4 deletions Tests/ML/utils/test_io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.config import PhotometricNormalizationMethod, SegmentationModelBase
from InnerEye.ML.dataset.sample import PatientDatasetSource, PatientMetadata
from InnerEye.ML.utils import io_util
from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example
Expand Down Expand Up @@ -191,10 +192,12 @@ def test_save_dataset_example(test_output_dirs: OutputFolderForTests) -> None:
labels=labels)

images_folder = test_output_dirs.root_dir
store_and_upload_example(dataset_sample, images_folder=images_folder)
image_from_disk = io_util.load_nifti_image(os.path.join(images_folder, "p2_e_1_image.nii.gz"))
labels_from_disk = io_util.load_nifti_image(os.path.join(images_folder, "p2_e_1_label.nii.gz"))
prediction_from_disk = io_util.load_nifti_image(os.path.join(images_folder, "p2_e_1_prediction.nii.gz"))
config = SegmentationModelBase(should_validate=False, norm_method=PhotometricNormalizationMethod.Unchanged)
config.set_output_to(images_folder)
store_and_upload_example(dataset_sample, config)
image_from_disk = io_util.load_nifti_image(os.path.join(config.example_images_folder, "p2_e_1_image.nii.gz"))
labels_from_disk = io_util.load_nifti_image(os.path.join(config.example_images_folder, "p2_e_1_label.nii.gz"))
prediction_from_disk = io_util.load_nifti_image(os.path.join(config.example_images_folder, "p2_e_1_prediction.nii.gz"))
assert image_from_disk.header.spacing == spacing
# When no photometric normalization is provided when saving, image is multiplied by 1000.
# It is then rounded to int64, but converted back to float when read back in.
Expand Down

0 comments on commit f3446c8

Please sign in to comment.