Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder (vol 24): Mixed-precision training (and float16 inference) through new example script. #47116

Merged
merged 24 commits into from
Aug 29, 2024

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Aug 13, 2024

This PR adds a new example script demo'ing:

  • how to write a custom callback for RLlib to convert those RLModules only(!) on the EnvRunners to float16 precision.
  • how to write a custom env-to-module ConnectorV2 piece to add float16 observations to the action computing forward batch on the EnvRunners, but NOT permanently write these changes into the episodes, such that on the Learner side, the original float32 observations will be used (for the mixed precision forward_train and loss computations).
  • how to plugin torch's built-in GradScaler class to be used by the TorchLearner to scale losses and unscale gradients in order to gain more stability when training with mixed precision.
  • how to write a custom TorchLearner to run the update step (overrides _update()) within a torch.amp.autocast() context.
  • demonstrates how to plug in all the above custom components into an AlgorithmConfig instance and start training with mixed-precision while performing the inference on the EnvRunners with float16 precision.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@can-anyscale can-anyscale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp the doc changes

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…es_folder_24_mixed_precision

# Conflicts:
#	rllib/algorithms/algorithm_config.py
#	rllib/core/learner/torch/torch_learner.py
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 changed the title [RLlib] Add experimental setting for mixed-precision learning (new API stack only). [RLlib] Cleanup examples folder (vol 24): Mixed-precision training (and float16 inference) through new example script. Aug 28, 2024
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…nup_examples_folder_24_mixed_precision

Signed-off-by: sven1977 <svenmika1977@gmail.com>

# Conflicts:
#	rllib/BUILD
#	rllib/algorithms/algorithm_config.py
#	rllib/core/learner/learner.py
#	rllib/core/models/torch/primitives.py
#	rllib/examples/gpus/float16_training_and_inference.py
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) August 29, 2024 06:17
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Aug 29, 2024
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Again, such a great example!

only `compute_loss_for_module()` should be overridden instead. If the algorithm
uses independent multi-agent learning (default behavior for RLlib's multi-agent
setups), also only `compute_loss_for_module()` should be overridden, but it will
be called for each individual RLModule inside the MultiRLModule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a point for when to specifically override compute_losses instead of compute_loss_for_module

TorchLearner to scale losses and unscale gradients in order to gain more stability
when training with mixed precision.
- shows how to write a custom TorchLearner to run the update step (overrides
`_update()`) within a `torch.amp.autocast()` context. This makes sure that .
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentence has no end: "This makes sure that ...?"


class PPOTorchMixedPrecisionLearner(PPOTorchLearner):
def _update(self, *args, **kwargs):
with torch.cuda.amp.autocast():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I learned something new! Thanks!

@sven1977 sven1977 merged commit 751dbb1 into ray-project:master Aug 29, 2024
6 of 7 checks passed
@sven1977 sven1977 deleted the float16_precision branch August 30, 2024 11:10
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 12, 2024
…nd float16 inference) through new example script. (ray-project#47116)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…nd float16 inference) through new example script. (ray-project#47116)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…nd float16 inference) through new example script. (ray-project#47116)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…nd float16 inference) through new example script. (ray-project#47116)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…nd float16 inference) through new example script. (ray-project#47116)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants