-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder (vol 24): Mixed-precision training (and float16 inference) through new example script. #47116
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp the doc changes
…nup_examples_folder_23_float16
…nup_examples_folder_24_mixed_precision
…es_folder_24_mixed_precision # Conflicts: # rllib/algorithms/algorithm_config.py # rllib/core/learner/torch/torch_learner.py
…nup_examples_folder_24_mixed_precision Signed-off-by: sven1977 <svenmika1977@gmail.com> # Conflicts: # rllib/BUILD # rllib/algorithms/algorithm_config.py # rllib/core/learner/learner.py # rllib/core/models/torch/primitives.py # rllib/examples/gpus/float16_training_and_inference.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Again, such a great example!
only `compute_loss_for_module()` should be overridden instead. If the algorithm | ||
uses independent multi-agent learning (default behavior for RLlib's multi-agent | ||
setups), also only `compute_loss_for_module()` should be overridden, but it will | ||
be called for each individual RLModule inside the MultiRLModule. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a point for when to specifically override compute_losses
instead of compute_loss_for_module
TorchLearner to scale losses and unscale gradients in order to gain more stability | ||
when training with mixed precision. | ||
- shows how to write a custom TorchLearner to run the update step (overrides | ||
`_update()`) within a `torch.amp.autocast()` context. This makes sure that . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sentence has no end: "This makes sure that ...?"
|
||
class PPOTorchMixedPrecisionLearner(PPOTorchLearner): | ||
def _update(self, *args, **kwargs): | ||
with torch.cuda.amp.autocast(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I learned something new! Thanks!
…nd float16 inference) through new example script. (ray-project#47116) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…nd float16 inference) through new example script. (ray-project#47116) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…nd float16 inference) through new example script. (ray-project#47116) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…nd float16 inference) through new example script. (ray-project#47116) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…nd float16 inference) through new example script. (ray-project#47116) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
This PR adds a new example script demo'ing:
forward_train
andloss
computations).GradScaler
class to be used by the TorchLearner to scale losses and unscale gradients in order to gain more stability when training with mixed precision._update()
) within atorch.amp.autocast()
context.AlgorithmConfig
instance and start training with mixed-precision while performing the inference on the EnvRunners with float16 precision.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.