Skip to content

Commit

Permalink
[tune] Use public methods for trainable (#9184)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Jul 1, 2020
1 parent 1491508 commit d35f0e4
Show file tree
Hide file tree
Showing 40 changed files with 350 additions and 220 deletions.
4 changes: 2 additions & 2 deletions doc/source/rllib-dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Contributing Algorithms
These are the guidelines for merging new algorithms into RLlib:

* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__):
- must subclass Trainer and implement the ``_train()`` method
- must subclass Trainer and implement the ``step()`` method
- must include a lightweight test (`example <https://github.com/ray-project/ray/blob/6bb110393008c9800177490688c6ed38b2da52a9/test/jenkins_tests/run_multi_node_tests.sh#L45>`__) to ensure the algorithm runs
- should include tuned hyperparameter examples and documentation
- should offer functionality not present in existing algorithms
Expand All @@ -46,7 +46,7 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a

How to add an algorithm to ``contrib``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Trainer <https://github.com/ray-project/ray/tree/master/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Trainer <https://github.com/ray-project/ray/tree/master/rllib/agents/agent.py>`__ and implement the ``_init`` and ``step`` methods:

.. literalinclude:: ../../rllib/contrib/random_agent/random_agent.py
:language: python
Expand Down
4 changes: 2 additions & 2 deletions doc/source/tune/_tutorials/tune-60-seconds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ The other is a :ref:`class-based API <tune-class-api>`. Here's an example of spe
from ray import tune
class Trainable(tune.Trainable):
def _setup(self, config):
def setup(self, config):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]
def _train(self): # This is called iteratively.
def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}
Expand Down
2 changes: 1 addition & 1 deletion doc/source/tune/_tutorials/tune-distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ In GCP, you can use the following configuration modification:
Spot instances may be removed suddenly while trials are still running. Often times this may be difficult to deal with when using other distributed hyperparameter optimization frameworks. Tune allows users to mitigate the effects of this by preserving the progress of your model training through checkpointing.

The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement ``_save``, and ``_restore`` abstract methods, as seen in the example below:
The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement ``save_checkpoint``, and ``load_checkpoint`` abstract methods, as seen in the example below:

.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_trainable.py
:language: python
Expand Down
4 changes: 2 additions & 2 deletions doc/source/tune/_tutorials/tune-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ You can log arbitrary values and metrics in both training APIs:
class Trainable(tune.Trainable):
...
def _train(self): # this is called iteratively
def step(self): # this is called iteratively
accuracy = self.model.train()
metric_1 = f(self.model)
metric_2 = self.model.get_loss()
Expand Down Expand Up @@ -223,7 +223,7 @@ Stopping Trials

You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function.

If a dictionary is passed in, the keys may be any field in the return result of ``tune.report`` in the Function API or ``_train()`` (including the results from ``_train`` and auto-filled metrics).
If a dictionary is passed in, the keys may be any field in the return result of ``tune.report`` in the Function API or ``step()`` (including the results from ``step`` and auto-filled metrics).

In the example below, each trial will be stopped either when it completes 10 iterations OR when it reaches a mean accuracy of 0.98. These metrics are assumed to be **increasing**.

Expand Down
6 changes: 3 additions & 3 deletions doc/source/tune/api_docs/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ You can do this in the trainable, as shown below:
.. code-block:: python
class CustomLogging(tune.Trainable)
def _setup(self, config):
def setup(self, config):
trial_id = self.trial_id
library.init(
name=trial_id,
Expand All @@ -109,10 +109,10 @@ You can do this in the trainable, as shown below:
allow_val_change=True)
library.set_log_path(self.logdir)
def _train(self):
def step(self):
library.log_model(...)
def _log_result(self, result):
def log_result(self, result):
res_dict = {
str(k): v
for k, v in result.items()
Expand Down
26 changes: 13 additions & 13 deletions doc/source/tune/api_docs/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Many Tune features rely on checkpointing, including the usage of certain Trial S
for iter in range(start, 100):
time.sleep(1)
#
#
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
Expand Down Expand Up @@ -101,13 +101,13 @@ The Trainable **class API** will require users to subclass ``ray.tune.Trainable`
from ray import tune
class Trainable(tune.Trainable):
def _setup(self, config):
def setup(self, config):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]
def _train(self): # This is called iteratively.
def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}
Expand All @@ -124,11 +124,11 @@ The Trainable **class API** will require users to subclass ``ray.tune.Trainable`
As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on a separate process (using the :ref:`Ray Actor API <actor-guide>`).

1. ``_setup`` function is invoked once training starts.
2. ``_train`` is invoked **multiple times**. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.
3. ``_stop`` is invoked when training is finished.
1. ``setup`` function is invoked once training starts.
2. ``step`` is invoked **multiple times**. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.
3. ``cleanup`` is invoked when training is finished.

.. tip:: As a rule of thumb, the execution time of ``_train`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).
.. tip:: As a rule of thumb, the execution time of ``step`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).


.. _tune-trainable-save-restore:
Expand All @@ -141,24 +141,24 @@ You can also implement checkpoint/restore using the Trainable Class API:
.. code-block:: python
class MyTrainableClass(Trainable):
def _save(self, tmp_checkpoint_dir):
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def _restore(self, tmp_checkpoint_dir):
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tune.run(MyTrainableClass, checkpoint_freq=2)
You can checkpoint with three different mechanisms: manually, periodically, and at termination.

**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `_train`. This can be especially helpful in spot instances:
**Manual Checkpointing**: A custom Trainable can manually trigger checkpointing by returning ``should_checkpoint: True`` (or ``tune.result.SHOULD_CHECKPOINT: True``) in the result dictionary of `step`. This can be especially helpful in spot instances:

.. code-block:: python
def _train(self):
def step(self):
# training code
result = {"mean_accuracy": accuracy}
if detect_instance_preemption():
Expand Down Expand Up @@ -190,7 +190,7 @@ of a trial, you can additionally set the ``checkpoint_at_end=True``:
)
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution.
Use ``validate_save_restore`` to catch ``save_checkpoint``/``load_checkpoint`` errors before execution.

.. code-block:: python
Expand All @@ -214,7 +214,7 @@ This requires you to implement ``Trainable.reset_config``, which provides a new
class PytorchTrainble(tune.Trainable):
"""Train a Pytorch ConvNet."""
def _setup(self, config):
def setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
Expand Down
3 changes: 1 addition & 2 deletions python/ray/tune/durable_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ class DurableTrainable(Trainable):
"""Abstract class for a remote-storage backed fault-tolerant Trainable.
Supports checkpointing to and restoring from remote storage. To use this
class, implement the same private methods as ray.tune.Trainable (`_save`,
`_train`, `_restore`, `reset_config`, `_setup`, `_stop`).
class, implement the same private methods as ray.tune.Trainable.
.. warning:: This class is currently **experimental** and may
be subject to change.
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/async_hyperband_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""

def _setup(self, config):
def setup(self, config):
self.timestep = 0

def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
Expand All @@ -31,13 +31,13 @@ def _train(self):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]

Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/bohb_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""

def _setup(self, config):
def setup(self, config):
self.timestep = 0

def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
Expand All @@ -30,13 +30,13 @@ def _train(self):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]

Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/durable_trainable_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def eval(self, k, add_noise=True):

def get_optimus_trainable(parent_cls):
class OptimusTrainable(parent_cls):
def _setup(self, config):
def setup(self, config):
self.iter = 0
if config.get("seed"):
np.random.seed(config["seed"])
Expand All @@ -61,7 +61,7 @@ def _setup(self, config):
self.initial_samples_per_step = 500
self.mock_data = open("/dev/urandom", "rb").read(1024)

def _train(self):
def step(self):
self.iter += 1
new_loss = self.func.eval(self.iter)
time.sleep(0.5)
Expand All @@ -71,7 +71,7 @@ def _train(self):
"samples": self.initial_samples_per_step
}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
time.sleep(0.5)
return {
"func": cloudpickle.dumps(self.func),
Expand All @@ -80,7 +80,7 @@ def _save(self, checkpoint_dir):
"iter": self.iter
}

def _restore(self, checkpoint):
def load_checkpoint(self, checkpoint):
self.func = cloudpickle.loads(checkpoint["func"])
self.data = checkpoint["data"]
self.iter = checkpoint["iter"]
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/hyperband_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""

def _setup(self, config):
def setup(self, config):
self.timestep = 0

def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
Expand All @@ -31,13 +31,13 @@ def _train(self):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]

Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/logging_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class MyTrainableClass(Trainable):
maximum reward value reached.
"""

def _setup(self, config):
def setup(self, config):
self.timestep = 0

def _train(self):
def step(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
Expand All @@ -39,13 +39,13 @@ def _train(self):
# objectives such as loss or accuracy.
return {"episode_reward_mean": v}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]

Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/mnist_pytorch_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# yapf: disable
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
def _setup(self, config):
def setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.train_loader, self.test_loader = get_data_loaders()
Expand All @@ -44,18 +44,18 @@ def _setup(self, config):
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))

def _train(self):
def step(self):
train(
self.model, self.optimizer, self.train_loader, device=self.device)
acc = test(self.model, self.test_loader, self.device)
return {"mean_accuracy": acc}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))


Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/examples/pbt_convnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,25 @@ class PytorchTrainble(tune.Trainable):
changing the original training code.
"""

def _setup(self, config):
def setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))

def _train(self):
def step(self):
train(self.model, self.optimizer, self.train_loader)
acc = test(self.model, self.test_loader)
return {"mean_accuracy": acc}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))

def _export_model(self, export_formats, export_dir):
Expand Down
Loading

0 comments on commit d35f0e4

Please sign in to comment.