Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Downloading checkpoints from AML if not found on disk #614

Merged
merged 9 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs that run in AzureML.
- ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()`
hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`).
-([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module
-([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
-([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets


Expand Down
101 changes: 16 additions & 85 deletions InnerEye/ML/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations

import abc
import logging
import re
from datetime import datetime
from enum import Enum, unique
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from typing import Any, Dict, List

DATASET_CSV_FILE_NAME = "dataset.csv"
CHECKPOINT_SUFFIX = ".ckpt"
Expand All @@ -26,6 +24,13 @@
LAST_CHECKPOINT_FILE_NAME = "last"
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = LAST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX

FINAL_MODEL_FOLDER = "final_model"
FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"
CHECKPOINT_FOLDER = "checkpoints"
VISUALIZATION_FOLDER = "visualizations"
EXTRA_RUN_SUBFOLDER = "extra_run_id"
ARGS_TXT = "args.txt"


@unique
class ModelExecutionMode(Enum):
Expand Down Expand Up @@ -64,18 +69,14 @@ def get_feature_length(self, column: str) -> int:
raise NotImplementedError("get_feature_length must be implemented by sub classes")


def get_recovery_checkpoint_path(path: Path) -> Path:
def create_unique_timestamp_id() -> str:
"""
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
FileNotFoundError if no
recovery checkpoint file is present.
:param path: Path to checkpoint folder
Creates a unique string using the current time in UTC, up to seconds precision, with characters that
are suitable for use in filenames. For example, on 31 Dec 2019 at 11:59:59pm UTC, the result would be
2019-12-31T235959Z.
"""
recovery_ckpt_and_epoch = find_recovery_checkpoint_and_epoch(path)
if recovery_ckpt_and_epoch is not None:
return recovery_ckpt_and_epoch[0]
files = list(path.glob("*"))
raise FileNotFoundError(f"No checkpoint files found in {path}. Existing files: {' '.join(p.name for p in files)}")
unique_id = datetime.utcnow().strftime("%Y-%m-%dT%H%M%SZ")
return unique_id


def get_best_checkpoint_path(path: Path) -> Path:
Expand All @@ -84,73 +85,3 @@ def get_best_checkpoint_path(path: Path) -> Path:
:param path to checkpoint folder
"""
return path / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX


def find_all_recovery_checkpoints(path: Path) -> Optional[List[Path]]:
"""
Extracts all file starting with RECOVERY_CHECKPOINT_FILE_NAME in path
:param path:
:return:
"""
all_recovery_files = [f for f in path.glob(RECOVERY_CHECKPOINT_FILE_NAME + "*")]
if len(all_recovery_files) == 0:
return None
return all_recovery_files


PathAndEpoch = Tuple[Path, int]


def extract_latest_checkpoint_and_epoch(available_files: List[Path]) -> PathAndEpoch:
"""
Checkpoints are saved as recovery_epoch={epoch}.ckpt, find the latest ckpt and epoch number.
:param available_files: all available checkpoints
:return: path the checkpoint from latest epoch and epoch number
"""
recovery_epochs = [int(re.findall(r"[\d]+", f.stem)[0]) for f in available_files]
idx_max_epoch = int(np.argmax(recovery_epochs))
return available_files[idx_max_epoch], recovery_epochs[idx_max_epoch]


def find_recovery_checkpoint_and_epoch(path: Path) -> Optional[PathAndEpoch]:
"""
Looks at all the recovery files, extracts the epoch number for all of them and returns the most recent (latest
epoch)
checkpoint path along with the corresponding epoch number. If no recovery checkpoint are found, return None.
:param path: The folder to start searching in.
:return: None if there is no file matching the search pattern, or a Tuple with Path object and integer pointing to
recovery checkpoint path and recovery epoch.
"""
available_checkpoints = find_all_recovery_checkpoints(path)
if available_checkpoints is not None:
return extract_latest_checkpoint_and_epoch(available_checkpoints)
return None


def create_best_checkpoint(path: Path) -> Path:
"""
Creates the best checkpoint file. "Best" is at the moment defined as being the last checkpoint, but could be
based on some defined policy.
The best checkpoint will be renamed to `best_checkpoint.ckpt`.
:param path: The folder that contains all checkpoint files.
"""
logging.debug(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
last_ckpt = path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
all_files = f"Existing files: {' '.join(p.name for p in path.glob('*'))}"
if not last_ckpt.is_file():
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
logging.info(f"Using {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} as the best checkpoint: Renaming to "
f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}")
best = path / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last_ckpt.rename(best)
return best


def create_unique_timestamp_id() -> str:
"""
Creates a unique string using the current time in UTC, up to seconds precision, with characters that
are suitable for use in filenames. For example, on 31 Dec 2019 at 11:59:59pm UTC, the result would be
2019-12-31T235959Z.
"""
unique_id = datetime.utcnow().strftime("%Y-%m-%dT%H%M%SZ")
return unique_id
19 changes: 4 additions & 15 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,9 @@
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.Common.type_annotations import PathOrString, T, TupleFloat2
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, create_unique_timestamp_id, \
get_best_checkpoint_path, get_recovery_checkpoint_path

# A folder inside of the outputs folder that will contain all information for running the model in inference mode

FINAL_MODEL_FOLDER = "final_model"
FINAL_ENSEMBLE_MODEL_FOLDER = "final_ensemble_model"

# The checkpoints must be stored inside of the final model folder, if we want to avoid copying
# them before registration.
CHECKPOINT_FOLDER = "checkpoints"
VISUALIZATION_FOLDER = "visualizations"
EXTRA_RUN_SUBFOLDER = "extra_run_id"

ARGS_TXT = "args.txt"
from InnerEye.ML.common import CHECKPOINT_FOLDER, DATASET_CSV_FILE_NAME, \
ModelExecutionMode, VISUALIZATION_FOLDER, \
create_unique_timestamp_id, get_best_checkpoint_path


@unique
Expand Down Expand Up @@ -487,6 +475,7 @@ def get_path_to_checkpoint(self) -> Path:
"""
Returns the full path to a recovery checkpoint.
"""
from InnerEye.ML.utils.checkpoint_handling import get_recovery_checkpoint_path
return get_recovery_checkpoint_path(self.checkpoint_folder)

def get_path_to_best_checkpoint(self) -> Path:
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/lightning_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
WorkflowParams
from InnerEye.ML.utils import model_util
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
from InnerEye.ML.utils.run_recovery import RunRecovery


class InnerEyeInference(abc.ABC):
Expand Down Expand Up @@ -151,7 +150,8 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._model: Optional[LightningModule] = None
self._model_name = type(self).__name__
self.pretraining_run_checkpoints: Optional[RunRecovery] = None
# This should be typed RunRecovery, but causes circular imports
self.pretraining_run_checkpoints: Optional[Any] = None
self.num_nodes = 1

def validate(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, change_working_directory
from InnerEye.Common.resource_monitor import ResourceMonitor
from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, create_best_checkpoint
from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER
from InnerEye.ML.common import ARGS_TXT, ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, VISUALIZATION_FOLDER
from InnerEye.ML.utils.checkpoint_handling import create_best_checkpoint
from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.lightning_loggers import StoringLogger
Expand Down
3 changes: 1 addition & 2 deletions InnerEye/ML/normalize_and_visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from InnerEye.Common.common_util import logging_to_stdout
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML import plotting
from InnerEye.ML.common import DATASET_CSV_FILE_NAME
from InnerEye.ML.common import ARGS_TXT, DATASET_CSV_FILE_NAME
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.full_image_dataset import load_dataset_sources
from InnerEye.ML.deep_learning_config import ARGS_TXT
from InnerEye.ML.photometric_normalization import PhotometricNormalization
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.utils.io_util import load_images_from_dataset_source
Expand Down
19 changes: 10 additions & 9 deletions InnerEye/ML/run_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, PYTHON_ENVIRONMENT_NAME
from InnerEye.Common.type_annotations import PathOrString
from InnerEye.ML.baselines_util import compare_folders_and_run_outputs
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.common import CHECKPOINT_FOLDER, EXTRA_RUN_SUBFOLDER, FINAL_ENSEMBLE_MODEL_FOLDER, \
FINAL_MODEL_FOLDER, \
ModelExecutionMode
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, DeepLearningConfig, EXTRA_RUN_SUBFOLDER, \
FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint
from InnerEye.ML.deep_learning_config import DeepLearningConfig, ModelCategory, MultiprocessingStartMethod, \
load_checkpoint
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
from InnerEye.ML.lightning_loggers import StoringLogger
Expand All @@ -53,8 +55,7 @@
get_ipynb_report_name, reports_folder
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from InnerEye.ML.utils.run_recovery import RunRecovery
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler, download_all_checkpoints_from_run
from InnerEye.ML.visualizers import activation_maps
from InnerEye.ML.visualizers.plot_cross_validation import \
get_config_and_results_for_offline_runs, plot_cross_validation_from_files
Expand Down Expand Up @@ -183,10 +184,10 @@ def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
if self.container.pretraining_run_recovery_id is not None:
run_to_recover = self.azure_config.fetch_run(self.container.pretraining_run_recovery_id.strip())
only_return_path = not is_global_rank_zero()
run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.container,
run_to_recover,
EXTRA_RUN_SUBFOLDER,
only_return_path=only_return_path)
run_recovery_object = download_all_checkpoints_from_run(self.container,
run_to_recover,
EXTRA_RUN_SUBFOLDER,
only_return_path=only_return_path)
self.container.pretraining_run_checkpoints = run_recovery_object

# A lot of the code for the built-in InnerEye models expects the output paths directly in the config files.
Expand Down
Loading