Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Multi-node checkpoint recovery fix #478

Merged
merged 23 commits into from
Jun 10, 2021
4 changes: 2 additions & 2 deletions InnerEye/ML/run_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.model_inference_config import ModelInferenceConfig
from InnerEye.ML.model_testing import model_test
from InnerEye.ML.model_training import create_lightning_trainer, model_train
from InnerEye.ML.model_training import create_lightning_trainer, is_global_rank_zero, model_train
from InnerEye.ML.reports.notebook_report import generate_classification_crossval_notebook, \
generate_classification_multilabel_notebook, generate_classification_notebook, generate_segmentation_notebook, \
get_ipynb_report_name, reports_folder
Expand Down Expand Up @@ -222,7 +222,7 @@ def setup(self, use_mount_or_download_dataset: bool = True) -> None:
azure_config=self.azure_config,
project_root=self.project_root,
run_context=RUN_CONTEXT)
self.checkpoint_handler.download_recovery_checkpoints_or_weights()
self.checkpoint_handler.download_recovery_checkpoints_or_weights(only_return_path=not is_global_rank_zero())

# A lot of the code for the built-in InnerEye models expects the output paths directly in the config files.
if isinstance(self.container, InnerEyeContainer):
Expand Down
11 changes: 8 additions & 3 deletions InnerEye/ML/utils/checkpoint_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,28 @@ def download_checkpoints_from_hyperdrive_child_runs(self, hyperdrive_parent_run:
if not path.is_dir():
raise NotADirectoryError(f"Does not exist or is not a directory: {path}")

def download_recovery_checkpoints_or_weights(self) -> None:
def download_recovery_checkpoints_or_weights(self, only_return_path: bool = False) -> None:
"""
Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
run_recovery_object, weights_url or local_weights_path.
This is called at the start of training.
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
"""
if self.azure_config.run_recovery_id:
run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip())
self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover)
self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover,
only_return_path=only_return_path)
else:
self.run_recovery = None

if self.azure_config.pretraining_run_recovery_id is not None:
run_to_recover = self.azure_config.fetch_run(self.azure_config.pretraining_run_recovery_id.strip())
run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.output_params,
run_to_recover,
EXTRA_RUN_SUBFOLDER)
EXTRA_RUN_SUBFOLDER,
only_return_path=only_return_path)
self.container.extra_downloaded_run_id = run_recovery_object
else:
self.container.extra_downloaded_run_id = None
Expand Down
17 changes: 11 additions & 6 deletions InnerEye/ML/utils/run_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,30 @@ def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) ->

@staticmethod
def download_all_checkpoints_from_run(config: OutputParams, run: Run,
subfolder: Optional[str] = None) -> RunRecovery:
subfolder: Optional[str] = None,
only_return_path: bool = False) -> RunRecovery:
"""
Downloads all checkpoints of the provided run inside the checkpoints folder.
:param config: Model related configs.
:param run: Run whose checkpoints should be recovered
:param subfolder: optional subfolder name, if provided the checkpoints will be downloaded to
CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run.
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
:return: run recovery information
"""
if fetch_child_runs(run):
raise ValueError(f"AzureML run {run.id} has child runs, this method does not support those.")

destination_folder = config.checkpoint_folder / subfolder if subfolder else config.checkpoint_folder

download_outputs_from_run(
blobs_path=Path(CHECKPOINT_FOLDER),
destination=destination_folder,
run=run
)
if not only_return_path:
download_outputs_from_run(
blobs_path=Path(CHECKPOINT_FOLDER),
destination=destination_folder,
run=run
)
time.sleep(60) # Needed because AML is not fast enough to download
return RunRecovery(checkpoints_roots=[destination_folder])

Expand Down
38 changes: 37 additions & 1 deletion Tests/AfterTraining/test_after_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
from pathlib import Path
from typing import List
from unittest import mock

import numpy as np
import pytest
Expand All @@ -29,7 +30,7 @@
from InnerEye.Common.common_util import CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME, get_best_epoch_results_path
from InnerEye.Common.fixed_paths import DEFAULT_AML_LOGS_DIR, DEFAULT_RESULT_IMAGE_NAME, \
DEFAULT_RESULT_ZIP_DICOM_NAME, \
PYTHON_ENVIRONMENT_NAME
PYTHON_ENVIRONMENT_NAME, repository_root_directory
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.spawn_subprocess import spawn_and_monitor_subprocess
Expand All @@ -38,6 +39,7 @@
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, ModelCategory
from InnerEye.ML.model_inference_config import read_model_inference_config
from InnerEye.ML.reports.notebook_report import get_html_report_name
from InnerEye.ML.runner import main
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.utils.image_util import get_unit_image_header
from InnerEye.ML.utils.io_util import zip_random_dicom_series
Expand Down Expand Up @@ -351,3 +353,37 @@ def test_training_2nodes(test_output_dirs: OutputFolderForTests) -> None:
assert "initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/4" in log0_txt
assert "initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/4" in log1_txt
assert "initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/4" in log1_txt


@pytest.mark.after_training_2node
def test_recovery_on_2_nodes(test_output_dirs: OutputFolderForTests) -> None:
args_list = ["--model", "BasicModel2EpochsMoreData",
"--azureml", "True",
"--num_nodes", "2",
"--run_recovery_id",
str(get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_2NODE_RUN)),
"--num_epochs", "4",
"--wait_for_completion", "True"
]
script = str(repository_root_directory() / "InnerEye" / "ML" / "runner.py")
with mock.patch("sys.argv", [script] + args_list):
main()
run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_2NODE_RUN)
assert run.status == RunStatus.COMPLETED
files = run.get_file_names()
# There are two nodes, so there should be one log file per node.
log0_path = "azureml-logs/70_driver_log_0.txt"
log1_path = "azureml-logs/70_driver_log_1.txt"
assert log0_path in files, "Node rank 0 log file is missing"
assert log1_path in files, "Node rank 1 log file is missing"
# Download both log files and check their contents
log0 = test_output_dirs.root_dir / log0_path
log1 = test_output_dirs.root_dir / log1_path
run.download_file(log0_path, output_file_path=str(log0))
run.download_file(log1_path, output_file_path=str(log1))
log0_txt = log0.read_text()
log1_txt = log1.read_text()
assert "Downloading multiple files from run" in log0_txt
assert "Downloading multiple files from run" not in log1_txt
assert "Loading checkpoint that was created at (epoch = 2, global_step = 2)" in log0_txt
assert "Loading checkpoint that was created at (epoch = 2, global_step = 2)" in log1_txt