Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix torch scheduler stepping and reporting. #48125

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Oct 21, 2024

Why are these changes needed?

When using torch.optim.lr_scheduler.LRScheduler schedulers learning rates were wrongly updated and reported b/c they were stepped in TorchLearner.apply_gradients and reported in Learner.after_gradient_based_update. The latter function also updated any RLlib specific learning rate schedules and therefore reset any learning rate schedules of the torch-specific schedulers.

This PR fixes this error and instead overrides the after_gradient_based_update to step learning rate schedulers in there and report them correctly. It also avoids the RLlib schedulers to override updated learning rates. In case no torch-specific learning rates are used, the super() is called an RLlib schedules do their work.

This setup also updates the torch-specific schedulers at the right point in regard to SGD: after an SGD epoch and not in each minibatch step.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…t no duplicates or fragments are added to the replay buffer b/c it cannot handle these. Furthermore, refined tests for 'OfflinePreLearner'.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…mplete_episodes' when recording episodes. This ensures that episodes can be read in again for training.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…dium'.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…ent read batch size in case 'EpisodeType' or 'BatchType' data is stored in offline datasets. Added some docstrings.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…d RLlib schedulers in case a user set torch schedulers. These learning rates get now also correctly reported via this method.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@simonsays1980 simonsays1980 added rllib RLlib related issues rllib-algorithms An RLlib algorithm/Trainer is not learning. rllib-logging This problem is related to logging metrics bug Something that is supposed to be working; but isn't labels Oct 21, 2024
@simonsays1980 simonsays1980 marked this pull request as ready for review October 21, 2024 17:26
@sven1977 sven1977 changed the title [RLlib] - Fix torch scheduler stepping and reporting [RLlib] Fix torch scheduler stepping and reporting. Oct 23, 2024

# If the module uses learning rate schedulers, step them here.
if module_id in self._lr_schedulers:
@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! This is so much better to do this in this method here (instead of in apply_gradients).

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Awesome fix, thanks @simonsays1980 . Just one important request about adding the assertion into the example script such that this script really acts as a unit test for this functionality.

@sven1977 sven1977 enabled auto-merge (squash) November 4, 2024 09:14
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Nov 4, 2024
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import add_rllib_example_script_args

torch, _ = try_import_torch()

parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=200000)

class LRChecker(DefaultCallbacks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import add_rllib_example_script_args

torch, _ = try_import_torch()

parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=200000)

class LRChecker(DefaultCallbacks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class LRChecker(DefaultCallbacks):
class LRChecker(DefaultCallbacks):
"""A custom callback that asserts the functionality of the lr-scheduler used."""

for optimizer_name in optimizer_names:
# If learning rate schedulers are provided step them here. Note,
# stepping them in `TorchLearner.apply_gradients` updates the
# learning rates during minibatch updates; we want to update
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for clarifying this.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now! Thanks for the additional fixes and adding the test case (through the callback) to the example script! Very cool.

@sven1977 sven1977 merged commit ec9775d into ray-project:master Nov 4, 2024
7 checks passed
JP-sDEV pushed a commit to JP-sDEV/ray that referenced this pull request Nov 14, 2024
mohitjain2504 pushed a commit to mohitjain2504/ray that referenced this pull request Nov 15, 2024
Signed-off-by: mohitjain2504 <mohit.jain@dream11.com>
@simonsays1980 simonsays1980 deleted the fix-torch-scheduler-stepping-and-reporting branch November 22, 2024 10:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't go add ONLY when ready to merge, run all tests rllib RLlib related issues rllib-algorithms An RLlib algorithm/Trainer is not learning. rllib-logging This problem is related to logging metrics
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants