From e51cec6a6f7d2f9a2506aaa5543daad0d5ad35e9 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 21 Nov 2024 13:38:56 +0100 Subject: [PATCH] [RLlib] APPO enhancements (new API stack) vol 03: Fix target network update setting and logic. (#48802) --- rllib/algorithms/appo/appo.py | 39 ++++++++++++++++++--------- rllib/algorithms/appo/appo_learner.py | 33 +++++++++-------------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index 3632ffab954b..b27e96b02d16 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -100,7 +100,6 @@ def __init__(self, algo_class=None): # __sphinx_doc_begin__ # APPO specific settings: self.vtrace = True - self.use_critic = True self.use_gae = True self.lambda_ = 1.0 self.clip_param = 0.4 @@ -120,7 +119,7 @@ def __init__(self, algo_class=None): # Override some of IMPALAConfig's default values with APPO-specific values. self.num_env_runners = 2 self.min_time_s_per_iteration = 10 - self.target_network_update_freq = 1 + self.target_network_update_freq = 2 self.broadcast_interval = 1 self.grad_clip = 40.0 # Note: Only when using enable_rl_module_and_learner=True can the clipping mode @@ -151,26 +150,27 @@ def __init__(self, algo_class=None): # Deprecated keys. self.target_update_frequency = DEPRECATED_VALUE + self.use_critic = DEPRECATED_VALUE @override(IMPALAConfig) def training( self, *, vtrace: Optional[bool] = NotProvided, - use_critic: Optional[bool] = NotProvided, use_gae: Optional[bool] = NotProvided, lambda_: Optional[float] = NotProvided, clip_param: Optional[float] = NotProvided, use_kl_loss: Optional[bool] = NotProvided, kl_coeff: Optional[float] = NotProvided, kl_target: Optional[float] = NotProvided, - tau: Optional[float] = NotProvided, target_network_update_freq: Optional[int] = NotProvided, + tau: Optional[float] = NotProvided, target_worker_clipping: Optional[float] = NotProvided, circular_buffer_num_batches: Optional[int] = NotProvided, circular_buffer_iterations_per_batch: Optional[int] = NotProvided, # Deprecated keys. target_update_frequency=DEPRECATED_VALUE, + use_critic=DEPRECATED_VALUE, **kwargs, ) -> "APPOConfig": """Sets the training related configuration. @@ -178,8 +178,6 @@ def training( Args: vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE advantages will be used instead. - use_critic: Should use a critic as a baseline (otherwise don't use value - baseline; required for using GAE). Only applies if vtrace=False. use_gae: If true, use the Generalized Advantage Estimator (GAE) with a value function, see https://arxiv.org/pdf/1506.02438.pdf. Only applies if vtrace=False. @@ -189,9 +187,18 @@ def training( kl_coeff: Coefficient for weighting the KL-loss term. kl_target: Target term for the KL-term to reach (via adjusting the `kl_coeff` automatically). - tau: The factor by which to update the target policy network towards - the current policy network. Can range between 0 and 1. - e.g. updated_param = tau * current_param + (1 - tau) * target_param + target_network_update_freq: NOTE: This parameter is only applicable on + the new API stack. The frequency with which to update the target + policy network from the main trained policy network. The metric + used is `NUM_ENV_STEPS_TRAINED_LIFETIME` and the unit is `n` (see [1] + 4.1.1), where: `n = [circular_buffer_num_batches (N)] * + [circular_buffer_iterations_per_batch (K)] * [train batch size]` + For example, if you set `target_network_update_freq=2`, and N=4, K=2, + and `train_batch_size_per_learner=500`, then the target net is updated + every 2*4*2*500=8000 trained env steps (every 16 batch updates on each + learner). + The authors in [1] suggests that this setting is robust to a range of + choices (try values between 0.125 and 4). target_network_update_freq: The frequency to update the target policy and tune the kl loss coefficients that are used during training. After setting this parameter, the algorithm waits for at least @@ -199,6 +206,9 @@ def training( on before updating the target networks and tune the kl loss coefficients. NOTE: This parameter is only applicable when using the Learner API (enable_rl_module_and_learner=True). + tau: The factor by which to update the target policy network towards + the current policy network. Can range between 0 and 1. + e.g. updated_param = tau * current_param + (1 - tau) * target_param target_worker_clipping: The maximum value for the target-worker-clipping used for computing the IS ratio, described in [1] IS = min(π(i) / π(target), ρ) * (π / π(i)) @@ -220,14 +230,17 @@ def training( new="target_network_update_freq", error=True, ) + if use_critic != DEPRECATED_VALUE: + deprecation_warning( + old="use_critic", + error=True, + ) # Pass kwargs onto super's `training()` method. super().training(**kwargs) if vtrace is not NotProvided: self.vtrace = vtrace - if use_critic is not NotProvided: - self.use_critic = use_critic if use_gae is not NotProvided: self.use_gae = use_gae if lambda_ is not NotProvided: @@ -240,10 +253,10 @@ def training( self.kl_coeff = kl_coeff if kl_target is not NotProvided: self.kl_target = kl_target - if tau is not NotProvided: - self.tau = tau if target_network_update_freq is not NotProvided: self.target_network_update_freq = target_network_update_freq + if tau is not NotProvided: + self.tau = tau if target_worker_clipping is not NotProvided: self.target_worker_clipping = target_worker_clipping if circular_buffer_num_batches is not NotProvided: diff --git a/rllib/algorithms/appo/appo_learner.py b/rllib/algorithms/appo/appo_learner.py index 920d7b7ea992..431449893264 100644 --- a/rllib/algorithms/appo/appo_learner.py +++ b/rllib/algorithms/appo/appo_learner.py @@ -12,8 +12,9 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.metrics import ( + ALL_MODULES, LAST_TARGET_UPDATE_TS, - NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_MODULE_STEPS_TRAINED, NUM_TARGET_UPDATES, ) @@ -86,30 +87,22 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: """Updates the target Q Networks.""" super().after_gradient_based_update(timesteps=timesteps) - timestep = timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0) - # TODO (sven): Maybe we should have a `after_gradient_based_update` # method per module? + curr_timestep = self.metrics.peek((ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME)) for module_id, module in self.module._rl_modules.items(): config = self.config.get_config_for_module(module_id) - # TODO (avnish) Using steps trained here instead of sampled ... I'm not sure - # why the other implementation uses sampled. - # The difference in steps sampled/trained is pretty - # much always going to be larger than self.config.num_epochs * - # self.config.minibatch_buffer_size unless the number of steps collected - # is really small. The thing is that the default rollout fragment length - # is 50, so the minibatch buffer size * num_epochs is going to be - # have to be 50 to even meet the threshold of having delayed target - # updates. - # We should instead have the target / kl threshold update be based off - # of the train_batch_size * some target update frequency * num_epochs. - last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS) - if timestep - self.metrics.peek( - last_update_ts_key, default=0 - ) >= config.target_network_update_freq and isinstance( - module.unwrapped(), TargetNetworkAPI + if isinstance(module.unwrapped(), TargetNetworkAPI) and ( + curr_timestep - self.metrics.peek(last_update_ts_key, default=0) + >= ( + config.target_network_update_freq + * config.circular_buffer_num_batches + * config.circular_buffer_iterations_per_batch + * config.total_train_batch_size + / (config.num_learners or 1) + ) ): for ( main_net, @@ -123,7 +116,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: # Increase lifetime target network update counter by one. self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum") # Update the (single-value -> window=1) last updated timestep metric. - self.metrics.log_value(last_update_ts_key, timestep, window=1) + self.metrics.log_value(last_update_ts_key, curr_timestep, window=1) if ( config.use_kl_loss