From ec9775d86fbf7eb93358d95268e9f62e53f790bd Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 4 Nov 2024 16:21:40 +0100 Subject: [PATCH] [RLlib] Fix torch scheduler stepping and reporting. (#48125) --- rllib/core/learner/torch/torch_learner.py | 66 +++++++++++++--- .../learners/ppo_with_torch_lr_schedulers.py | 79 +++++++++++++++++-- 2 files changed, 128 insertions(+), 17 deletions(-) diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index 5c46ba913d56..5c6beb4b622d 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -14,7 +14,7 @@ AlgorithmConfig, TorchCompileWhatToCompile, ) -from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.learner.learner import Learner, LR_KEY from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, @@ -32,6 +32,7 @@ from ray.rllib.utils.annotations import ( override, OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, ) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics import ( @@ -41,6 +42,7 @@ NUM_NON_TRAINABLE_PARAMETERS, WEIGHTS_SEQ_NO, ) +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor, copy_torch_tensors from ray.rllib.utils.typing import ( ModuleID, @@ -223,7 +225,10 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None: # If we have learning rate schedulers for a module add them, if # necessary. if self._lr_scheduler_classes is not None: - if module_id not in self._lr_schedulers: + if ( + module_id not in self._lr_schedulers + or optimizer_name not in self._lr_schedulers[module_id] + ): # Set for each module and optimizer a scheduler. self._lr_schedulers[module_id] = {optimizer_name: []} # If the classes are in a dictionary each module might have @@ -271,15 +276,56 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None: "`False`." ) - # If the module uses learning rate schedulers, step them here. - if module_id in self._lr_schedulers: - for scheduler in self._lr_schedulers[module_id][optimizer_name]: - scheduler.step() + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: + """Called after gradient-based updates are completed. + + Should be overridden to implement custom cleanup-, logging-, or non-gradient- + based Learner/RLModule update logic after(!) gradient-based updates have been + completed. + + Note, for `framework="torch"` users can register + `torch.optim.lr_scheduler.LRScheduler` via + `AlgorithmConfig._torch_lr_scheduler_classes`. These schedulers need to be + stepped here after gradient updates and reported. + + Args: + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + """ + + # If we have no `torch.optim.lr_scheduler.LRScheduler` registered call the + # `super()`'s method to update RLlib's learning rate schedules. + if not self._lr_schedulers: + return super().after_gradient_based_update(timesteps=timesteps) - # If the module uses learning rate schedulers, step them here. - if module_id in self._lr_schedulers: - for scheduler in self._lr_schedulers[module_id][optimizer_name]: - scheduler.step() + # Only update this optimizer's lr, if a scheduler has been registered + # along with it. + for module_id, optimizer_names in self._module_optimizers.items(): + for optimizer_name in optimizer_names: + # If learning rate schedulers are provided step them here. Note, + # stepping them in `TorchLearner.apply_gradients` updates the + # learning rates during minibatch updates; we want to update + # between whole batch updates. + if ( + module_id in self._lr_schedulers + and optimizer_name in self._lr_schedulers[module_id] + ): + for scheduler in self._lr_schedulers[module_id][optimizer_name]: + scheduler.step() + optimizer = self.get_optimizer(module_id, optimizer_name) + self.metrics.log_value( + # Cut out the module ID from the beginning since it's already + # part of the key sequence: (ModuleID, "[optim name]_lr"). + key=( + module_id, + f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}", + ), + value=convert_to_numpy(self._get_optimizer_lr(optimizer)), + window=1, + ) @override(Learner) def _get_optimizer_state(self) -> StateDict: diff --git a/rllib/examples/learners/ppo_with_torch_lr_schedulers.py b/rllib/examples/learners/ppo_with_torch_lr_schedulers.py index d7f38205de9b..3e109162e018 100644 --- a/rllib/examples/learners/ppo_with_torch_lr_schedulers.py +++ b/rllib/examples/learners/ppo_with_torch_lr_schedulers.py @@ -11,8 +11,8 @@ How to run this script ---------------------- -`python [script file name].py --enable-new-api-stack --lr-const-factor=0.1 ---lr-const-iters=10 --lr-exp-decay=0.3` +`python [script file name].py --enable-new-api-stack --lr-const-factor=0.9 +--lr-const-iters=10 --lr-exp-decay=0.9` Use the `--lr-const-factor` to define the facotr by which to multiply the learning rate in the first `--lr-const-iters` iterations. Use the @@ -49,8 +49,14 @@ +------------------------+------------------------+------------------------+ """ import functools +import numpy as np +from typing import Optional +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, @@ -58,17 +64,73 @@ EVALUATION_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME, ) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import add_rllib_example_script_args torch, _ = try_import_torch() -parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=200000) + +class LRChecker(DefaultCallbacks): + def on_algorithm_init( + self, + *, + algorithm: "Algorithm", + metrics_logger: Optional[MetricsLogger] = None, + **kwargs, + ) -> None: + # Store the expected learning rates for each iteration. + self.lr = [] + # Retrieve the chosen configuration parameters from the config. + lr_factor = algorithm.config._torch_lr_scheduler_classes[0].keywords["factor"] + lr_total_iters = algorithm.config._torch_lr_scheduler_classes[0].keywords[ + "total_iters" + ] + lr_gamma = algorithm.config._torch_lr_scheduler_classes[1].keywords["gamma"] + # Compute the learning rates for all iterations up to `lr_const_iters`. + for i in range(1, lr_total_iters + 1): + # The initial learning rate. + lr = algorithm.config.lr + # In the first 10 iterations we multiply by `lr_const_factor`. + if i < lr_total_iters: + lr *= lr_factor + # Finally, we have an exponential decay of `lr_exp_decay`. + lr *= lr_gamma**i + self.lr.append(lr) + + def on_train_result( + self, + *, + algorithm: "Algorithm", + metrics_logger: Optional[MetricsLogger] = None, + result: dict, + **kwargs, + ) -> None: + + # Check for the first `lr_total_iters + 1` iterations, if expected + # and actual learning rates correspond. + if ( + algorithm.training_iteration + <= algorithm.config._torch_lr_scheduler_classes[0].keywords["total_iters"] + ): + actual_lr = algorithm.learner_group._learner.get_optimizer( + DEFAULT_MODULE_ID, DEFAULT_OPTIMIZER + ).param_groups[0]["lr"] + # Assert the learning rates are close enough. + assert np.isclose( + actual_lr, + self.lr[algorithm.training_iteration - 1], + atol=1e-9, + rtol=1e-9, + ) + + +parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=250000) parser.set_defaults(enable_new_api_stack=True) parser.add_argument( "--lr-const-factor", type=float, - default=0.1, + default=0.9, help="The factor by which the learning rate should be multiplied.", ) parser.add_argument( @@ -83,20 +145,20 @@ parser.add_argument( "--lr-exp-decay", type=float, - default=0.3, + default=0.99, help="The rate by which the learning rate should exponentially decay.", ) if __name__ == "__main__": # Use `parser` to add your own custom command line options to this script - # and (if needed) use their values toset up `config` below. + # and (if needed) use their values to set up `config` below. args = parser.parse_args() config = ( PPOConfig() .environment("CartPole-v1") .training( - lr=0.0003, + lr=0.03, num_sgd_iter=6, vf_loss_coeff=0.01, ) @@ -129,6 +191,9 @@ ), ] ) + .callbacks( + LRChecker, + ) ) stop = {