Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Allow cross validation with 'bring your own' Lightning models (without ensemble building) #483

Merged
merged 10 commits into from
Jun 24, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ multiple large checkpoints can time out.

### Added

- ([#483](https://github.com/microsoft/InnerEye-DeepLearning/pull/483)) Allow cross validation with 'bring your own' Lightning models (without ensemble building).
- ([#489](https://github.com/microsoft/InnerEye-DeepLearning/pull/489)) Remove portal query for outliers.
- ([#488](https://github.com/microsoft/InnerEye-DeepLearning/pull/488)) Better handling of missing seriesId in segmentation cross validation reports.
- ([#454](https://github.com/microsoft/InnerEye-DeepLearning/pull/454)) Checking that labels are mutually exclusive.
Expand Down
89 changes: 73 additions & 16 deletions InnerEye/ML/configs/other/HelloContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -12,6 +12,7 @@
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import StepLR, _LRScheduler
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold

from InnerEye.Common import fixed_paths
from InnerEye.ML.lightning_container import LightningContainer
Expand All @@ -30,16 +31,13 @@ class HelloDataset(Dataset):
# x = torch.rand((N, 1)) * 10
# y = 0.2 * x + 0.1 * torch.randn(x.size())
# xy = torch.cat((x, y), dim=1)
# np.savetxt("InnerEye/ML/configs/other/hellocontainer.csv", xy.numpy(), delimiter=",")
def __init__(self, root_folder: Path, start_index: int, end_index: int) -> None:
# np.savetxt("Tests/ML/test_data/hellocontainer.csv", xy.numpy(), delimiter=",")
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, raw_data: List) -> None:
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates the 1-dim regression dataset.
:param root_folder: The folder in which the data file lives ("hellocontainer.csv")
:param start_index: The first row to read.
:param end_index: The last row to read (exclusive)
:param raw_data: The raw data, e.g. from a cross validation split or loaded from file
"""
super().__init__()
raw_data = np.loadtxt(str(root_folder / "hellocontainer.csv"), delimiter=",")[start_index:end_index]
super().__init__()
self.data = torch.tensor(raw_data, dtype=torch.float)

def __len__(self) -> int:
Expand All @@ -48,17 +46,67 @@ def __len__(self) -> int:
def __getitem__(self, item: int) -> Dict[str, torch.Tensor]:
return {'x': self.data[item][0:1], 'y': self.data[item][1:2]}

@staticmethod
def from_path_and_indexes(
root_folder: Path,
start_index: int,
end_index: int) -> 'HelloDataset':
'''
Static method to instantiate a HelloDataset from the root folder with the start and end indexes.
:param root_folder: The folder in which the data file lives ("hellocontainer.csv")
:param start_index: The first row to read.
:param end_index: The last row to read (exclusive)
:return: A new instance based on the root folder and the start and end indexes.
'''
raw_data = np.loadtxt(root_folder / "hellocontainer.csv", delimiter=",")[start_index:end_index]
return HelloDataset(raw_data)


class HelloDataModule(LightningDataModule):
"""
A data module that gives the training, validation and test data for a simple 1-dim regression task.
If not using cross validation a basic 50% / 20% / 30% split between train, validation, and test data
is made on the whole dataset.
For cross validation (if required) we use k-fold cross-validation. The test set remains unchanged
while the training and validation data cycle through the k-folds of the remaining data.
"""

def __init__(self, root_folder: Path) -> None:
def __init__(
self,
root_folder: Path,
number_of_cross_validation_splits: int = 0,
cross_validation_split_index: int = 0) -> None:
super().__init__()
self.train = HelloDataset(root_folder, start_index=0, end_index=50)
self.val = HelloDataset(root_folder, start_index=50, end_index=70)
self.test = HelloDataset(root_folder, start_index=70, end_index=100)
if number_of_cross_validation_splits <= 1:
# For 0 or 1 splits just use the default values on the whole data-set.
self.train = HelloDataset.from_path_and_indexes(root_folder, start_index=0, end_index=50)
self.val = HelloDataset.from_path_and_indexes(root_folder, start_index=50, end_index=70)
self.test = HelloDataset.from_path_and_indexes(root_folder, start_index=70, end_index=100)
else:
# Raise exceptions for unreasonable values
if cross_validation_split_index >= number_of_cross_validation_splits:
raise IndexError(f"The cross_validation_split_index ({cross_validation_split_index}) is too large "
f"given the number_of_cross_validation_splits ({number_of_cross_validation_splits}) requested")
raw_data = np.loadtxt(root_folder / "hellocontainer.csv", delimiter=",")
np.random.seed(42)
np.random.shuffle(raw_data)
if number_of_cross_validation_splits >= len(raw_data):
raise ValueError(f"Asked for {number_of_cross_validation_splits} cross validation splits from a "
f"dataset of length {len(raw_data)}")
# Hold out the last 30% as test data
self.test = HelloDataset(raw_data[70:100])
# Create k-folds from the remaining 70% of the data-set. Use one for the validation
# data and the rest for the training data
raw_data_remaining = raw_data[0:70]
k_fold = KFold(n_splits=number_of_cross_validation_splits)
train_indexes, val_indexes = list(k_fold.split(raw_data_remaining))[cross_validation_split_index]
self.train = HelloDataset(raw_data_remaining[train_indexes])
self.val = HelloDataset(raw_data_remaining[val_indexes])

def prepare_data(self, *args: Any, **kwargs: Any) -> None:
pass

def setup(self, stage: Optional[str] = None) -> None:
pass

def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
return DataLoader(self.train, batch_size=5)
Expand Down Expand Up @@ -135,7 +183,7 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It returns the PyTorch optimizer(s) and learning rate scheduler(s) that should be used for training.
= """
"""
optimizer = Adam(self.parameters(), lr=1e-1)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
return [optimizer], [scheduler]
Expand Down Expand Up @@ -203,10 +251,19 @@ def create_model(self) -> LightningModule:
return HelloRegression()

# This method must be overridden by any subclass of LightningContainer. It returns a data module, which
# in turn contains 3 data loaders for training, validation, and test set.
# in turn contains 3 data loaders for training, validation, and test set.
#
# If the container is used for cross validation then this method must handle the cross validation splits.
# Because this deals with data loaders, not loaded data, we cannot check automatically that cross validation is
# handled correctly within the LightningContainer base class, i.e. if you forget to do the cross validation split
# in your subclass nothing will fail, but each child run will be identical since they will each be given the full
# dataset.
def get_data_module(self) -> LightningDataModule:
assert self.local_dataset is not None
return HelloDataModule(root_folder=self.local_dataset) # type: ignore
return HelloDataModule(
root_folder=self.local_dataset,
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
cross_validation_split_index=self.cross_validation_split_index) # type: ignore

# This is an optional override: This report creation method can read out any files that were written during
# training, and cook them into a nice looking report. Here, the report is a simple text file.
Expand Down
47 changes: 47 additions & 0 deletions InnerEye/ML/lightning_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from pytorch_lightning import LightningDataModule, LightningModule
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import GridParameterSampling, HyperDriveConfig, PrimaryMetricGoal, choice

from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.generic_parsing import GenericConfig, create_from_matching_params
from InnerEye.Common.metrics_constants import TrackedMetrics
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \
WorkflowParams, load_checkpoint
Expand Down Expand Up @@ -175,6 +179,9 @@ def get_data_module(self) -> LightningDataModule:
The format of the data is not specified any further.
The method must take cross validation into account, and ensure that logic to create training and validation
sets takes cross validation with a given number of splits is correctly taken care of.
Because the method deals with data loaders, not loaded data, we cannot check automatically that cross validation
is handled correctly within the base class, i.e. if the cross validation split is not handled in the method then
nothing will fail, but each child run will be identical since they will each be given the full dataset.
:return: A LightningDataModule
"""
return None # type: ignore
Expand All @@ -200,6 +207,12 @@ def get_trainer_arguments(self) -> Dict[str, Any]:
"""
return dict()

def get_parameter_search_hyperdrive_config(self, _: ScriptRunConfig) -> HyperDriveConfig: # type: ignore
"""
Parameter search is not implemented. It should be implemented in a sub class if needed.
"""
raise NotImplementedError("Parameter search is not implemented. It should be implemented in a sub class if needed.")

def create_report(self) -> None:
"""
This method is called after training and testing has been completed. It can aggregate all files that were
Expand Down Expand Up @@ -288,6 +301,40 @@ def create_lightning_module_and_store(self) -> None:
self._model._optimizer_params = create_from_matching_params(self, OptimizerParams)
self._model._trainer_params = create_from_matching_params(self, TrainerParams)

def get_cross_validation_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig:
"""
Returns a configuration for AzureML Hyperdrive that varies the cross validation split index.
:param run_config: The AzureML run configuration object that training for an individual model.
:return: A hyperdrive configuration object.
"""
return HyperDriveConfig(
run_config=run_config,
hyperparameter_sampling=self.get_cross_validation_hyperdrive_sampler(),
primary_metric_name=TrackedMetrics.Val_Loss.value,
dumbledad marked this conversation as resolved.
Show resolved Hide resolved
primary_metric_goal=PrimaryMetricGoal.MINIMIZE,
max_total_runs=self.number_of_cross_validation_splits
)

def get_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig:
"""
Returns the HyperDrive config for either parameter search or cross validation
(if number_of_cross_validation_splits > 1).
:param run_config: AzureML estimator
:return: HyperDriveConfigs
"""
if self.perform_cross_validation:
return self.get_cross_validation_hyperdrive_config(run_config)
else:
return self.get_parameter_search_hyperdrive_config(run_config)

def get_cross_validation_hyperdrive_sampler(self) -> GridParameterSampling:
"""
Returns the cross validation sampler, required to sample the entire parameter space for cross validation.
"""
return GridParameterSampling(parameter_space={
CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY: choice(list(range(self.number_of_cross_validation_splits))),
})
dumbledad marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
"""Returns a string describing the present object, as a list of key: value strings."""
arguments_str = "\nContainer:\n"
Expand Down
9 changes: 1 addition & 8 deletions InnerEye/ML/model_config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,6 @@ def create_model(self) -> Any:
# because this would prevent us from easily instantiating this class in tests.
raise NotImplementedError("create_model must be overridden")

def get_total_number_of_cross_validation_runs(self) -> int:
"""
Returns the total number of HyperDrive/offline runs required to sample the entire
cross validation parameter space.
"""
return self.number_of_cross_validation_splits

def get_cross_validation_hyperdrive_sampler(self) -> GridParameterSampling:
"""
Returns the cross validation sampler, required to sample the entire parameter space for cross validation.
Expand All @@ -176,7 +169,7 @@ def get_cross_validation_hyperdrive_config(self, run_config: ScriptRunConfig) ->
hyperparameter_sampling=self.get_cross_validation_hyperdrive_sampler(),
primary_metric_name=TrackedMetrics.Val_Loss.value,
primary_metric_goal=PrimaryMetricGoal.MINIMIZE,
max_total_runs=self.get_total_number_of_cross_validation_runs()
max_total_runs=self.number_of_cross_validation_splits
)

def get_cross_validation_dataset_splits(self, dataset_split: DatasetSplits) -> DatasetSplits:
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/run_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def are_sibling_runs_finished(self) -> bool:
"""
if (not self.is_offline_run) \
and (azure_util.is_cross_validation_child_run(RUN_CONTEXT)):
n_splits = self.innereye_config.get_total_number_of_cross_validation_runs()
n_splits = self.innereye_config.number_of_cross_validation_splits
child_runs = azure_util.fetch_child_runs(PARENT_RUN_CONTEXT,
expected_number_cross_validation_splits=n_splits)
pending_runs = [x.id for x in child_runs
Expand Down
8 changes: 3 additions & 5 deletions InnerEye/ML/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ def run(self) -> Tuple[Optional[DeepLearningConfig], Optional[Run]]:
user_agent.append(azure_util.INNEREYE_SDK_NAME, azure_util.INNEREYE_SDK_VERSION)
self.parse_and_load_model()
if self.lightning_container.perform_cross_validation:
if self.model_config is None:
raise NotImplementedError("Cross validation for LightingContainer models is not yet supported.")
# force hyperdrive usage if performing cross validation
self.azure_config.hyperdrive = True
run_object: Optional[Run] = None
Expand All @@ -219,14 +217,14 @@ def submit_to_azureml(self) -> Run:
if isinstance(self.model_config, DeepLearningConfig) and not self.lightning_container.azure_dataset_id:
raise ValueError("When running an InnerEye built-in model in AzureML, the 'azure_dataset_id' "
"property must be set.")
hyperdrive_func = lambda run_config: self.model_config.get_hyperdrive_config(run_config) # type: ignore
source_config = SourceConfig(
root_folder=self.project_root,
entry_script=Path(sys.argv[0]).resolve(),
conda_dependencies_files=get_all_environment_files(self.project_root),
hyperdrive_config_func=hyperdrive_func,
hyperdrive_config_func=(self.model_config.get_hyperdrive_config if self.model_config
else self.lightning_container.get_hyperdrive_config),
# For large jobs, upload of results can time out because of large checkpoint files. Default is 600
upload_timeout_seconds=86400,
upload_timeout_seconds=86400
)
source_config.set_script_params_except_submit_flag()
# Reduce the size of the snapshot by adding unused folders to amlignore. The Test* subfolders are only needed
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/visualizers/plot_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def crossval_config_from_model_config(train_config: DeepLearningConfig) -> PlotC
model_category=train_config.model_category,
epoch=epoch,
should_validate=False,
number_of_cross_validation_splits=train_config.get_total_number_of_cross_validation_runs())
number_of_cross_validation_splits=train_config.number_of_cross_validation_splits)


def get_config_and_results_for_offline_runs(train_config: DeepLearningConfig) -> OfflineCrossvalConfigAndFiles:
Expand Down
Loading