Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder (vol 24): Mixed-precision training (and float16 inference) through new example script. #47116

Merged
merged 24 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2977,6 +2977,14 @@ py_test(
srcs = ["examples/gpus/float16_training_and_inference.py"],
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=150.0"]
)
py_test(
name = "examples/gpus/mixed_precision_training_float16_inference",
main = "examples/gpus/mixed_precision_training_float16_inference.py",
tags = ["team:rllib", "exclusive", "examples", "gpu"],
size = "medium",
srcs = ["examples/gpus/mixed_precision_training_float16_inference.py"],
args = ["--enable-new-api-stack", "--as-test"]
)
py_test(
name = "examples/gpus/fractional_0.5_gpus_per_learner",
main = "examples/gpus/fractional_gpus_per_learner.py",
Expand Down
15 changes: 9 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3199,12 +3199,15 @@ def experimental(
Args:
_torch_grad_scaler_class: Class to use for torch loss scaling (and gradient
unscaling). The class must implement the following methods to be
compatible with a `TorchLearner`. These methods/APIs match exactly the
those of torch's own `torch.amp.GradScaler`:
`scale([loss])` to scale the loss.
`get_scale()` to get the current scale value.
`step([optimizer])` to unscale the grads and step the given optimizer.
`update()` to update the scaler after an optimizer step.
compatible with a `TorchLearner`. These methods/APIs match exactly those
of torch's own `torch.amp.GradScaler` (see here for more details
https://pytorch.org/docs/stable/amp.html#gradient-scaling):
`scale([loss])` to scale the loss by some factor.
`get_scale()` to get the current scale factor value.
`step([optimizer])` to unscale the grads (divide by the scale factor)
and step the given optimizer.
`update()` to update the scaler after an optimizer step (for example to
adjust the scale factor).
_tf_policy_handles_more_than_one_loss: Experimental flag.
If True, TFPolicy will handle more than one loss/optimizer.
Set this to True, if you would like to return more than
Expand Down
24 changes: 14 additions & 10 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Learner(Checkpointable):
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule
)
from ray.rllib.core import COMPONENT_RL_MODULE
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

env = gym.make("CartPole-v1")
Expand Down Expand Up @@ -215,9 +215,11 @@ class Learner(Checkpointable):
class MyLearner(TorchLearner):

def compute_losses(self, fwd_out, batch):
# Compute the loss based on batch and output of the forward pass
# to access the learner hyper-parameters use `self._hps`
return {ALL_MODULES: loss}
# Compute the losses per module based on `batch` and output of the
# forward pass (`fwd_out`). To access the (algorithm) config for a
# specific RLModule, do:
# `self.config.get_config_for_module([moduleID])`.
return {DEFAULT_MODULE_ID: module_loss}
"""

framework: str = None
Expand Down Expand Up @@ -849,14 +851,16 @@ def compute_losses(
"""Computes the loss(es) for the module being optimized.

This method must be overridden by MultiRLModule-specific Learners in order to
define the specific loss computation logic. If the algorithm is single-agent
`compute_loss_for_module()` should be overridden instead. If the algorithm uses
independent multi-agent learning (default behavior for multi-agent setups), also
`compute_loss_for_module()` should be overridden, but it will be called for each
individual RLModule inside the MultiRLModule.
define the specific loss computation logic. If the algorithm is single-agent,
only `compute_loss_for_module()` should be overridden instead. If the algorithm
uses independent multi-agent learning (default behavior for RLlib's multi-agent
setups), also only `compute_loss_for_module()` should be overridden, but it will
be called for each individual RLModule inside the MultiRLModule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a point for when to specifically override compute_losses instead of compute_loss_for_module

It is recommended to not compute any forward passes within this method, and to
use the `forward_train()` outputs of the RLModule(s) to compute the required
tensors for loss calculations.
loss tensors.
See here for a custom loss function example script:
https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py # noqa

Args:
fwd_out: Output from a call to the `forward_train()` method of the
Expand Down
44 changes: 30 additions & 14 deletions rllib/examples/gpus/float16_training_and_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
learning with float16 weight matrices and gradients. This custom scaler behaves
exactly like the torch built-in `torch.amp.GradScaler` but also works for float16
gradients (which the torch built-in one doesn't).
- shows how to write a custom TorchLearner to change the epsilon setting (to the
much larger 1e-4 to stabilize learning) on the default optimizer (Adam) registered
for each RLModule.
- demonstrates how to plug in all the above custom components into an
`AlgorithmConfig` instance and start training (and inference) with float16
precision.
Expand Down Expand Up @@ -75,7 +78,7 @@
)


class Float16InitCallback(DefaultCallbacks):
class MakeAllRLModulesFloat16(DefaultCallbacks):
"""Callback making sure that all RLModules in the algo are `half()`'ed."""

def on_algorithm_init(
Expand All @@ -99,7 +102,7 @@ def on_algorithm_init(
)


class Float16Connector(ConnectorV2):
class WriteObsAndRewardsAsFloat16(ConnectorV2):
"""ConnectorV2 piece preprocessing observations and rewards to be float16.

Note that users can also write a gymnasium.Wrapper for observations and rewards
Expand Down Expand Up @@ -196,20 +199,25 @@ def update(self):
self._found_inf_or_nan = False


class Float16TorchLearner(PPOTorchLearner):
class LargeEpsAdamTorchLearner(PPOTorchLearner):
"""A TorchLearner overriding the default optimizer (Adam) to use non-default eps."""

@override(TorchLearner)
def configure_optimizers_for_module(self, module_id, config):
module = self._module[module_id]

params = self.get_parameters(module)
# Create an Adam optimizer with a different eps for better float16 stability.
optimizer = torch.optim.Adam(params, eps=1e-4)
"""Registers an Adam optimizer with a larg epsilon under the given module_id."""
params = list(self._module[module_id].parameters())

# Register the created optimizer (under the default optimizer name).
# Register one Adam optimizer (under the default optimizer name:
# DEFAULT_OPTIMIZER) for the `module_id`.
self.register_optimizer(
module_id=module_id,
optimizer=optimizer,
# Create an Adam optimizer with a different eps for better float16
# stability.
optimizer=torch.optim.Adam(params, eps=1e-4),
params=params,
# Let RLlib handle the learning rate/learning rate schedule.
# You can leave `lr_or_lr_schedule` at None, but then you should
# pass a fixed learning rate into the Adam constructor above.
lr_or_lr_schedule=config.lr,
)

Expand All @@ -221,12 +229,20 @@ def configure_optimizers_for_module(self, module_id, config):
get_trainable_cls(args.algo)
.get_default_config()
.environment("CartPole-v1")
# Plug in our custom loss scaler class.
# Plug in our custom callback (on_algorithm_init) to make all RLModules
# float16 models.
.callbacks(MakeAllRLModulesFloat16)
# Plug in our custom loss scaler class to stabilize gradient computations
# (by scaling the loss, then unscaling the gradients before applying them).
# This is using the built-in, experimental feature of TorchLearner.
.experimental(_torch_grad_scaler_class=Float16GradScaler)
.env_runners(env_to_module_connector=lambda env: Float16Connector())
.callbacks(Float16InitCallback)
# Plug in our custom env-to-module ConnectorV2 piece to convert all observations
# and reward in the episodes (permanently) to float16.
.env_runners(env_to_module_connector=lambda env: WriteObsAndRewardsAsFloat16())
.training(
learner_class=Float16TorchLearner,
# Plug in our custom TorchLearner (using a much larger, stabilizing epsilon
# on the Adam optimizer).
learner_class=LargeEpsAdamTorchLearner,
# Switch off grad clipping entirely b/c we use our custom grad scaler with
# built-in inf/nan detection (see `step` method of `Float16GradScaler`).
grad_clip=None,
Expand Down
178 changes: 178 additions & 0 deletions rllib/examples/gpus/mixed_precision_training_float16_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Example of using automatic mixed precision training on a torch RLModule.

This example:
- shows how to write a custom callback for RLlib to convert those RLModules
only(!) on the EnvRunners to float16 precision.
- shows how to write a custom env-to-module ConnectorV2 piece to add float16
observations to the action computing forward batch on the EnvRunners, but NOT
permanently write these changes into the episodes, such that on the
Learner side, the original float32 observations will be used (for the mixed
precision `forward_train` and `loss` computations).
- shows how to plugin torch's built-in `GradScaler` class to be used by the
TorchLearner to scale losses and unscale gradients in order to gain more stability
when training with mixed precision.
- shows how to write a custom TorchLearner to run the update step (overrides
`_update()`) within a `torch.amp.autocast()` context. This makes sure that .
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentence has no end: "This makes sure that ...?"

- demonstrates how to plug in all the above custom components into an
`AlgorithmConfig` instance and start training with mixed-precision while
performing the inference on the EnvRunners with float16 precision.


How to run this script
----------------------
`python [script file name].py --enable-new-api-stack

For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.

Note that the shown GPU settings in this script also work in case you are not
running via tune, but instead are using the `--no-tune` command line option.

For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`

You can visualize experiment results in ~/ray_results using TensorBoard.


Results to expect
-----------------
In the console output, you should see something like this:

+-----------------------------+------------+-----------------+--------+
| Trial name | status | loc | iter |
| | | | |
|-----------------------------+------------+-----------------+--------+
| PPO_CartPole-v1_485af_00000 | TERMINATED | 127.0.0.1:81045 | 22 |
+-----------------------------+------------+-----------------+--------+
+------------------+------------------------+------------------------+
| total time (s) | episode_return_mean | num_episodes_lifetime |
| | | |
|------------------+------------------------+------------------------+
| 281.3231 | 455.81 | 1426 |
+------------------+------------------------+------------------------+
"""
from typing import Optional

import gymnasium as gym
import numpy as np
import torch

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)


parser = add_rllib_example_script_args(
default_iters=200, default_reward=450.0, default_timesteps=200000
)
parser.set_defaults(
algo="PPO",
enable_new_api_stack=True,
)


class MakeEnvRunnerRLModulesFloat16(DefaultCallbacks):
"""Callback making sure that all RLModules in the algo are `half()`'ed."""

def on_algorithm_init(
self,
*,
algorithm: Algorithm,
metrics_logger: Optional[MetricsLogger] = None,
**kwargs,
) -> None:
# Switch all EnvRunner RLModules (assuming single RLModules) to float16.
algorithm.env_runner_group.foreach_worker(
lambda env_runner: env_runner.module.half()
)
if algorithm.eval_env_runner_group:
algorithm.eval_env_runner_group.foreach_worker(
lambda env_runner: env_runner.module.half()
)


class Float16Connector(ConnectorV2):
"""ConnectorV2 piece preprocessing observations and rewards to be float16.

Note that users can also write a gymnasium.Wrapper for observations and rewards
to achieve the same thing.
"""

def recompute_output_observation_space(
self,
input_observation_space,
input_action_space,
):
return gym.spaces.Box(
input_observation_space.low.astype(np.float16),
input_observation_space.high.astype(np.float16),
input_observation_space.shape,
np.float16,
)

def __call__(self, *, rl_module, batch, episodes, **kwargs):
for sa_episode in self.single_agent_episode_iterator(episodes):
obs = sa_episode.get_observations(-1)
float16_obs = obs.astype(np.float16)
self.add_batch_item(
batch,
column="obs",
item_to_add=float16_obs,
single_agent_episode=sa_episode,
)
return batch


class PPOTorchMixedPrecisionLearner(PPOTorchLearner):
def _update(self, *args, **kwargs):
with torch.cuda.amp.autocast():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I learned something new! Thanks!

results = super()._update(*args, **kwargs)
return results


if __name__ == "__main__":
args = parser.parse_args()

assert (
args.enable_new_api_stack
), "Must set --enable-new-api-stack when running this script!"
assert args.algo == "PPO", "Must set --algo=PPO when running this script!"

base_config = (
(PPOConfig().environment("CartPole-v1"))
.env_runners(env_to_module_connector=lambda env: Float16Connector())
# Plug in our custom callback (on_algorithm_init) to make EnvRunner RLModules
# float16 models.
.callbacks(MakeEnvRunnerRLModulesFloat16)
# Plug in the torch built-int loss scaler class to stabilize gradient
# computations (by scaling the loss, then unscaling the gradients before
# applying them). This is using the built-in, experimental feature of
# TorchLearner.
.experimental(_torch_grad_scaler_class=torch.cuda.amp.GradScaler)
.training(
# Plug in the custom Learner class to activate mixed-precision training for
# our torch RLModule (uses `torch.amp.autocast()`).
learner_class=PPOTorchMixedPrecisionLearner,
# Switch off grad clipping entirely b/c we use our custom grad scaler with
# built-in inf/nan detection (see `step` method of `Float16GradScaler`).
grad_clip=None,
# Typical CartPole-v1 hyperparams known to work well:
gamma=0.99,
lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
use_kl_loss=True,
)
)

run_rllib_example_script_experiment(base_config, args)
Loading