Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix torch scheduler stepping and reporting. #48125

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
173e762
Added validation method for episodes to 'OfflinePreLearner', such tha…
simonsays1980 Oct 17, 2024
6a609f1
Added a validation for user settings that ensures that 'batch_mode=co…
simonsays1980 Oct 17, 2024
63c1ca9
Decreased iterations for recording and set test for prelearner to 'me…
simonsays1980 Oct 17, 2024
b726bbd
Merge branch 'master' into offline-rl-handle-duplicates-in-buffer
simonsays1980 Oct 19, 2024
42ea05f
Added 'input_read_batch_size' to enable users to control for a differ…
simonsays1980 Oct 21, 2024
859ffd1
Merge branch 'master' into offline-rl-handle-duplicates-in-buffer
simonsays1980 Oct 21, 2024
2c7e56b
Added @sven1977's review.
simonsays1980 Oct 21, 2024
3ec4fa4
Added 'after_gradient_based_update' to step torch schedulers and avoi…
simonsays1980 Oct 21, 2024
c0e442d
Fixed failing test. Decay was too strong.
simonsays1980 Oct 22, 2024
bda96f3
Merge branch 'master' into fix-torch-scheduler-stepping-and-reporting
simonsays1980 Oct 22, 2024
f8a896d
Merge branch 'master' into fix-torch-scheduler-stepping-and-reporting
simonsays1980 Oct 28, 2024
3f9aa06
Added a quick callback for checking the learning rates in the optimiz…
simonsays1980 Oct 28, 2024
38e0d7a
Added callbacks to torch lr-scheduler examples to check for learning …
simonsays1980 Oct 29, 2024
c131e4f
Merge branch 'master' into fix-torch-scheduler-stepping-and-reporting
simonsays1980 Oct 29, 2024
b76abb7
Merge branch 'master' into fix-torch-scheduler-stepping-and-reporting
simonsays1980 Oct 29, 2024
5871709
Fixed a small index error in the 'LRChecker' callback and increased t…
simonsays1980 Nov 1, 2024
04c4094
Merge branch 'master' into fix-torch-scheduler-stepping-and-reporting
simonsays1980 Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AlgorithmConfig,
TorchCompileWhatToCompile,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.learner import Learner, LR_KEY
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
Expand All @@ -32,13 +32,15 @@
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
ALL_MODULES,
NUM_TRAINABLE_PARAMETERS,
NUM_NON_TRAINABLE_PARAMETERS,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor, copy_torch_tensors
from ray.rllib.utils.typing import (
ModuleID,
Expand Down Expand Up @@ -209,7 +211,10 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
# If we have learning rate schedulers for a module add them, if
# necessary.
if self._lr_scheduler_classes is not None:
if module_id not in self._lr_schedulers:
if (
module_id not in self._lr_schedulers
or optimizer_name not in self._lr_schedulers[module_id]
):
# Set for each module and optimizer a scheduler.
self._lr_schedulers[module_id] = {optimizer_name: []}
# If the classes are in a dictionary each module might have
Expand Down Expand Up @@ -257,15 +262,57 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
"`False`."
)

# If the module uses learning rate schedulers, step them here.
if module_id in self._lr_schedulers:
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
scheduler.step()

# If the module uses learning rate schedulers, step them here.
if module_id in self._lr_schedulers:
@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! This is so much better to do this in this method here (instead of in apply_gradients).

@override(Learner)
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
"""Called after gradient-based updates are completed.

Should be overridden to implement custom cleanup-, logging-, or non-gradient-
based Learner/RLModule update logic after(!) gradient-based updates have been
completed.

Note, for `framework="torch"` users can register
`torch.optim.lr_scheduler.LRScheduler` via
`AlgorithmConfig._torch_lr_scheduler_classes`. These schedulers need to be
stepped here after gradient updates and reported.

Args:
timesteps: Timesteps dict, which must have the key
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
# TODO (sven): Make this a more formal structure with its own type.
"""
# If we have `torch.optim.lr_scheduler.LRScheduler` we need to step them here
# and report learning rates.
if self._lr_schedulers:
# Only update this optimizer's lr, if a scheduler has been registered
# along with it.
for module_id, optimizer_names in self._module_optimizers.items():
for optimizer_name in optimizer_names:
# If learning rate schedulers are provided step them here. Note,
# stepping them in `TorchLearner.apply_gradients` updates the
# learning rates during minibatch updates; we want to update
# between whole batch updates.
if (
module_id in self._lr_schedulers
and optimizer_name in self._lr_schedulers[module_id]
):
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
scheduler.step()
optimizer = self.get_optimizer(module_id, optimizer_name)
self.metrics.log_value(
# Cut out the module ID from the beginning since it's already
# part of the key sequence: (ModuleID, "[optim name]_lr").
key=(
module_id,
f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}",
),
value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
window=1,
)
# Otherwise call the `super()`'s method to update RLlib's learning rate
# schedules.
else:
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
return super().after_gradient_based_update(timesteps=timesteps)

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
Expand Down
12 changes: 6 additions & 6 deletions rllib/examples/learners/ppo_with_torch_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --lr-const-factor=0.1
--lr-const-iters=10 --lr-exp-decay=0.3`
`python [script file name].py --enable-new-api-stack --lr-const-factor=0.9
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
--lr-const-iters=10 --lr-exp-decay=0.9`

Use the `--lr-const-factor` to define the facotr by which to multiply the
learning rate in the first `--lr-const-iters` iterations. Use the
Expand Down Expand Up @@ -68,7 +68,7 @@
parser.add_argument(
"--lr-const-factor",
type=float,
default=0.1,
default=0.9,
help="The factor by which the learning rate should be multiplied.",
)
parser.add_argument(
Expand All @@ -83,20 +83,20 @@
parser.add_argument(
"--lr-exp-decay",
type=float,
default=0.3,
default=0.99,
help="The rate by which the learning rate should exponentially decay.",
)

if __name__ == "__main__":
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values toset up `config` below.
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

config = (
PPOConfig()
.environment("CartPole-v1")
.training(
lr=0.0003,
lr=0.03,
num_sgd_iter=6,
vf_loss_coeff=0.01,
)
Expand Down