Skip to content

Commit

Permalink
[RLlib] Cleanup examples folder (vol 30): BC pretraining, then PPO fi…
Browse files Browse the repository at this point in the history
…netuning (new API stack with RLModule checkpoints). (#47838)
  • Loading branch information
sven1977 authored Sep 28, 2024
1 parent 44bdd4d commit b676d02
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 6 deletions.
25 changes: 19 additions & 6 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "medium",
srcs = ["tuned_examples/bc/cartpole_bc.py"],
# Include the zipped json data file as well.
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
],
Expand Down Expand Up @@ -557,7 +557,7 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/marwil/cartpole_marwil.py"],
# Include the zipped json data file as well.
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
],
Expand Down Expand Up @@ -987,7 +987,7 @@ py_test(
name = "test_bc",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
# Include the parquet data files.
# Include the offline data files.
data = ["tests/data/cartpole/cartpole-v1_large"],
srcs = ["algorithms/bc/tests/test_bc.py"]
)
Expand Down Expand Up @@ -1052,7 +1052,7 @@ py_test(
name = "test_marwil",
tags = ["team:rllib", "algorithms_dir"],
size = "large",
# Include the parquet data folder.
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/pendulum/pendulum-v1_large",
Expand Down Expand Up @@ -1708,21 +1708,23 @@ py_test(
tags = ["team:rllib", "offline"],
size = "medium",
srcs = ["offline/tests/test_offline_data.py"],
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
],
]
)

py_test(
name = "test_offline_prelearner",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_prelearner.py"],
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
],
]
)

# --------------------------------------------------------------------
Expand Down Expand Up @@ -3102,6 +3104,17 @@ py_test(
# subdirectory: offline_rl/
# ....................................

py_test(
name = "examples/offline_rl/train_w_bc_finetune_w_ppo",
main = "examples/offline_rl/train_w_bc_finetune_w_ppo.py",
tags = ["team:rllib", "examples", "exclusive"],
size = "medium",
srcs = ["examples/offline_rl/train_w_bc_finetune_w_ppo.py"],
args = ["--enable-new-api-stack", "--as-test", "--framework=torch"],
# Include the offline data files.
data = ["tests/data/cartpole/cartpole-v1_large"]
)

# @HybridAPIStack
# py_test(
# name = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent",
Expand Down
306 changes: 306 additions & 0 deletions rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
"""Example of training a custom RLModule with BC first, then finetuning it with PPO.
This example:
- demonstrates how to write a very simple custom BC RLModule.
- run a quick BC training experiment with the custom module and learn CartPole
until some episode return A, while checkpointing each iteration.
- shows how subclass the custom BC RLModule, add the ValueFunctionAPI to the
new class, and add a value-function branch and an implementation of
`compute_values` to the original model to make it work with a value-based algo
like PPO.
- shows how to plug this new PPO-capable RLModule (including its checkpointed state
from the BC run) into your algorithm's config.
- confirms that even after 1-2 training iterations with PPO, no catastrophic
forgetting occurs (due to the additional value function branch and the switched
optimizer).
- uses Tune and RLlib to continue training the model until a higher return of B
is reached.
How to run this script
----------------------
`python [script file name].py --enable-new-api-stack`
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
Results to expect
-----------------
In the console output, you can first see BC's performance until return A is reached:
+----------------------------+------------+----------------+--------+
| Trial name | status | loc | iter |
| | | | |
|----------------------------+------------+----------------+--------+
| BC_CartPole-v1_95ba0_00000 | TERMINATED | 127.0.0.1:1515 | 51 |
+----------------------------+------------+----------------+--------+
+------------------+------------------------+------------------------+
| total time (s) | episode_return_mean | num_env_steps_traine |
| | | d_lifetime |
|------------------+------------------------|------------------------|
| 11.4828 | 250.5 | 42394 |
+------------------+------------------------+------------------------+
The script should confirm that no catastrophic forgetting has taken place:
PPO return after initialization: 292.3
PPO return after 2x training: 276.85
Then, after PPO training, you should see something like this (higher return):
+-----------------------------+------------+----------------+--------+
| Trial name | status | loc | iter |
| | | | |
|-----------------------------+------------+----------------+--------+
| PPO_CartPole-v1_e07ac_00000 | TERMINATED | 127.0.0.1:6032 | 37 |
+-----------------------------+------------+----------------+--------+
+------------------+------------------------+------------------------+
| total time (s) | episode_return_mean | num_episodes_lifetime |
| | | |
+------------------+------------------------+------------------------+
| 32.7647 | 450.76 | 406 |
+------------------+------------------------+------------------------+
"""
from pathlib import Path

from torch import nn

from ray.rllib.algorithms.bc import BCConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import (
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.configs import MLPEncoderConfig, MLPHeadConfig
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
EVALUATION_RESULTS,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)

parser = add_rllib_example_script_args()
parser.set_defaults(
enable_new_api_stack=True,
env="CartPole-v1",
checkpoint_freq=1,
)


class MyBCModel(TorchRLModule):
"""A very simple BC-usable model that only computes action logits."""

@override(TorchRLModule)
def setup(self):
# Create an encoder trunk.
# Observations are directly passed through it and feature vectors are output.
self._encoder = MLPEncoderConfig(
input_dims=[4], # CartPole
hidden_layer_dims=[256, 256],
hidden_layer_activation="relu",
output_layer_dim=None,
).build(framework="torch")

# The policy head sitting on top of the encoder. Feature vectors come in as
# input and action logits are output.
self._pi = MLPHeadConfig(
input_dims=[256], # from encoder
hidden_layer_dims=[256], # pi head
hidden_layer_activation="relu",
output_layer_dim=2, # CartPole
output_layer_activation="linear",
).build(framework="torch")

@override(TorchRLModule)
def _forward_inference(self, batch, **kwargs):
return {Columns.ACTION_DIST_INPUTS: self._pi(self._encoder(batch)[ENCODER_OUT])}

@override(RLModule)
def _forward_exploration(self, batch, **kwargs):
return self._forward_inference(batch)

@override(RLModule)
def _forward_train(self, batch, **kwargs):
return self._forward_inference(batch)


class MyPPOModel(MyBCModel, ValueFunctionAPI):
"""Subclass of our simple BC model, but implementing the ValueFunctionAPI.
Implementing the `compute_values` method makes this RLModule usable by algos
like PPO, which require this.
"""

@override(MyBCModel)
def setup(self):
# Call super setup to create encoder trunk and policy head.
super().setup()
# Create the new value function head and zero-initialize it to not cause too
# much disruption.
self._vf = MLPHeadConfig(
input_dims=[256], # from encoder
hidden_layer_dims=[256], # pi head
hidden_layer_activation="relu",
hidden_layer_weights_initializer=nn.init.zeros_,
hidden_layer_bias_initializer=nn.init.zeros_,
output_layer_dim=1, # 1=value node
output_layer_activation="linear",
output_layer_weights_initializer=nn.init.zeros_,
output_layer_bias_initializer=nn.init.zeros_,
).build(framework="torch")

@override(MyBCModel)
def _forward_train(self, batch, **kwargs):
features = self._encoder(batch)[ENCODER_OUT]
logits = self._pi(features)
vf_out = self._vf(features).squeeze(-1)
return {
Columns.ACTION_DIST_INPUTS: logits,
Columns.VF_PREDS: vf_out,
}

@override(ValueFunctionAPI)
def compute_values(self, batch):
# Compute features ...
features = self._encoder(batch)[ENCODER_OUT]
# then values using our value head.
return self._vf(features).squeeze(-1)


if __name__ == "__main__":
args = parser.parse_args()

assert args.env == "CartPole-v1", "This example works only with --env=CartPole-v1!"

# Define the data paths for our CartPole large dataset.
base_path = Path(__file__).parents[2]
assert base_path.is_dir(), base_path
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
assert data_path.is_dir(), data_path
print(f"data_path={data_path}")

# Define the BC config.
base_config = (
BCConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Note, the `input_` argument is the major argument for the
# new offline API. Via the `input_read_method_kwargs` the
# arguments for the `ray.data.Dataset` read method can be
# configured. The read method needs at least as many blocks
# as remote learners.
.offline_data(
input_=[data_path.as_posix()],
# Define the number of reading blocks, these should be larger than 1
# and aligned with the data size.
input_read_method_kwargs={"override_num_blocks": max(args.num_gpus * 2, 2)},
# Concurrency defines the number of processes that run the
# `map_batches` transformations. This should be aligned with the
# 'prefetch_batches' argument in 'iter_batches_kwargs'.
map_batches_kwargs={"concurrency": 2, "num_cpus": 2},
# This data set is small so do not prefetch too many batches and use no
# local shuffle.
iter_batches_kwargs={
"prefetch_batches": 1,
"local_shuffle_buffer_size": None,
},
# The number of iterations to be run per learner when in multi-learner
# mode in a single RLlib training iteration. Leave this to `None` to
# run an entire epoch on the dataset during a single RLlib training
# iteration. For single-learner mode 1 is the only option.
dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None,
)
.training(
train_batch_size_per_learner=1024,
# To increase learning speed with multiple learners,
# increase the learning rate correspondingly.
lr=0.0008 * max(1, args.num_gpus**0.5),
)
# Plug in our simple custom BC model from above.
.rl_module(rl_module_spec=RLModuleSpec(module_class=MyBCModel))
# Run evaluation to observe how good our BC policy already is.
.evaluation(
evaluation_interval=3,
evaluation_num_env_runners=1,
evaluation_duration=5,
evaluation_parallel_to_training=True,
)
)

# Run the BC experiment and stop at R=250.0
metric_key = f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
stop = {metric_key: 250.0}
results = run_rllib_example_script_experiment(base_config, args, stop=stop)

# Extract the RLModule checkpoint.
best_result = results.get_best_result(metric_key)
rl_module_checkpoint = (
Path(best_result.checkpoint.path)
/ COMPONENT_LEARNER_GROUP
/ COMPONENT_LEARNER
/ COMPONENT_RL_MODULE
/ "default_policy"
)

# Create a new PPO config.
base_config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment(args.env)
.training(
# Keep lr relatively low at the beginning to avoid catastrophic forgetting.
lr=0.00002,
num_epochs=6,
vf_loss_coeff=0.01,
)
# Plug in our simple custom PPO model from above. Note that the checkpoint
# for the BC model is loadable into the PPO model, b/c the BC model is a subset
# of the PPO model (all weights/biases in the BC model are also found in the PPO
# model; the PPO model only has an additional value function branch).
.rl_module(
rl_module_spec=RLModuleSpec(
module_class=MyPPOModel,
load_state_path=rl_module_checkpoint,
)
)
)

# Quick test, whether initial performance in the loaded (now PPO) model is ok.
ppo = base_config.build()
eval_results = ppo.evaluate()
R = eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
assert R >= 200.0, f"Initial PPO performance bad! R={R} (expected 200.0+)."
print(f"PPO return after initialization: {R}")
# Check, whether training 2 times causes catastrophic forgetting.
ppo.train()
train_results = ppo.train()
R = train_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
assert R >= 250.0, f"PPO performance (training) bad! R={R} (expected 250.0+)."
print(f"PPO return after 2x training: {R}")

# Perform actual PPO training run (this time until 450.0 return).
stop = {
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 450.0,
}
run_rllib_example_script_experiment(base_config, args, stop=stop)

0 comments on commit b676d02

Please sign in to comment.