-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR/Train] Add Trainer.restore
API for train experiment-level fault tolerance
#31920
Changes from 55 commits
36b607c
864f2b1
2af61de
fd93dae
bbbf5fa
0e10972
ded1d27
d12e155
0a4bec3
a9f10bb
149863b
4ace2c9
73d1fde
847978e
e481610
cfb3859
4f3caf1
d180135
d5ab478
41b2e90
5de0b9a
56c80e4
8c59952
202ea18
1cb225b
8638d75
484a873
b98ae79
61af0ad
396fa43
4bbd79c
7721ac8
a486c4e
30393d2
2ce1c31
ecbcd1a
5c03050
74d99cc
3ac12e2
964ebe5
f55f5cd
d885f15
b070b01
e2ff5a9
9adcd38
1f9db34
b3d8d0e
0433128
ea122ad
8abe4a7
60cff3f
62aed3e
3945c70
f7c341a
72cf533
28913f3
5080cb3
543b2a8
8fabbb7
c5eeb53
1d1e9ba
ca2a476
c2e24b0
5a75eaa
63a7dc3
1ed3fc2
9c0d8a8
43153db
3f17166
1f8132e
334a69a
9f6e04c
0529739
410a34f
d1c2e11
4196ce3
4e1f532
1a415bc
d865435
e1b41e8
859ecf3
da94da8
5c5c009
a5fde62
4b53d64
726cf81
ae7568e
8b4a876
a7e4336
61bb672
5a62a5e
f47970f
94058a0
4589b52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,11 +2,21 @@ | |
import copy | ||
import inspect | ||
import logging | ||
import os | ||
from pathlib import Path | ||
import tempfile | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union | ||
|
||
import ray | ||
import ray.cloudpickle as pickle | ||
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated | ||
from ray.air._internal.remote_storage import ( | ||
download_from_uri, | ||
is_non_local_path_uri, | ||
list_at_uri, | ||
) | ||
from ray.air.checkpoint import Checkpoint | ||
from ray.air import session | ||
from ray.air.config import RunConfig, ScalingConfig | ||
from ray.air.result import Result | ||
from ray.train.constants import TRAIN_DATASET_KEY | ||
|
@@ -20,6 +30,8 @@ | |
|
||
from ray.tune import Trainable | ||
|
||
_TRAINER_PKL = "trainer.pkl" | ||
|
||
# A type representing either a ray.data.Dataset or a function that returns a | ||
# ray.data.Dataset and accepts no arguments. | ||
GenDataset = Union["Dataset", Callable[[], "Dataset"]] | ||
|
@@ -164,8 +176,169 @@ def __init__( | |
self.preprocessor = preprocessor | ||
self.resume_from_checkpoint = resume_from_checkpoint | ||
|
||
# This path should only be set through restore | ||
self._restore_path = None | ||
|
||
self._validate_attributes() | ||
|
||
@classmethod | ||
def restore( | ||
cls: Type["BaseTrainer"], | ||
path: str, | ||
datasets: Optional[Dict[str, GenDataset]] = None, | ||
preprocessor: Optional["Preprocessor"] = None, | ||
scaling_config: Optional[ScalingConfig] = None, | ||
**kwargs, | ||
) -> "BaseTrainer": | ||
"""Restores a Train experiment from a previously interrupted/failed run. | ||
|
||
Restore should be used for experiment-level fault tolerance in the event | ||
that the head node crashes (e.g. OOM or some other runtime error) or the | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
entire cluster goes down (e.g. network error affecting all nodes). | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The following example can be paired with implementing job retry using | ||
:ref:`Ray Jobs <jobs-overview>` to produce a Train experiment that will | ||
attempt to resume on both experiment-level and trial-level failures: | ||
|
||
.. code-block:: python | ||
|
||
import os | ||
from ray.train.torch import TorchTrainer | ||
from ray import tune | ||
|
||
experiment_name = "unique_experiment_name" | ||
upload_dir = "s3://bucket" | ||
experiment_dir = os.path.join(upload_dir, experiment_name) | ||
|
||
# Pretend this is a large object that's been loaded | ||
large_data = {} | ||
# Use an object reference to share this object across the cluster | ||
large_data_ref = ray.put(large_dataset) | ||
|
||
datasets = {"train": ray.data.from_items([{"a": i} for i in range(10)])} | ||
|
||
def train_loop_per_worker(config): | ||
pass | ||
|
||
train_loop_config = {"obj_ref": large_data_ref} | ||
|
||
if TorchTrainer.can_restore(experiment_dir): | ||
trainer = TorchTrainer.restore( | ||
experiment_dir, | ||
train_loop_per_worker=train_loop_per_worker, | ||
train_loop_config=train_loop_config, | ||
datasets=datasets, | ||
) | ||
else: | ||
trainer = TorchTrainer( | ||
train_loop_per_worker, | ||
train_loop_config, | ||
datasets=datasets, | ||
run_config=air.RunConfig( | ||
name=experiment_name, | ||
sync_config=tune.SyncConfig(upload_dir=upload_dir), | ||
# Tip: Add trial-level fault-tolerance on top. | ||
failure_config=air.FailureConfig(max_failures=3), | ||
), | ||
) | ||
|
||
result = trainer.fit() | ||
|
||
|
||
Args: | ||
path: The path to the experiment directory of the training run to restore. | ||
This can be a local path or a remote URI if the experiment was | ||
uploaded to the cloud. | ||
datasets: Re-specified datasets used in the original training run. | ||
This must include all the datasets that were passed in the | ||
original trainer constructor. | ||
preprocessor: Optionally re-specified preprocessor that was passed in | ||
the original trainer constructor. This should be used to re-supply | ||
the preprocessor if it is not restorable in a new Ray cluster. | ||
This preprocessor will be fit at the start before resuming training. | ||
If no preprocessor is passed in restore, then the old preprocessor | ||
will be loaded from the latest checkpoint and will not be re-fit. | ||
scaling_config: Optionally re-specified scaling config. This can be | ||
modified to be different from the original spec. | ||
**kwargs: Other optionally re-specified arguments, passed in by subclasses. | ||
|
||
Raises: | ||
ValueError: If all datasets were not re-supplied on restore. | ||
|
||
Returns: | ||
BaseTrainer: A restored instance of the class that is calling this method. | ||
""" | ||
assert cls.can_restore(path), ( | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"Invalid restore path: {path}. Make sure that this path exists and " | ||
"is the experiment directory that results from a call to `trainer.fit()`." | ||
) | ||
trainer_state_path = cls._maybe_sync_down_trainer_state(path) | ||
assert ( | ||
trainer_state_path.exists() | ||
), f"Did not find trainer state at {str(trainer_state_path)}." | ||
|
||
with open(trainer_state_path, "rb") as fp: | ||
original_trainer = pickle.load(fp) | ||
assert type(original_trainer) == cls, ( | ||
f"Invalid trainer type. Cannot restore a trainer of type " | ||
f"{type(original_trainer)} with `{cls.__name__}.restore`. " | ||
f"Use `{type(original_trainer).__name__}.restore` instead." | ||
) | ||
|
||
# Get the param dict used to initialize the original trainer | ||
param_dict = original_trainer._param_dict | ||
|
||
original_datasets = original_trainer.datasets or {} | ||
if original_datasets and not datasets: | ||
raise ValueError( | ||
"The following datasets need to be provided again on restore: " | ||
f"{list(original_datasets.keys())}\n" | ||
f"Use {cls.__name__}.restore(..., datasets=datasets) " | ||
"with the datasets that were provided to the original trainer." | ||
) | ||
datasets = datasets or {} | ||
assert set(original_datasets.keys()) == set(datasets.keys()), ( | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"The provided datasets don't match the original dataset keys.\n" | ||
f" Expected datasets for the keys: {list(original_datasets.keys())}\n" | ||
f" Actual datasets provided: {list(datasets.keys())}" | ||
) | ||
param_dict["datasets"] = datasets | ||
|
||
# If no preprocessor is re-specified, then it will be set to None | ||
# here and loaded from the latest checkpoint | ||
param_dict["preprocessor"] = preprocessor | ||
|
||
if scaling_config: | ||
param_dict["scaling_config"] = scaling_config | ||
|
||
for param_name, val in kwargs.items(): | ||
# Overwrite the old value if something is passed into restore | ||
if val is not None: | ||
param_dict[param_name] = val | ||
|
||
trainer = cls(**param_dict) | ||
trainer._restore_path = path | ||
return trainer | ||
|
||
@classmethod | ||
def can_restore(cls: Type["BaseTrainer"], path: Union[str, Path]) -> bool: | ||
"""Checks whether a given directory contains a restorable Train experiment. | ||
|
||
Args: | ||
path: The path to the experiment directory of the Train experiment. | ||
This can be either a local directory (e.g. ~/ray_results/exp_name) | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
or a remote URI (e.g. s3://bucket/exp_name). | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
bool: Whether or not this path exists and contains the pickled Trainer | ||
""" | ||
path = str(path) | ||
if is_non_local_path_uri(path): | ||
dir_contents = list_at_uri(path) | ||
else: | ||
dir_contents = [] if not os.path.exists(path) else os.listdir(path) | ||
return bool(dir_contents) and _TRAINER_PKL in dir_contents | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __repr__(self): | ||
# A dictionary that maps parameters to their default values. | ||
default_values: Dict[str, Any] = { | ||
|
@@ -265,6 +438,22 @@ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfi | |
) | ||
return scaling_config | ||
|
||
@classmethod | ||
def _maybe_sync_down_trainer_state(cls, restore_path: str) -> Path: | ||
"""Sync down trainer state from remote storage. | ||
|
||
Returns: | ||
local_dir of the synced trainer state | ||
""" | ||
if not is_non_local_path_uri(restore_path): | ||
return Path(os.path.expanduser(restore_path)) / _TRAINER_PKL | ||
|
||
tempdir = Path(tempfile.mkdtemp("tmp_experiment_dir")) | ||
|
||
path = Path(restore_path) | ||
download_from_uri(str(path / _TRAINER_PKL), str(tempdir / _TRAINER_PKL)) | ||
return tempdir / _TRAINER_PKL | ||
|
||
def setup(self) -> None: | ||
"""Called during fit() to perform initial setup on the Trainer. | ||
|
||
|
@@ -300,7 +489,10 @@ def preprocess_datasets(self) -> None: | |
|
||
if self.preprocessor: | ||
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None) | ||
if train_dataset: | ||
if train_dataset and self.preprocessor.fit_status in ( | ||
ray.data.Preprocessor.FitStatus.NOT_FITTED, | ||
ray.data.Preprocessor.FitStatus.PARTIALLY_FITTED, | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my education, can you share a bit why we have this logic both here and in dataset_spec.py? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
self.preprocessor.fit(train_dataset) | ||
|
||
# Execute dataset transformations serially for now. | ||
|
@@ -351,15 +543,34 @@ def fit(self) -> Result: | |
TrainingFailedError: If any failures during the execution of | ||
``self.as_trainable()``. | ||
""" | ||
from ray.tune.tuner import Tuner | ||
from ray.tune.error import TuneError | ||
from ray.tune.tuner import Tuner, TunerInternal | ||
from ray.tune import TuneError | ||
|
||
trainable = self.as_trainable() | ||
param_space = self._extract_fields_for_tuner_param_space() | ||
|
||
tuner = Tuner( | ||
trainable=trainable, param_space=param_space, run_config=self.run_config | ||
if self._restore_path: | ||
# TODO(justinvyu): Pass in the new trainable + param_space after Jun's PR | ||
# This is because some params get propagated to the Tuner and will | ||
# overwrite new ones from Trainer.restore. | ||
tuner = Tuner.restore( | ||
self._restore_path, | ||
overwrite_trainable=trainable, | ||
resume_unfinished=True, | ||
resume_errored=True, | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
tuner = Tuner( | ||
trainable=trainable, param_space=param_space, run_config=self.run_config | ||
) | ||
|
||
experiment_path = Path( | ||
TunerInternal.setup_create_experiment_checkpoint_dir( | ||
trainable, self.run_config | ||
) | ||
) | ||
self._save(experiment_path) | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
result_grid = tuner.fit() | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert len(result_grid) == 1 | ||
try: | ||
|
@@ -370,6 +581,11 @@ def fit(self) -> Result: | |
raise TrainingFailedError from e | ||
return result | ||
|
||
def _save(self, experiment_path: Union[str, Path]): | ||
experiment_path = Path(experiment_path) | ||
with open(experiment_path / _TRAINER_PKL, "wb") as fp: | ||
pickle.dump(self, fp) | ||
|
||
def _extract_fields_for_tuner_param_space(self) -> Dict: | ||
"""Extracts fields to be included in `Tuner.param_space`. | ||
|
||
|
@@ -399,16 +615,24 @@ def _generate_trainable_cls(self) -> Type["Trainable"]: | |
|
||
trainer_cls = self.__class__ | ||
scaling_config = self.scaling_config | ||
restored = bool(self._restore_path) | ||
|
||
def train_func(config, checkpoint_dir=None): | ||
def train_func(config): | ||
# config already contains merged values. | ||
# Instantiate new Trainer in Trainable. | ||
trainer = trainer_cls(**config) | ||
|
||
if checkpoint_dir: | ||
trainer.resume_from_checkpoint = Checkpoint.from_directory( | ||
checkpoint_dir | ||
) | ||
# Get the checkpoint from the Tune session, and use it to initialize | ||
# the restored trainer. | ||
# This handles recovery from both trial-level and experiment-level failures. | ||
checkpoint = session.get_checkpoint() | ||
if checkpoint: | ||
trainer.resume_from_checkpoint = checkpoint | ||
# Always load the preprocessor from checkpoint | ||
# Unless we are restoring the experiment and have passed in a new | ||
# preprocessor | ||
if not (restored and trainer.preprocessor): | ||
trainer.preprocessor = checkpoint.get_preprocessor() | ||
|
||
trainer.setup() | ||
trainer.preprocess_datasets() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a controversial change? A loaded preprocessor that's been fitted already shouldn't fit again. This way, I don't need to pass a
fit_preprocessor=False
flag in.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems reasonable. can you make this an actual comment for this logic?