-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder (vol 24): Mixed-precision training (and float16 inference) through new example script. #47116
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
231ec2a
wip
sven1977 9a583f8
wip
sven1977 c1e4c8f
wip
sven1977 c127eb5
wip
sven1977 4fe7cf4
wip
sven1977 d87305b
Merge branch 'master' of https://github.com/ray-project/ray into floa…
sven1977 178fe9f
wip
sven1977 41600c3
wip
sven1977 f9b2355
Merge branch 'master' of https://github.com/ray-project/ray into floa…
sven1977 9feb1ef
wip
sven1977 33fefd8
wip
sven1977 42da46b
wip
sven1977 819c241
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 cc5e5c1
wip
sven1977 6596ee5
fix
sven1977 c9d07c6
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 e676226
fix
sven1977 eac3b8b
Merge branch 'cleanup_examples_folder_23_float16' into cleanup_exampl…
sven1977 c1e8d0f
wip
sven1977 d48e0d3
wip
sven1977 afba480
wip
sven1977 5406943
wip
sven1977 49bd23b
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 7ef8f9c
wip
sven1977 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
rllib/examples/gpus/mixed_precision_training_float16_inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
"""Example of using automatic mixed precision training on a torch RLModule. | ||
|
||
This example: | ||
- shows how to write a custom callback for RLlib to convert those RLModules | ||
only(!) on the EnvRunners to float16 precision. | ||
- shows how to write a custom env-to-module ConnectorV2 piece to add float16 | ||
observations to the action computing forward batch on the EnvRunners, but NOT | ||
permanently write these changes into the episodes, such that on the | ||
Learner side, the original float32 observations will be used (for the mixed | ||
precision `forward_train` and `loss` computations). | ||
- shows how to plugin torch's built-in `GradScaler` class to be used by the | ||
TorchLearner to scale losses and unscale gradients in order to gain more stability | ||
when training with mixed precision. | ||
- shows how to write a custom TorchLearner to run the update step (overrides | ||
`_update()`) within a `torch.amp.autocast()` context. This makes sure that . | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sentence has no end: "This makes sure that ...?" |
||
- demonstrates how to plug in all the above custom components into an | ||
`AlgorithmConfig` instance and start training with mixed-precision while | ||
performing the inference on the EnvRunners with float16 precision. | ||
|
||
|
||
How to run this script | ||
---------------------- | ||
`python [script file name].py --enable-new-api-stack | ||
|
||
For debugging, use the following additional command line options | ||
`--no-tune --num-env-runners=0` | ||
which should allow you to set breakpoints anywhere in the RLlib code and | ||
have the execution stop there for inspection and debugging. | ||
|
||
Note that the shown GPU settings in this script also work in case you are not | ||
running via tune, but instead are using the `--no-tune` command line option. | ||
|
||
For logging to your WandB account, use: | ||
`--wandb-key=[your WandB API key] --wandb-project=[some project name] | ||
--wandb-run-name=[optional: WandB run name (within the defined project)]` | ||
|
||
You can visualize experiment results in ~/ray_results using TensorBoard. | ||
|
||
|
||
Results to expect | ||
----------------- | ||
In the console output, you should see something like this: | ||
|
||
+-----------------------------+------------+-----------------+--------+ | ||
| Trial name | status | loc | iter | | ||
| | | | | | ||
|-----------------------------+------------+-----------------+--------+ | ||
| PPO_CartPole-v1_485af_00000 | TERMINATED | 127.0.0.1:81045 | 22 | | ||
+-----------------------------+------------+-----------------+--------+ | ||
+------------------+------------------------+------------------------+ | ||
| total time (s) | episode_return_mean | num_episodes_lifetime | | ||
| | | | | ||
|------------------+------------------------+------------------------+ | ||
| 281.3231 | 455.81 | 1426 | | ||
+------------------+------------------------+------------------------+ | ||
""" | ||
from typing import Optional | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
|
||
from ray.rllib.algorithms.algorithm import Algorithm | ||
from ray.rllib.algorithms.callbacks import DefaultCallbacks | ||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner | ||
from ray.rllib.connectors.connector_v2 import ConnectorV2 | ||
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger | ||
from ray.rllib.utils.test_utils import ( | ||
add_rllib_example_script_args, | ||
run_rllib_example_script_experiment, | ||
) | ||
|
||
|
||
parser = add_rllib_example_script_args( | ||
default_iters=200, default_reward=450.0, default_timesteps=200000 | ||
) | ||
parser.set_defaults( | ||
algo="PPO", | ||
enable_new_api_stack=True, | ||
) | ||
|
||
|
||
class MakeEnvRunnerRLModulesFloat16(DefaultCallbacks): | ||
"""Callback making sure that all RLModules in the algo are `half()`'ed.""" | ||
|
||
def on_algorithm_init( | ||
self, | ||
*, | ||
algorithm: Algorithm, | ||
metrics_logger: Optional[MetricsLogger] = None, | ||
**kwargs, | ||
) -> None: | ||
# Switch all EnvRunner RLModules (assuming single RLModules) to float16. | ||
algorithm.env_runner_group.foreach_worker( | ||
lambda env_runner: env_runner.module.half() | ||
) | ||
if algorithm.eval_env_runner_group: | ||
algorithm.eval_env_runner_group.foreach_worker( | ||
lambda env_runner: env_runner.module.half() | ||
) | ||
|
||
|
||
class Float16Connector(ConnectorV2): | ||
"""ConnectorV2 piece preprocessing observations and rewards to be float16. | ||
|
||
Note that users can also write a gymnasium.Wrapper for observations and rewards | ||
to achieve the same thing. | ||
""" | ||
|
||
def recompute_output_observation_space( | ||
self, | ||
input_observation_space, | ||
input_action_space, | ||
): | ||
return gym.spaces.Box( | ||
input_observation_space.low.astype(np.float16), | ||
input_observation_space.high.astype(np.float16), | ||
input_observation_space.shape, | ||
np.float16, | ||
) | ||
|
||
def __call__(self, *, rl_module, batch, episodes, **kwargs): | ||
for sa_episode in self.single_agent_episode_iterator(episodes): | ||
obs = sa_episode.get_observations(-1) | ||
float16_obs = obs.astype(np.float16) | ||
self.add_batch_item( | ||
batch, | ||
column="obs", | ||
item_to_add=float16_obs, | ||
single_agent_episode=sa_episode, | ||
) | ||
return batch | ||
|
||
|
||
class PPOTorchMixedPrecisionLearner(PPOTorchLearner): | ||
def _update(self, *args, **kwargs): | ||
with torch.cuda.amp.autocast(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! I learned something new! Thanks! |
||
results = super()._update(*args, **kwargs) | ||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
assert ( | ||
args.enable_new_api_stack | ||
), "Must set --enable-new-api-stack when running this script!" | ||
assert args.algo == "PPO", "Must set --algo=PPO when running this script!" | ||
|
||
base_config = ( | ||
(PPOConfig().environment("CartPole-v1")) | ||
.env_runners(env_to_module_connector=lambda env: Float16Connector()) | ||
# Plug in our custom callback (on_algorithm_init) to make EnvRunner RLModules | ||
# float16 models. | ||
.callbacks(MakeEnvRunnerRLModulesFloat16) | ||
# Plug in the torch built-int loss scaler class to stabilize gradient | ||
# computations (by scaling the loss, then unscaling the gradients before | ||
# applying them). This is using the built-in, experimental feature of | ||
# TorchLearner. | ||
.experimental(_torch_grad_scaler_class=torch.cuda.amp.GradScaler) | ||
.training( | ||
# Plug in the custom Learner class to activate mixed-precision training for | ||
# our torch RLModule (uses `torch.amp.autocast()`). | ||
learner_class=PPOTorchMixedPrecisionLearner, | ||
# Switch off grad clipping entirely b/c we use our custom grad scaler with | ||
# built-in inf/nan detection (see `step` method of `Float16GradScaler`). | ||
grad_clip=None, | ||
# Typical CartPole-v1 hyperparams known to work well: | ||
gamma=0.99, | ||
lr=0.0003, | ||
num_sgd_iter=6, | ||
vf_loss_coeff=0.01, | ||
use_kl_loss=True, | ||
) | ||
) | ||
|
||
run_rllib_example_script_experiment(base_config, args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a point for when to specifically override
compute_losses
instead ofcompute_loss_for_module