Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR/Train] Add Trainer.restore API for train experiment-level fault tolerance #31920

Merged
merged 94 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
36b607c
Add API skeleton for trainer restore
justinvyu Jan 24, 2023
864f2b1
Fix some errors (imports, wrong return type)
justinvyu Jan 24, 2023
2af61de
Save trainer.pkl
justinvyu Jan 24, 2023
fd93dae
Add object ref check utility
justinvyu Jan 24, 2023
bbbf5fa
Add `test_trainer_restore`
justinvyu Jan 24, 2023
0e10972
Add preprocessor loading in restore+no new preprocessor case (and als…
justinvyu Jan 24, 2023
ded1d27
Fix typo
justinvyu Jan 24, 2023
d12e155
Change test to check exception type
justinvyu Jan 24, 2023
0a4bec3
Fix training failed error capture
justinvyu Jan 24, 2023
a9f10bb
Fix should fit preprocessor logic for new train run
justinvyu Jan 25, 2023
149863b
Improve check object refs (search for actor handles too)
justinvyu Jan 25, 2023
4ace2c9
Add more unit tests (gbdt, obj ref in train loop/config, obj ref in p…
justinvyu Jan 25, 2023
73d1fde
Remove unused imports
justinvyu Jan 25, 2023
847978e
Fix HF test function if no eval dataset is passed
justinvyu Jan 25, 2023
e481610
Add trainer w/ init tests + fix gbdt tests
justinvyu Jan 25, 2023
cfb3859
Fix lightgbm test
justinvyu Jan 25, 2023
4f3caf1
Add to bazel build file
justinvyu Jan 25, 2023
d180135
Disable mosaic trainer restore functionality
justinvyu Jan 25, 2023
d5ab478
Add config validation
justinvyu Jan 25, 2023
41b2e90
Fix case where datasets = None
justinvyu Jan 25, 2023
5de0b9a
Add test restoring from a different trainer class
justinvyu Jan 25, 2023
56c80e4
Fix gbdt test assertion
justinvyu Jan 25, 2023
8c59952
Add new args to dummy trainer used in ingest tests
justinvyu Jan 25, 2023
202ea18
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Jan 25, 2023
1cb225b
Change to specifying optional restore fields
justinvyu Jan 25, 2023
8638d75
Fix for HF special case
justinvyu Jan 25, 2023
484a873
Clean up mosaic restore args
justinvyu Jan 25, 2023
b98ae79
Add error message for invalid restore kwargs
justinvyu Jan 25, 2023
61af0ad
Fix 'should fit preprocessor' logic
justinvyu Jan 25, 2023
396fa43
Fix missing preprocessor import
justinvyu Jan 25, 2023
4bbd79c
Revert "Add to bazel build file"
justinvyu Jan 25, 2023
7721ac8
Add back to build file without formatting
justinvyu Jan 25, 2023
a486c4e
Revert "Fix for HF special case"
justinvyu Jan 26, 2023
30393d2
Revert "Change to specifying optional restore fields"
justinvyu Jan 26, 2023
2ce1c31
Remove validation logic
justinvyu Jan 27, 2023
ecbcd1a
Fix unit tests
justinvyu Jan 27, 2023
5c03050
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Jan 27, 2023
74d99cc
Simplify skipping preprocessor fit logic
justinvyu Jan 27, 2023
3ac12e2
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Jan 30, 2023
964ebe5
Add BaseTrainer.can_restore
justinvyu Jan 30, 2023
f55f5cd
Always save the trainer pkl, even on restore
justinvyu Jan 30, 2023
d885f15
Fill in restore error messages
justinvyu Jan 30, 2023
b070b01
Add tests for can restore utility and restoring from invalid dir
justinvyu Jan 30, 2023
e2ff5a9
Add docstrings for tests
justinvyu Jan 30, 2023
9adcd38
New way of doing restore where param_dict actually gets updated
justinvyu Jan 31, 2023
1f9db34
Update tests to actually catch param dict not being updated
justinvyu Jan 31, 2023
b3d8d0e
Fix loading logic for restored + re-specified preprocessor
justinvyu Jan 31, 2023
0433128
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Jan 31, 2023
ea122ad
Add restore docstrings
justinvyu Jan 31, 2023
8abe4a7
Revert "Fix training failed error capture"
justinvyu Jan 31, 2023
60cff3f
Fix tests after reverting error handling change
justinvyu Jan 31, 2023
62aed3e
Remove duplicate api ref
justinvyu Jan 31, 2023
3945c70
Fix preprocessor not found error
justinvyu Jan 31, 2023
f7c341a
Fix test failures
justinvyu Jan 31, 2023
72cf533
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Jan 31, 2023
28913f3
Update to raise value errors instead of asserts
justinvyu Jan 31, 2023
5080cb3
Update test to expect value errors
justinvyu Jan 31, 2023
543b2a8
Fix lint
justinvyu Jan 31, 2023
8fabbb7
Expand user path in can restore utility
justinvyu Feb 1, 2023
c5eeb53
Explicit ray cluster shutdown in tests
justinvyu Feb 1, 2023
1d1e9ba
Fix typo (fit_status -> fit_status())
justinvyu Feb 1, 2023
ca2a476
Add a comment about BaseTrainer._save
justinvyu Feb 1, 2023
c2e24b0
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 1, 2023
5a75eaa
Apply suggestions from code review
justinvyu Feb 1, 2023
63a7dc3
Add rl trainer restore test
justinvyu Feb 1, 2023
1ed3fc2
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 1, 2023
9c0d8a8
Update trainable param name in tuner restore
justinvyu Feb 1, 2023
43153db
Merge branch 'train/restore' of https://github.com/justinvyu/ray into…
justinvyu Feb 1, 2023
3f17166
Make can restore consistent with the other pr
justinvyu Feb 1, 2023
1f8132e
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 1, 2023
334a69a
Add api stability for session
justinvyu Feb 1, 2023
9f6e04c
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 8, 2023
0529739
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 15, 2023
410a34f
Re-specify param space in trainer restore + reenable test
justinvyu Feb 15, 2023
d1c2e11
Improve comments
justinvyu Feb 15, 2023
4196ce3
Add trainer restore to API ref
justinvyu Feb 15, 2023
4e1f532
Add FAQ post
justinvyu Feb 15, 2023
1a415bc
Convert BaseTrainer example to be framework agnostic
justinvyu Feb 15, 2023
d865435
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 16, 2023
e1b41e8
Improve faq + make example work
justinvyu Feb 16, 2023
859ecf3
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 16, 2023
da94da8
Add typing imports
justinvyu Feb 16, 2023
5c5c009
Remove accidentally included files
justinvyu Feb 16, 2023
a5fde62
Remove ipdb (oops)
justinvyu Feb 16, 2023
4b53d64
Fix duplicate HFTrainer.restore ref
justinvyu Feb 16, 2023
726cf81
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 16, 2023
ae7568e
Shouldn't check for subclass
justinvyu Feb 16, 2023
8b4a876
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 16, 2023
a7e4336
Convert error to warning instead + try/catch new trainer instantiation
justinvyu Feb 17, 2023
61bb672
Add public api alpha decorators
justinvyu Feb 17, 2023
5a62a5e
Update unit test to check for warning
justinvyu Feb 17, 2023
f47970f
Remove trailing header chars
justinvyu Feb 17, 2023
94058a0
Merge branch 'master' of https://github.com/ray-project/ray into trai…
justinvyu Feb 17, 2023
4589b52
Update docstring example to define trainer subclass inline
justinvyu Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,18 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_trainer_restore",
size = "medium",
srcs = ["tests/test_trainer_restore.py"],
tags = [
"exclusive",
"ray_air",
"team:ml",
],
deps = [":train_lib"],
)

# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
Expand Down
11 changes: 8 additions & 3 deletions python/ray/train/_internal/dataset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from ray.air.config import DatasetConfig

from ray.data import Dataset, DatasetPipeline
from ray.data.preprocessor import Preprocessor
from ray.data.preprocessors import Chain
from ray.air._internal.util import _estimate_avail_object_store_memory

if TYPE_CHECKING:
from ray.data import DatasetIterator
from ray.data.preprocessor import Preprocessor

RayDataset = Union["Dataset", "DatasetPipeline"]

Expand Down Expand Up @@ -113,7 +113,9 @@ def __init__(self, dataset_config: Dict[str, DatasetConfig]):
self.preprocessor: Optional["Preprocessor"] = None

def preprocess_datasets(
self, prep: "Preprocessor", datasets: Dict[str, "Dataset"]
self,
prep: "Preprocessor",
datasets: Dict[str, "Dataset"],
) -> Dict[str, "Dataset"]:
"""Preprocess the given datasets.

Expand Down Expand Up @@ -142,7 +144,10 @@ def preprocess_datasets(
continue
if conf.fit:
ds_to_fit = datasets[k]
if ds_to_fit:
if ds_to_fit and prep.fit_status() in (
Preprocessor.FitStatus.NOT_FITTED,
Preprocessor.FitStatus.PARTIALLY_FITTED,
):
Comment on lines +147 to +150
Copy link
Contributor Author

@justinvyu justinvyu Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a controversial change? A loaded preprocessor that's been fitted already shouldn't fit again. This way, I don't need to pass a fit_preprocessor=False flag in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems reasonable. can you make this an actual comment for this logic?

prep.fit(ds_to_fit)
new_datasets = {}

Expand Down
244 changes: 234 additions & 10 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@
import copy
import inspect
import logging
import os
from pathlib import Path
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

import ray
import ray.cloudpickle as pickle
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
from ray.air._internal.remote_storage import (
download_from_uri,
is_non_local_path_uri,
list_at_uri,
)
from ray.air.checkpoint import Checkpoint
from ray.air import session
from ray.air.config import RunConfig, ScalingConfig
from ray.air.result import Result
from ray.train.constants import TRAIN_DATASET_KEY
Expand All @@ -20,6 +30,8 @@

from ray.tune import Trainable

_TRAINER_PKL = "trainer.pkl"

# A type representing either a ray.data.Dataset or a function that returns a
# ray.data.Dataset and accepts no arguments.
GenDataset = Union["Dataset", Callable[[], "Dataset"]]
Expand Down Expand Up @@ -164,8 +176,169 @@ def __init__(
self.preprocessor = preprocessor
self.resume_from_checkpoint = resume_from_checkpoint

# This path should only be set through restore
self._restore_path = None

self._validate_attributes()

@classmethod
def restore(
cls: Type["BaseTrainer"],
path: str,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
scaling_config: Optional[ScalingConfig] = None,
**kwargs,
) -> "BaseTrainer":
"""Restores a Train experiment from a previously interrupted/failed run.

Restore should be used for experiment-level fault tolerance in the event
that the head node crashes (e.g. OOM or some other runtime error) or the
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
entire cluster goes down (e.g. network error affecting all nodes).
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

The following example can be paired with implementing job retry using
:ref:`Ray Jobs <jobs-overview>` to produce a Train experiment that will
attempt to resume on both experiment-level and trial-level failures:

.. code-block:: python

import os
from ray.train.torch import TorchTrainer
from ray import tune

experiment_name = "unique_experiment_name"
upload_dir = "s3://bucket"
experiment_dir = os.path.join(upload_dir, experiment_name)

# Pretend this is a large object that's been loaded
large_data = {}
# Use an object reference to share this object across the cluster
large_data_ref = ray.put(large_dataset)

datasets = {"train": ray.data.from_items([{"a": i} for i in range(10)])}

def train_loop_per_worker(config):
pass

train_loop_config = {"obj_ref": large_data_ref}

if TorchTrainer.can_restore(experiment_dir):
trainer = TorchTrainer.restore(
experiment_dir,
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
datasets=datasets,
)
else:
trainer = TorchTrainer(
train_loop_per_worker,
train_loop_config,
datasets=datasets,
run_config=air.RunConfig(
name=experiment_name,
sync_config=tune.SyncConfig(upload_dir=upload_dir),
# Tip: Add trial-level fault-tolerance on top.
failure_config=air.FailureConfig(max_failures=3),
),
)

result = trainer.fit()


Args:
path: The path to the experiment directory of the training run to restore.
This can be a local path or a remote URI if the experiment was
uploaded to the cloud.
datasets: Re-specified datasets used in the original training run.
This must include all the datasets that were passed in the
original trainer constructor.
preprocessor: Optionally re-specified preprocessor that was passed in
the original trainer constructor. This should be used to re-supply
the preprocessor if it is not restorable in a new Ray cluster.
This preprocessor will be fit at the start before resuming training.
If no preprocessor is passed in restore, then the old preprocessor
will be loaded from the latest checkpoint and will not be re-fit.
scaling_config: Optionally re-specified scaling config. This can be
modified to be different from the original spec.
**kwargs: Other optionally re-specified arguments, passed in by subclasses.

Raises:
ValueError: If all datasets were not re-supplied on restore.

Returns:
BaseTrainer: A restored instance of the class that is calling this method.
"""
assert cls.can_restore(path), (
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
f"Invalid restore path: {path}. Make sure that this path exists and "
"is the experiment directory that results from a call to `trainer.fit()`."
)
trainer_state_path = cls._maybe_sync_down_trainer_state(path)
assert (
trainer_state_path.exists()
), f"Did not find trainer state at {str(trainer_state_path)}."

with open(trainer_state_path, "rb") as fp:
original_trainer = pickle.load(fp)
assert type(original_trainer) == cls, (
f"Invalid trainer type. Cannot restore a trainer of type "
f"{type(original_trainer)} with `{cls.__name__}.restore`. "
f"Use `{type(original_trainer).__name__}.restore` instead."
)

# Get the param dict used to initialize the original trainer
param_dict = original_trainer._param_dict

original_datasets = original_trainer.datasets or {}
if original_datasets and not datasets:
raise ValueError(
"The following datasets need to be provided again on restore: "
f"{list(original_datasets.keys())}\n"
f"Use {cls.__name__}.restore(..., datasets=datasets) "
"with the datasets that were provided to the original trainer."
)
datasets = datasets or {}
assert set(original_datasets.keys()) == set(datasets.keys()), (
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
"The provided datasets don't match the original dataset keys.\n"
f" Expected datasets for the keys: {list(original_datasets.keys())}\n"
f" Actual datasets provided: {list(datasets.keys())}"
)
param_dict["datasets"] = datasets

# If no preprocessor is re-specified, then it will be set to None
# here and loaded from the latest checkpoint
param_dict["preprocessor"] = preprocessor

if scaling_config:
param_dict["scaling_config"] = scaling_config

for param_name, val in kwargs.items():
# Overwrite the old value if something is passed into restore
if val is not None:
param_dict[param_name] = val

trainer = cls(**param_dict)
trainer._restore_path = path
return trainer

@classmethod
def can_restore(cls: Type["BaseTrainer"], path: Union[str, Path]) -> bool:
"""Checks whether a given directory contains a restorable Train experiment.

Args:
path: The path to the experiment directory of the Train experiment.
This can be either a local directory (e.g. ~/ray_results/exp_name)
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
or a remote URI (e.g. s3://bucket/exp_name).
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

Returns:
bool: Whether or not this path exists and contains the pickled Trainer
"""
path = str(path)
if is_non_local_path_uri(path):
dir_contents = list_at_uri(path)
else:
dir_contents = [] if not os.path.exists(path) else os.listdir(path)
return bool(dir_contents) and _TRAINER_PKL in dir_contents
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
# A dictionary that maps parameters to their default values.
default_values: Dict[str, Any] = {
Expand Down Expand Up @@ -265,6 +438,22 @@ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfi
)
return scaling_config

@classmethod
def _maybe_sync_down_trainer_state(cls, restore_path: str) -> Path:
"""Sync down trainer state from remote storage.

Returns:
local_dir of the synced trainer state
"""
if not is_non_local_path_uri(restore_path):
return Path(os.path.expanduser(restore_path)) / _TRAINER_PKL

tempdir = Path(tempfile.mkdtemp("tmp_experiment_dir"))

path = Path(restore_path)
download_from_uri(str(path / _TRAINER_PKL), str(tempdir / _TRAINER_PKL))
return tempdir / _TRAINER_PKL

def setup(self) -> None:
"""Called during fit() to perform initial setup on the Trainer.

Expand Down Expand Up @@ -300,7 +489,10 @@ def preprocess_datasets(self) -> None:

if self.preprocessor:
train_dataset = self.datasets.get(TRAIN_DATASET_KEY, None)
if train_dataset:
if train_dataset and self.preprocessor.fit_status in (
ray.data.Preprocessor.FitStatus.NOT_FITTED,
ray.data.Preprocessor.FitStatus.PARTIALLY_FITTED,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my education, can you share a bit why we have this logic both here and in dataset_spec.py?

Copy link
Contributor Author

@justinvyu justinvyu Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DataParallelTrainer goes through a separate path for preprocessing datasets. It uses a data parallel ingest spec. Other trainers like xgboost will let the underlying framework handle the data ingest (ex: params can be passed in to the underlying xgboost_ray RayDMatrix object.)

self.preprocessor.fit(train_dataset)

# Execute dataset transformations serially for now.
Expand Down Expand Up @@ -351,15 +543,34 @@ def fit(self) -> Result:
TrainingFailedError: If any failures during the execution of
``self.as_trainable()``.
"""
from ray.tune.tuner import Tuner
from ray.tune.error import TuneError
from ray.tune.tuner import Tuner, TunerInternal
from ray.tune import TuneError

trainable = self.as_trainable()
param_space = self._extract_fields_for_tuner_param_space()

tuner = Tuner(
trainable=trainable, param_space=param_space, run_config=self.run_config
if self._restore_path:
# TODO(justinvyu): Pass in the new trainable + param_space after Jun's PR
# This is because some params get propagated to the Tuner and will
# overwrite new ones from Trainer.restore.
tuner = Tuner.restore(
self._restore_path,
overwrite_trainable=trainable,
resume_unfinished=True,
resume_errored=True,
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
)
else:
tuner = Tuner(
trainable=trainable, param_space=param_space, run_config=self.run_config
)

experiment_path = Path(
TunerInternal.setup_create_experiment_checkpoint_dir(
trainable, self.run_config
)
)
self._save(experiment_path)
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

result_grid = tuner.fit()
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
assert len(result_grid) == 1
try:
Expand All @@ -370,6 +581,11 @@ def fit(self) -> Result:
raise TrainingFailedError from e
return result

def _save(self, experiment_path: Union[str, Path]):
experiment_path = Path(experiment_path)
with open(experiment_path / _TRAINER_PKL, "wb") as fp:
pickle.dump(self, fp)

def _extract_fields_for_tuner_param_space(self) -> Dict:
"""Extracts fields to be included in `Tuner.param_space`.

Expand Down Expand Up @@ -399,16 +615,24 @@ def _generate_trainable_cls(self) -> Type["Trainable"]:

trainer_cls = self.__class__
scaling_config = self.scaling_config
restored = bool(self._restore_path)

def train_func(config, checkpoint_dir=None):
def train_func(config):
# config already contains merged values.
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

if checkpoint_dir:
trainer.resume_from_checkpoint = Checkpoint.from_directory(
checkpoint_dir
)
# Get the checkpoint from the Tune session, and use it to initialize
# the restored trainer.
# This handles recovery from both trial-level and experiment-level failures.
checkpoint = session.get_checkpoint()
if checkpoint:
trainer.resume_from_checkpoint = checkpoint
# Always load the preprocessor from checkpoint
# Unless we are restoring the experiment and have passed in a new
# preprocessor
if not (restored and trainer.preprocessor):
trainer.preprocessor = checkpoint.get_preprocessor()

trainer.setup()
trainer.preprocess_datasets()
Expand Down
Loading