-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix torch scheduler stepping and reporting. #48125
[RLlib] Fix torch scheduler stepping and reporting. #48125
Conversation
…t no duplicates or fragments are added to the replay buffer b/c it cannot handle these. Furthermore, refined tests for 'OfflinePreLearner'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…mplete_episodes' when recording episodes. This ensures that episodes can be read in again for training. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…dium'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ent read batch size in case 'EpisodeType' or 'BatchType' data is stored in offline datasets. Added some docstrings. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…d RLlib schedulers in case a user set torch schedulers. These learning rates get now also correctly reported via this method. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
|
||
# If the module uses learning rate schedulers, step them here. | ||
if module_id in self._lr_schedulers: | ||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! This is so much better to do this in this method here (instead of in apply_gradients
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Awesome fix, thanks @simonsays1980 . Just one important request about adding the assertion into the example script such that this script really acts as a unit test for this functionality.
…er and switched the if-else condition to save us an indent. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…rates. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…he number of timesteps for testing in the example because CI tests were failing. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
from ray.rllib.utils.framework import try_import_torch | ||
from ray.rllib.utils.test_utils import add_rllib_example_script_args | ||
|
||
torch, _ = try_import_torch() | ||
|
||
parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=200000) | ||
|
||
class LRChecker(DefaultCallbacks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
from ray.rllib.utils.framework import try_import_torch | ||
from ray.rllib.utils.test_utils import add_rllib_example_script_args | ||
|
||
torch, _ = try_import_torch() | ||
|
||
parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=200000) | ||
|
||
class LRChecker(DefaultCallbacks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class LRChecker(DefaultCallbacks): | |
class LRChecker(DefaultCallbacks): | |
"""A custom callback that asserts the functionality of the lr-scheduler used.""" |
for optimizer_name in optimizer_names: | ||
# If learning rate schedulers are provided step them here. Note, | ||
# stepping them in `TorchLearner.apply_gradients` updates the | ||
# learning rates during minibatch updates; we want to update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for clarifying this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now! Thanks for the additional fixes and adding the test case (through the callback) to the example script! Very cool.
Signed-off-by: JP-sDEV <jon.pablo80@gmail.com>
Signed-off-by: mohitjain2504 <mohit.jain@dream11.com>
Why are these changes needed?
When using
torch.optim.lr_scheduler.LRScheduler
schedulers learning rates were wrongly updated and reported b/c they were stepped inTorchLearner.apply_gradients
and reported inLearner.after_gradient_based_update
. The latter function also updated any RLlib specific learning rate schedules and therefore reset any learning rate schedules of the torch-specific schedulers.This PR fixes this error and instead overrides the
after_gradient_based_update
to step learning rate schedulers in there and report them correctly. It also avoids the RLlib schedulers to override updated learning rates. In case no torch-specific learning rates are used, thesuper()
is called an RLlib schedules do their work.This setup also updates the torch-specific schedulers at the right point in regard to SGD: after an SGD epoch and not in each minibatch step.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.