Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc][tune] Add tune checkpoint user guide. #33145

Merged
merged 14 commits into from
Mar 22, 2023
Merged
1 change: 1 addition & 0 deletions doc/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ parts:
- file: tune/tutorials/tune-stopping
title: "How to Stop and Resume"
- file: tune/tutorials/tune-storage
- file: tune/tutorials/tune-trial-checkpoint
- file: tune/tutorials/tune-metrics
title: "Using Callbacks and Metrics"
- file: tune/tutorials/tune_get_data_in_and_out
Expand Down
10 changes: 5 additions & 5 deletions doc/source/tune/api/schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ This can be enabled by setting the ``scheduler`` parameter of ``tune.TuneConfig`

When the PBT scheduler is enabled, each trial variant is treated as a member of the population.
Periodically, **top-performing trials are checkpointed**
(this requires your Trainable to support :ref:`save and restore <tune-function-checkpointing>`).
(this requires your Trainable to support :ref:`save and restore <tune-trial-checkpoint>`).
**Low-performing trials clone the hyperparameter configurations of top performers and
perturb them** slightly in the hopes of discovering even better hyperparameter settings.
**Low-performing trials also resume from the checkpoints of the top performers**, allowing
Expand Down Expand Up @@ -261,7 +261,7 @@ PB2 can be enabled by setting the ``scheduler`` parameter of ``tune.TuneConfig``

When the PB2 scheduler is enabled, each trial variant is treated as a member of the population.
Periodically, top-performing trials are checkpointed (this requires your Trainable to
support :ref:`save and restore <tune-function-checkpointing>`).
support :ref:`save and restore <tune-trial-checkpoint>`).
Low-performing trials clone the checkpoints of top performers and perturb the configurations
in the hope of discovering an even better variation.

Expand Down Expand Up @@ -308,9 +308,9 @@ It wraps around another scheduler and uses its decisions.
which will let your model know about the new resources assigned. You can also obtain the current trial resources
by calling ``Trainable.trial_resources``.

* If you are using the functional API for tuning, the current trial resources can be
obtained by calling `tune.get_trial_resources()` inside the training function.
The function should be able to :ref:`load and save checkpoints <tune-function-checkpointing>`
* If you are using the functional API for tuning, get the current trial resources obtained by calling
`tune.get_trial_resources()` inside the training function.
The function should be able to :ref:`load and save checkpoints <tune-function-trainable-checkpointing>`
(the latter preferably every iteration).

An example of this in use can be found here: :doc:`/tune/examples/includes/xgboost_dynamic_resources_example`.
Expand Down
109 changes: 11 additions & 98 deletions doc/source/tune/api/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
Training in Tune (tune.Trainable, session.report)
=================================================

Training can be done with either a **Function API** (:ref:`session.report <tune-function-docstring>`) or **Class API** (:ref:`tune.Trainable <tune-trainable-docstring>`).
Training can be done with either a **Function API** (:ref:`session.report <tune-function-docstring>`) or
**Class API** (:ref:`tune.Trainable <tune-trainable-docstring>`).

For the sake of example, let's maximize this objective function:

Expand All @@ -18,11 +19,11 @@ For the sake of example, let's maximize this objective function:

.. _tune-function-api:

Tune's Function API
-------------------
Function Trainable API
----------------------

The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes,
one for each Tune trial.
Use the Function API to define a custom training function that Tune runs in Ray actor processes. Each trial is placed
into a Ray actor process and runs in parallel.

The ``config`` argument in the function is a dictionary populated automatically by Ray Tune and corresponding to
the hyperparameters selected for the trial from the :ref:`search space <tune-key-concepts-search-spaces>`.
Expand Down Expand Up @@ -51,36 +52,14 @@ It's also possible to return a final set of metrics to Tune by returning them fr
:start-after: __function_api_return_final_metrics_start__
:end-before: __function_api_return_final_metrics_end__

You'll notice that Ray Tune will output extra values in addition to the user reported metrics,
such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. _tune-function-checkpointing:

Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoints in Ray Tune in the following manner:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_checkpointing_start__
:end-before: __function_api_checkpointing_end__

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``<local_dir>/<exp_name>/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.
Note that Ray Tune outputs extra values in addition to the user reported metrics,
such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation of these values.

See :ref:`here for more information on creating checkpoints <air-checkpoint-ref>`.
If using framework-specific trainers from Ray AIR, see :ref:`here <air-trainer-ref>` for
references to framework-specific checkpoints such as `TensorflowCheckpoint`.
See how to configure checkpointing for a function trainable :ref:`here <tune-function-trainable-checkpointing>`.

.. _tune-class-api:

Tune's Trainable Class API
Class Trainable API
--------------------------

.. caution:: Do not use ``session.report`` within a ``Trainable`` class.
Expand Down Expand Up @@ -111,73 +90,7 @@ You'll notice that Ray Tune will output extra values in addition to the user rep
such as ``iterations_since_restore``.
See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. _tune-trainable-save-restore:

Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~

You can also implement checkpoint/restore using the Trainable Class API:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __class_api_checkpointing_start__
:end-before: __class_api_checkpointing_end__

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True``
(or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `step`.
This can be especially helpful in spot instances:

.. code-block:: python

def step(self):
# training code
result = {"mean_accuracy": accuracy}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result


**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments.
This can be enabled by setting ``checkpoint_frequency=<int>`` and ``max_failures=<int>`` to checkpoint trials
every *N* iterations and recover from up to *M* crashes per trial, e.g.:

.. code-block:: python

tuner = tune.Tuner(
my_trainable,
run_config=air.RunConfig(
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
failure_config=air.FailureConfig(max_failures=5))
)
results = tuner.fit()

**Checkpointing at Termination**: The checkpoint_frequency may not coincide with the exact end of an experiment.
If you want a checkpoint to be created at the end of a trial, you can additionally set the ``checkpoint_at_end=True``:

.. code-block:: python
:emphasize-lines: 5

tuner = tune.Tuner(
my_trainable,
run_config=air.RunConfig(
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10, checkpoint_at_end=True),
failure_config=air.FailureConfig(max_failures=5))
)
results = tuner.fit()


Use ``validate_save_restore`` to catch ``save_checkpoint``/``load_checkpoint`` errors before execution.

.. code-block:: python

from ray.tune.utils import validate_save_restore

# both of these should return
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)

See how to configure checkpoint for class trainable :ref:`here <tune-class-trainable-checkpointing>`.
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved


Advanced: Reusing Actors in Tune
Expand Down
66 changes: 0 additions & 66 deletions doc/source/tune/doc_code/trainable.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,5 @@
# flake8: noqa

# __class_api_checkpointing_start__
import os
import torch
from torch import nn

from ray import air, tune


class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)

def step(self):
return {}

def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=air.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
# __class_api_checkpointing_end__

# __function_api_checkpointing_start__
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint


def train_func(config):
epochs = config.get("epochs", 2)
start = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
start = last_step + 1

for step in range(start, epochs):
# Model training here
# ...

# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
checkpoint = Checkpoint.from_dict({"step": step})
session.report(metrics, checkpoint=checkpoint)


tuner = tune.Tuner(train_func)
results = tuner.fit()
# __function_api_checkpointing_end__

# fmt: off
# __example_objective_start__
def objective(x, a, b):
Expand Down
Loading