Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Split validation and test infer config #502

Merged
merged 33 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6f5f73d
Simply split validation and test infer config
Jun 22, 2021
574c629
Try diff args for hyperdrive runs
Jun 22, 2021
2a45eb6
Merge branch 'main' into jontri/ensemble_infer
Jun 23, 2021
c6f3b80
Add flags for ensemble child inference
Jun 23, 2021
2f31aa1
Don't rewrite the args
Jun 23, 2021
900dae6
Test more run combinations
Jun 24, 2021
ad0f258
Clean up tests
Jun 24, 2021
2642f14
Exercise more cases of test runner
Jun 24, 2021
a59ba2d
Typo
Jun 24, 2021
3476e61
Typo2
Jun 24, 2021
adaef05
Use a dict instead of a tuple
Jun 25, 2021
d72ec8f
Merge with main
Jun 25, 2021
c38974d
Add type check
Jun 25, 2021
dff1863
Check metrics not none
Jun 25, 2021
e70f2ed
Reduce crop sizr
Jun 25, 2021
530187e
Test ensemble calls
Jun 25, 2021
aba326c
Run runner tests on smaller images
Jun 28, 2021
1be4224
Shorten options for inference#
Jun 28, 2021
995e85c
Make default None
Jun 28, 2021
7d01e54
Switch to option2
Jun 28, 2021
38a6e80
Use a dict
Jun 29, 2021
160b5fe
Reformat for flake
Jun 29, 2021
82915e4
try flake8 again
Jun 29, 2021
4007cb7
try flake8 again 2
Jun 29, 2021
ed5cea0
Better types for infer flags and fix pipeline
Jun 29, 2021
aed216d
Use a loop
Jun 29, 2021
5309499
Avoid getattr
Jun 29, 2021
d15c022
Update changelog
Jun 29, 2021
75e5e38
Reorder data loaders
Jun 29, 2021
474ea7c
Add some documentation
Jun 30, 2021
0553ac3
Reduce number of tests
Jun 30, 2021
a6e3961
Switch to the passthrough model for inference tests
Jul 2, 2021
df7b262
Merge branch 'main' into jontri/ensemble_infer
Jul 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ created.
## Upcoming

### Added
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
JonathanTripp marked this conversation as resolved.
Show resolved Hide resolved
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
jobs that run in AzureML.

### Changed
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) Renamed command line option 'perform_training_set_inference' to 'inference_on_train_set'. Replaced command line option 'perform_validation_and_test_set_inference' with the pair of options 'inference_on_val_set' and 'inference_on_test_set'.
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
- ([#497](https://github.com/microsoft/InnerEye-DeepLearning/pull/497)) Reducing the size of the code snapshot that
gets uploaded to AzureML, by skipping all test folders.
Expand Down
28 changes: 2 additions & 26 deletions InnerEye/Azure/azure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from azureml.train.hyperdrive import HyperDriveConfig
from git import Repo

from InnerEye.Azure.azure_util import fetch_run, is_offline_run_context
from InnerEye.Azure.azure_util import fetch_run, is_offline_run_context, remove_arg
from InnerEye.Azure.secrets_handling import SecretsHandling, read_all_settings
from InnerEye.Common import fixed_paths
from InnerEye.Common.generic_parsing import GenericConfig
Expand Down Expand Up @@ -324,31 +324,7 @@ def set_script_params_except_submit_flag(self) -> None:
Populates the script_param field of the present object from the arguments in sys.argv, with the exception
of the "azureml" flag.
"""
args = sys.argv[1:]
submit_flag = f"--{AZURECONFIG_SUBMIT_TO_AZUREML}"
retained_args = []
i = 0
while i < len(args):
arg = args[i]
if arg.startswith(submit_flag):
if len(arg) == len(submit_flag):
# The commandline argument is "--azureml", with something possibly following: This can either be
# "--azureml True" or "--azureml --some_other_param"
if i < (len(args) - 1):
# If the next argument starts with a "-" then assume that it does not belong to the --azureml
# flag. If there is no "-", assume it belongs to the --azureml flag, and skip both
if not args[i + 1].startswith("-"):
i = i + 1
elif arg[len(submit_flag)] == "=":
# The commandline argument is "--azureml=True" or "--azureml=False": Continue with next arg
pass
else:
# The argument list contains a flag like "--azureml_foo": Keep that.
retained_args.append(arg)
else:
retained_args.append(arg)
i = i + 1
self.script_params = retained_args
self.script_params = remove_arg(AZURECONFIG_SUBMIT_TO_AZUREML, sys.argv[1:])


@dataclass
Expand Down
40 changes: 40 additions & 0 deletions InnerEye/Azure/azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,43 @@ def step_up_directories(path: Path) -> Generator[Path, None, None]:
if parent == path:
break
path = parent


def remove_arg(arg: str, args: List[str]) -> List[str]:
"""
Remove an argument from a list of arguments. The argument list is assumed to contain
elements of the form:
"-a", "--arg1", "--arg2", "value2", or "--arg3=value"
If there is an item matching "--arg" then it will be removed from the list.

:param arg: Argument to look for.
:param args: List of arguments to scan.
:return: List of arguments with --arg removed, if present.
"""
arg_opt = f"--{arg}"
no_arg_opt = f"--no-{arg}"
retained_args = []
i = 0
while i < len(args):
arg = args[i]
if arg.startswith(arg_opt):
if len(arg) == len(arg_opt):
# The commandline argument is "--arg", with something possibly following: This can either be
# "--arg_opt value" or "--arg_opt --some_other_param"
if i < (len(args) - 1):
# If the next argument starts with a "-" then assume that it does not belong to the --arg
# argument. If there is no "-", assume it belongs to the --arg_opt argument, and skip both
if not args[i + 1].startswith("-"):
i = i + 1
elif arg[len(arg_opt)] == "=":
# The commandline argument is "--arg=value": Continue with next arg
pass
else:
# The argument list contains an argument like "--arg_other_param": Keep that.
retained_args.append(arg)
elif arg == no_arg_opt:
pass
else:
retained_args.append(arg)
i = i + 1
return retained_args
5 changes: 3 additions & 2 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def setup(self) -> None:
dataset_path=self.local_dataset,
batch_size=self.ssl_training_batch_size)})
self.data_module: InnerEyeDataModuleTypes = self.get_data_module()
self.perform_validation_and_test_set_inference = False
if self.number_of_cross_validation_splits > 1:
self.inference_on_val_set = False
self.inference_on_test_set = False
if self.perform_cross_validation:
raise NotImplementedError("Cross-validation logic is not implemented for this module.")

def _load_config(self) -> None:
Expand Down
82 changes: 69 additions & 13 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
from __future__ import annotations

import logging
from enum import Enum, unique
from pathlib import Path
from typing import Any, Dict, List, Optional

import param
from enum import Enum, unique
from pandas import DataFrame
from param import Parameterized
from pathlib import Path
from typing import Any, Dict, List, Optional

from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, RUN_CONTEXT, is_offline_run_context
from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import is_windows
from InnerEye.Common.common_util import ModelProcessing, is_windows
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.Common.type_annotations import PathOrString, TupleFloat2
Expand Down Expand Up @@ -199,14 +198,24 @@ class WorkflowParams(param.Parameterized):
cross_validation_split_index: int = param.Integer(DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, bounds=(-1, None),
doc="The index of the cross validation fold this model is "
"associated with when performing k-fold cross validation")
perform_training_set_inference: bool = \
param.Boolean(False,
doc="If True, run full image inference on the training set at the end of training. If False and "
"perform_validation_and_test_set_inference is True (default), only run inference on "
"validation and test set. If both flags are False do not run inference.")
perform_validation_and_test_set_inference: bool = \
param.Boolean(True,
doc="If True (default), run full image inference on validation and test set after training.")
inference_on_train_set: Optional[bool] = \
param.Boolean(None,
JonathanTripp marked this conversation as resolved.
Show resolved Hide resolved
doc="If set, enable/disable full image inference on training set after training.")
inference_on_val_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on validation set after training.")
inference_on_test_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on test set after training.")
ensemble_inference_on_train_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on the training set after ensemble training.")
ensemble_inference_on_val_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on validation set after ensemble training.")
ensemble_inference_on_test_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on test set after ensemble training.")
weights_url: str = param.String(doc="If provided, a url from which weights will be downloaded and used for model "
"initialization.")
local_weights_path: Optional[Path] = param.ClassSelector(class_=Path,
Expand Down Expand Up @@ -254,6 +263,53 @@ def validate(self) -> None:
f"found number_of_cross_validation_splits = {self.number_of_cross_validation_splits} "
f"and cross_validation_split_index={self.cross_validation_split_index}")

""" Defaults for when to run inference in the absence of any command line switches. """
INFERENCE_DEFAULTS: Dict[ModelProcessing, Dict[ModelExecutionMode, bool]] = {
ModelProcessing.DEFAULT: {
ModelExecutionMode.TRAIN: False,
ModelExecutionMode.TEST: True,
ModelExecutionMode.VAL: True,
},
ModelProcessing.ENSEMBLE_CREATION: {
ModelExecutionMode.TRAIN: False,
ModelExecutionMode.TEST: True,
ModelExecutionMode.VAL: False,
}
}

def inference_options(self) -> Dict[ModelProcessing, Dict[ModelExecutionMode, Optional[bool]]]:
"""
Return a mapping from ModelProcesing and ModelExecutionMode to command line switch.

:return: Command line switch for each combination of ModelProcessing and ModelExecutionMode.
"""
return {
ModelProcessing.DEFAULT: {
ModelExecutionMode.TRAIN: self.inference_on_train_set,
ModelExecutionMode.TEST: self.inference_on_test_set,
ModelExecutionMode.VAL: self.inference_on_val_set,
},
ModelProcessing.ENSEMBLE_CREATION: {
ModelExecutionMode.TRAIN: self.ensemble_inference_on_train_set,
ModelExecutionMode.TEST: self.ensemble_inference_on_test_set,
ModelExecutionMode.VAL: self.ensemble_inference_on_val_set,
}
}

def inference_on_set(self, model_proc: ModelProcessing, data_split: ModelExecutionMode) -> bool:
"""
Returns True if inference is required for this model_proc and data_split.

:param model_proc: Whether we are testing an ensemble or single model.
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
:return: True if inference required.
"""
inference_option = self.inference_options()[model_proc][data_split]
if inference_option is not None:
return inference_option

return WorkflowParams.INFERENCE_DEFAULTS[model_proc][data_split]

@property
def is_offline_run(self) -> bool:
"""
Expand Down
77 changes: 38 additions & 39 deletions InnerEye/ML/run_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import pandas as pd
from pytorch_lightning.core.datamodule import LightningDataModule
import stopit
import torch.multiprocessing
from azureml._restclient.constants import RunStatus
Expand Down Expand Up @@ -120,19 +121,16 @@ def download_dataset(azure_dataset_id: str,
return expected_dataset_path


def log_metrics(val_metrics: Optional[InferenceMetricsForSegmentation],
test_metrics: Optional[InferenceMetricsForSegmentation],
train_metrics: Optional[InferenceMetricsForSegmentation],
def log_metrics(metrics: Dict[ModelExecutionMode, InferenceMetrics],
run_context: Run) -> None:
"""
Log metrics for each split to the provided run, or the current run context if None provided
:param val_metrics: Inference results for the validation split
:param test_metrics: Inference results for the test split
:param train_metrics: Inference results for the train split
:param metrics: Dictionary of inference results for each split.
:param run_context: Run for which to log the metrics to, use the current run context if None provided
"""
for split in [x for x in [val_metrics, test_metrics, train_metrics] if x]:
split.log_metrics(run_context)
for split in metrics.values():
if isinstance(split, InferenceMetricsForSegmentation):
split.log_metrics(run_context)


class MLRunner:
Expand Down Expand Up @@ -390,7 +388,7 @@ def run(self) -> None:

# If this is an cross validation run, and the present run is child run 0, then wait for the sibling
# runs, build the ensemble model, and write a report for that.
if self.container.number_of_cross_validation_splits > 0:
if self.container.perform_cross_validation:
should_wait_for_other_child_runs = (not self.is_offline_run) and \
self.container.cross_validation_split_index == 0
if should_wait_for_other_child_runs:
Expand Down Expand Up @@ -420,10 +418,24 @@ def is_normal_run_or_crossval_child_0(self) -> bool:
"""
Returns True if the present run is a non-crossvalidation run, or child run 0 of a crossvalidation run.
"""
if self.container.number_of_cross_validation_splits > 0:
if self.container.perform_cross_validation:
return self.container.cross_validation_split_index == 0
return True

@staticmethod
def lightning_data_module_dataloaders(data: LightningDataModule) -> Dict[ModelExecutionMode, Callable]:
"""
Given a lightning data module, return a dictionary of dataloader for each model execution mode.

:param data: Lightning data module.
:return: Data loader for each model execution mode.
"""
return {
ModelExecutionMode.TEST: data.test_dataloader,
ModelExecutionMode.VAL: data.val_dataloader,
ModelExecutionMode.TRAIN: data.train_dataloader
}

def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> None:
"""
Run inference on the test set for all models that are specified via a LightningContainer.
Expand All @@ -439,11 +451,10 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No
# Read the data modules before changing the working directory, in case the code relies on relative paths
data = self.container.get_inference_data_module()
dataloaders: List[Tuple[DataLoader, ModelExecutionMode]] = []
if self.container.perform_validation_and_test_set_inference:
dataloaders.append((data.test_dataloader(), ModelExecutionMode.TEST)) # type: ignore
dataloaders.append((data.val_dataloader(), ModelExecutionMode.VAL)) # type: ignore
if self.container.perform_training_set_inference:
dataloaders.append((data.train_dataloader(), ModelExecutionMode.TRAIN)) # type: ignore
data_dataloaders = MLRunner.lightning_data_module_dataloaders(data)
for data_split, dataloader in data_dataloaders.items():
if self.container.inference_on_set(ModelProcessing.DEFAULT, data_split):
dataloaders.append((dataloader(), data_split))
checkpoint = load_checkpoint(checkpoint_paths[0], use_gpu=self.container.use_gpu)
lightning_model.load_state_dict(checkpoint['state_dict'])
lightning_model.eval()
Expand Down Expand Up @@ -491,8 +502,8 @@ def run_inference(self, checkpoint_handler: CheckpointHandler,
"""

# run full image inference on existing or newly trained model on the training, and testing set
test_metrics, val_metrics, _ = self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
model_proc=model_proc)
self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
model_proc=model_proc)

self.try_compare_scores_against_baselines(model_proc)

Expand Down Expand Up @@ -752,37 +763,25 @@ def copy_file(source: Path, destination_file: str) -> None:
def model_inference_train_and_test(self,
checkpoint_handler: CheckpointHandler,
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> \
Tuple[Optional[InferenceMetrics], Optional[InferenceMetrics], Optional[InferenceMetrics]]:
train_metrics = None
val_metrics = None
test_metrics = None
Dict[ModelExecutionMode, InferenceMetrics]:
metrics: Dict[ModelExecutionMode, InferenceMetrics] = {}

config = self.innereye_config

def run_model_test(data_split: ModelExecutionMode) -> Optional[InferenceMetrics]:
return model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler, # type: ignore
model_proc=model_proc)

if config.perform_validation_and_test_set_inference:
# perform inference on test set
test_metrics = run_model_test(ModelExecutionMode.TEST)
# perform inference on validation set (not for ensemble as current val is in the training fold
# for at least one of the models).
if model_proc != ModelProcessing.ENSEMBLE_CREATION:
val_metrics = run_model_test(ModelExecutionMode.VAL)

if config.perform_training_set_inference:
# perform inference on training set if required
train_metrics = run_model_test(ModelExecutionMode.TRAIN)
for data_split in ModelExecutionMode:
if self.container.inference_on_set(model_proc, data_split):
opt_metrics = model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler,
model_proc=model_proc)
if opt_metrics is not None:
metrics[data_split] = opt_metrics

# log the metrics to AzureML experiment if possible. When doing ensemble runs, log to the Hyperdrive parent run,
# so that we get the metrics of child run 0 and the ensemble separated.
if config.is_segmentation_model and not self.is_offline_run:
run_for_logging = PARENT_RUN_CONTEXT if model_proc.ENSEMBLE_CREATION else RUN_CONTEXT
log_metrics(val_metrics=val_metrics, test_metrics=test_metrics, # type: ignore
train_metrics=train_metrics, run_context=run_for_logging) # type: ignore
log_metrics(metrics=metrics, run_context=run_for_logging) # type: ignore

return test_metrics, val_metrics, train_metrics
return metrics

@stopit.threading_timeoutable()
def wait_for_runs_to_finish(self, delay: int = 60) -> None:
Expand Down
2 changes: 1 addition & 1 deletion Tests/ML/configs/lightning_test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DummyContainerWithModel(LightningContainer):

def __init__(self) -> None:
super().__init__()
self.perform_training_set_inference = True
self.inference_on_train_set = True
self.num_epochs = 50
self.l_rate = 1e-1

Expand Down
5 changes: 3 additions & 2 deletions Tests/ML/models/test_scalar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,9 @@ def test_run_ml_with_segmentation_model(test_output_dirs: OutputFolderForTests)
# This is for a bug in an earlier version of the code where the wrong execution mode was used to
# compute the expected mask size at training time.
config.test_crop_size = (75, 75, 75)
config.perform_training_set_inference = False
config.perform_validation_and_test_set_inference = True
config.inference_on_train_set = False
config.inference_on_val_set = True
config.inference_on_test_set = True
config.set_output_to(test_output_dirs.root_dir)
azure_config = get_default_azure_config()
azure_config.train = True
Expand Down
Loading