Skip to content

Commit

Permalink
[doc][tune] Add tune checkpoint user guide. (#33145)
Browse files Browse the repository at this point in the history
* [doc][tune] Add tune checkpoint user guide.

Also updated storage-options and trainable page.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* fix doc tests.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* move _tune-persisted-experiment-data

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* address comments

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* address comments.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* fix test

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* address timeout issue

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* address comments

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

* address comments

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>

---------

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
  • Loading branch information
xwjiang2010 authored Mar 22, 2023
1 parent 7d47206 commit 979b9db
Show file tree
Hide file tree
Showing 10 changed files with 430 additions and 235 deletions.
1 change: 1 addition & 0 deletions doc/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ parts:
- file: tune/tutorials/tune-stopping
title: "How to Stop and Resume"
- file: tune/tutorials/tune-storage
- file: tune/tutorials/tune-trial-checkpoint
- file: tune/tutorials/tune-metrics
title: "Using Callbacks and Metrics"
- file: tune/tutorials/tune_get_data_in_and_out
Expand Down
10 changes: 5 additions & 5 deletions doc/source/tune/api/schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ This can be enabled by setting the ``scheduler`` parameter of ``tune.TuneConfig`
When the PBT scheduler is enabled, each trial variant is treated as a member of the population.
Periodically, **top-performing trials are checkpointed**
(this requires your Trainable to support :ref:`save and restore <tune-function-checkpointing>`).
(this requires your Trainable to support :ref:`save and restore <tune-trial-checkpoint>`).
**Low-performing trials clone the hyperparameter configurations of top performers and
perturb them** slightly in the hopes of discovering even better hyperparameter settings.
**Low-performing trials also resume from the checkpoints of the top performers**, allowing
Expand Down Expand Up @@ -261,7 +261,7 @@ PB2 can be enabled by setting the ``scheduler`` parameter of ``tune.TuneConfig``
When the PB2 scheduler is enabled, each trial variant is treated as a member of the population.
Periodically, top-performing trials are checkpointed (this requires your Trainable to
support :ref:`save and restore <tune-function-checkpointing>`).
support :ref:`save and restore <tune-trial-checkpoint>`).
Low-performing trials clone the checkpoints of top performers and perturb the configurations
in the hope of discovering an even better variation.

Expand Down Expand Up @@ -308,9 +308,9 @@ It wraps around another scheduler and uses its decisions.
which will let your model know about the new resources assigned. You can also obtain the current trial resources
by calling ``Trainable.trial_resources``.

* If you are using the functional API for tuning, the current trial resources can be
obtained by calling `tune.get_trial_resources()` inside the training function.
The function should be able to :ref:`load and save checkpoints <tune-function-checkpointing>`
* If you are using the functional API for tuning, get the current trial resources obtained by calling
`tune.get_trial_resources()` inside the training function.
The function should be able to :ref:`load and save checkpoints <tune-function-trainable-checkpointing>`
(the latter preferably every iteration).

An example of this in use can be found here: :doc:`/tune/examples/includes/xgboost_dynamic_resources_example`.
Expand Down
109 changes: 11 additions & 98 deletions doc/source/tune/api/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
Training in Tune (tune.Trainable, session.report)
=================================================

Training can be done with either a **Function API** (:ref:`session.report <tune-function-docstring>`) or **Class API** (:ref:`tune.Trainable <tune-trainable-docstring>`).
Training can be done with either a **Function API** (:ref:`session.report <tune-function-docstring>`) or
**Class API** (:ref:`tune.Trainable <tune-trainable-docstring>`).

For the sake of example, let's maximize this objective function:

Expand All @@ -18,11 +19,11 @@ For the sake of example, let's maximize this objective function:

.. _tune-function-api:

Tune's Function API
-------------------
Function Trainable API
----------------------

The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes,
one for each Tune trial.
Use the Function API to define a custom training function that Tune runs in Ray actor processes. Each trial is placed
into a Ray actor process and runs in parallel.

The ``config`` argument in the function is a dictionary populated automatically by Ray Tune and corresponding to
the hyperparameters selected for the trial from the :ref:`search space <tune-key-concepts-search-spaces>`.
Expand Down Expand Up @@ -51,36 +52,14 @@ It's also possible to return a final set of metrics to Tune by returning them fr
:start-after: __function_api_return_final_metrics_start__
:end-before: __function_api_return_final_metrics_end__

You'll notice that Ray Tune will output extra values in addition to the user reported metrics,
such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. _tune-function-checkpointing:

Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
You can save and load checkpoints in Ray Tune in the following manner:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __function_api_checkpointing_start__
:end-before: __function_api_checkpointing_end__

.. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to ``<local_dir>/<exp_name>/trial_name/checkpoint_<step>``.

Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of ``save``.
Note that Ray Tune outputs extra values in addition to the user reported metrics,
such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation of these values.

See :ref:`here for more information on creating checkpoints <air-checkpoint-ref>`.
If using framework-specific trainers from Ray AIR, see :ref:`here <air-trainer-ref>` for
references to framework-specific checkpoints such as `TensorflowCheckpoint`.
See how to configure checkpointing for a function trainable :ref:`here <tune-function-trainable-checkpointing>`.

.. _tune-class-api:

Tune's Trainable Class API
Class Trainable API
--------------------------

.. caution:: Do not use ``session.report`` within a ``Trainable`` class.
Expand Down Expand Up @@ -111,73 +90,7 @@ You'll notice that Ray Tune will output extra values in addition to the user rep
such as ``iterations_since_restore``.
See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values.

.. _tune-trainable-save-restore:

Class API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~

You can also implement checkpoint/restore using the Trainable Class API:

.. literalinclude:: /tune/doc_code/trainable.py
:language: python
:start-after: __class_api_checkpointing_start__
:end-before: __class_api_checkpointing_end__

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True``
(or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `step`.
This can be especially helpful in spot instances:

.. code-block:: python
def step(self):
# training code
result = {"mean_accuracy": accuracy}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
**Periodic Checkpointing**: periodic checkpointing can be used to provide fault-tolerance for experiments.
This can be enabled by setting ``checkpoint_frequency=<int>`` and ``max_failures=<int>`` to checkpoint trials
every *N* iterations and recover from up to *M* crashes per trial, e.g.:

.. code-block:: python
tuner = tune.Tuner(
my_trainable,
run_config=air.RunConfig(
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
failure_config=air.FailureConfig(max_failures=5))
)
results = tuner.fit()
**Checkpointing at Termination**: The checkpoint_frequency may not coincide with the exact end of an experiment.
If you want a checkpoint to be created at the end of a trial, you can additionally set the ``checkpoint_at_end=True``:

.. code-block:: python
:emphasize-lines: 5
tuner = tune.Tuner(
my_trainable,
run_config=air.RunConfig(
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10, checkpoint_at_end=True),
failure_config=air.FailureConfig(max_failures=5))
)
results = tuner.fit()
Use ``validate_save_restore`` to catch ``save_checkpoint``/``load_checkpoint`` errors before execution.

.. code-block:: python
from ray.tune.utils import validate_save_restore
# both of these should return
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)
See how to configure checkpoint for class trainable :ref:`here <tune-class-trainable-checkpointing>`.


Advanced: Reusing Actors in Tune
Expand Down
66 changes: 0 additions & 66 deletions doc/source/tune/doc_code/trainable.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,5 @@
# flake8: noqa

# __class_api_checkpointing_start__
import os
import torch
from torch import nn

from ray import air, tune


class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)

def step(self):
return {}

def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=air.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
# __class_api_checkpointing_end__

# __function_api_checkpointing_start__
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint


def train_func(config):
epochs = config.get("epochs", 2)
start = 0
loaded_checkpoint = session.get_checkpoint()
if loaded_checkpoint:
last_step = loaded_checkpoint.to_dict()["step"]
start = last_step + 1

for step in range(start, epochs):
# Model training here
# ...

# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
checkpoint = Checkpoint.from_dict({"step": step})
session.report(metrics, checkpoint=checkpoint)


tuner = tune.Tuner(train_func)
results = tuner.fit()
# __function_api_checkpointing_end__

# fmt: off
# __example_objective_start__
def objective(x, a, b):
Expand Down
Loading

0 comments on commit 979b9db

Please sign in to comment.