Skip to content

Commit

Permalink
[RLlib] Fix torch scheduler stepping and reporting. (#48125)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Nov 4, 2024
1 parent d39c9df commit ec9775d
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 17 deletions.
66 changes: 56 additions & 10 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AlgorithmConfig,
TorchCompileWhatToCompile,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.learner import Learner, LR_KEY
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
Expand All @@ -32,6 +32,7 @@
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
Expand All @@ -41,6 +42,7 @@
NUM_NON_TRAINABLE_PARAMETERS,
WEIGHTS_SEQ_NO,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor, copy_torch_tensors
from ray.rllib.utils.typing import (
ModuleID,
Expand Down Expand Up @@ -223,7 +225,10 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
# If we have learning rate schedulers for a module add them, if
# necessary.
if self._lr_scheduler_classes is not None:
if module_id not in self._lr_schedulers:
if (
module_id not in self._lr_schedulers
or optimizer_name not in self._lr_schedulers[module_id]
):
# Set for each module and optimizer a scheduler.
self._lr_schedulers[module_id] = {optimizer_name: []}
# If the classes are in a dictionary each module might have
Expand Down Expand Up @@ -271,15 +276,56 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None:
"`False`."
)

# If the module uses learning rate schedulers, step them here.
if module_id in self._lr_schedulers:
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
scheduler.step()
@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Learner)
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
"""Called after gradient-based updates are completed.
Should be overridden to implement custom cleanup-, logging-, or non-gradient-
based Learner/RLModule update logic after(!) gradient-based updates have been
completed.
Note, for `framework="torch"` users can register
`torch.optim.lr_scheduler.LRScheduler` via
`AlgorithmConfig._torch_lr_scheduler_classes`. These schedulers need to be
stepped here after gradient updates and reported.
Args:
timesteps: Timesteps dict, which must have the key
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
# TODO (sven): Make this a more formal structure with its own type.
"""

# If we have no `torch.optim.lr_scheduler.LRScheduler` registered call the
# `super()`'s method to update RLlib's learning rate schedules.
if not self._lr_schedulers:
return super().after_gradient_based_update(timesteps=timesteps)

# If the module uses learning rate schedulers, step them here.
if module_id in self._lr_schedulers:
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
scheduler.step()
# Only update this optimizer's lr, if a scheduler has been registered
# along with it.
for module_id, optimizer_names in self._module_optimizers.items():
for optimizer_name in optimizer_names:
# If learning rate schedulers are provided step them here. Note,
# stepping them in `TorchLearner.apply_gradients` updates the
# learning rates during minibatch updates; we want to update
# between whole batch updates.
if (
module_id in self._lr_schedulers
and optimizer_name in self._lr_schedulers[module_id]
):
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
scheduler.step()
optimizer = self.get_optimizer(module_id, optimizer_name)
self.metrics.log_value(
# Cut out the module ID from the beginning since it's already
# part of the key sequence: (ModuleID, "[optim name]_lr").
key=(
module_id,
f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}",
),
value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
window=1,
)

@override(Learner)
def _get_optimizer_state(self) -> StateDict:
Expand Down
79 changes: 72 additions & 7 deletions rllib/examples/learners/ppo_with_torch_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --lr-const-factor=0.1
--lr-const-iters=10 --lr-exp-decay=0.3`
`python [script file name].py --enable-new-api-stack --lr-const-factor=0.9
--lr-const-iters=10 --lr-exp-decay=0.9`
Use the `--lr-const-factor` to define the facotr by which to multiply the
learning rate in the first `--lr-const-iters` iterations. Use the
Expand Down Expand Up @@ -49,26 +49,88 @@
+------------------------+------------------------+------------------------+
"""
import functools
import numpy as np
from typing import Optional

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
EVALUATION_RESULTS,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import add_rllib_example_script_args

torch, _ = try_import_torch()

parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=200000)

class LRChecker(DefaultCallbacks):
def on_algorithm_init(
self,
*,
algorithm: "Algorithm",
metrics_logger: Optional[MetricsLogger] = None,
**kwargs,
) -> None:
# Store the expected learning rates for each iteration.
self.lr = []
# Retrieve the chosen configuration parameters from the config.
lr_factor = algorithm.config._torch_lr_scheduler_classes[0].keywords["factor"]
lr_total_iters = algorithm.config._torch_lr_scheduler_classes[0].keywords[
"total_iters"
]
lr_gamma = algorithm.config._torch_lr_scheduler_classes[1].keywords["gamma"]
# Compute the learning rates for all iterations up to `lr_const_iters`.
for i in range(1, lr_total_iters + 1):
# The initial learning rate.
lr = algorithm.config.lr
# In the first 10 iterations we multiply by `lr_const_factor`.
if i < lr_total_iters:
lr *= lr_factor
# Finally, we have an exponential decay of `lr_exp_decay`.
lr *= lr_gamma**i
self.lr.append(lr)

def on_train_result(
self,
*,
algorithm: "Algorithm",
metrics_logger: Optional[MetricsLogger] = None,
result: dict,
**kwargs,
) -> None:

# Check for the first `lr_total_iters + 1` iterations, if expected
# and actual learning rates correspond.
if (
algorithm.training_iteration
<= algorithm.config._torch_lr_scheduler_classes[0].keywords["total_iters"]
):
actual_lr = algorithm.learner_group._learner.get_optimizer(
DEFAULT_MODULE_ID, DEFAULT_OPTIMIZER
).param_groups[0]["lr"]
# Assert the learning rates are close enough.
assert np.isclose(
actual_lr,
self.lr[algorithm.training_iteration - 1],
atol=1e-9,
rtol=1e-9,
)


parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=250000)
parser.set_defaults(enable_new_api_stack=True)
parser.add_argument(
"--lr-const-factor",
type=float,
default=0.1,
default=0.9,
help="The factor by which the learning rate should be multiplied.",
)
parser.add_argument(
Expand All @@ -83,20 +145,20 @@
parser.add_argument(
"--lr-exp-decay",
type=float,
default=0.3,
default=0.99,
help="The rate by which the learning rate should exponentially decay.",
)

if __name__ == "__main__":
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values toset up `config` below.
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

config = (
PPOConfig()
.environment("CartPole-v1")
.training(
lr=0.0003,
lr=0.03,
num_sgd_iter=6,
vf_loss_coeff=0.01,
)
Expand Down Expand Up @@ -129,6 +191,9 @@
),
]
)
.callbacks(
LRChecker,
)
)

stop = {
Expand Down

0 comments on commit ec9775d

Please sign in to comment.