diff --git a/rllib/algorithms/appo/appo_learner.py b/rllib/algorithms/appo/appo_learner.py index a1c06a854309..f13fc476d9cf 100644 --- a/rllib/algorithms/appo/appo_learner.py +++ b/rllib/algorithms/appo/appo_learner.py @@ -128,7 +128,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: @abc.abstractmethod def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None: - """Dynamically update the KL loss coefficients of each module with. + """Dynamically update the KL loss coefficients of each module. The update is completed using the mean KL divergence between the action distributions current policy and old policy of each module. That action diff --git a/rllib/algorithms/ppo/ppo_learner.py b/rllib/algorithms/ppo/ppo_learner.py index 2b08ff36947c..e6ca7dde6c37 100644 --- a/rllib/algorithms/ppo/ppo_learner.py +++ b/rllib/algorithms/ppo/ppo_learner.py @@ -3,6 +3,7 @@ from ray.rllib.algorithms.ppo.ppo import ( LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, PPOConfig, ) from ray.rllib.connectors.learner import ( @@ -19,6 +20,7 @@ NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_MODULE_STEPS_TRAINED, ) +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ModuleID, TensorType @@ -106,8 +108,16 @@ def after_gradient_based_update( config.use_kl_loss and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0) > 0 + and (module_id, LEARNER_RESULTS_KL_KEY) in self.metrics ): - self._update_module_kl_coeff(module_id=module_id, config=config) + kl_loss = convert_to_numpy( + self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)) + ) + self._update_module_kl_coeff( + module_id=module_id, + config=config, + kl_loss=kl_loss, + ) @abc.abstractmethod def _update_module_kl_coeff( @@ -115,8 +125,9 @@ def _update_module_kl_coeff( *, module_id: ModuleID, config: PPOConfig, + kl_loss: float, ) -> None: - """Dynamically update the KL loss coefficients of each module with. + """Dynamically update the KL loss coefficients of each module. The update is completed using the mean KL divergence between the action distributions current policy and old policy of each module. That action @@ -125,4 +136,6 @@ def _update_module_kl_coeff( Args: module_id: The module whose KL loss coefficient to update. config: The AlgorithmConfig specific to the given `module_id`. + kl_loss: The mean KL loss of the module, computed inside + `compute_loss_for_module()`. """ diff --git a/rllib/algorithms/ppo/tf/ppo_tf_learner.py b/rllib/algorithms/ppo/tf/ppo_tf_learner.py index f355098e23e6..aa176363f7a7 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_learner.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_learner.py @@ -18,7 +18,6 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.annotations import override -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.typing import ModuleID, TensorType _, tf, _ = try_import_tf() @@ -151,10 +150,9 @@ def _update_module_kl_coeff( *, module_id: ModuleID, config: PPOConfig, + kl_loss: float, ) -> None: - kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY))) - - if np.isnan(kl): + if np.isnan(kl_loss): logger.warning( f"KL divergence for Module {module_id} is non-finite, this " "will likely destabilize your model and the training " @@ -168,10 +166,10 @@ def _update_module_kl_coeff( # Update the KL coefficient. curr_var = self.curr_kl_coeffs_per_module[module_id] - if kl > 2.0 * config.kl_target: + if kl_loss > 2.0 * config.kl_target: # TODO (Kourosh) why not 2? curr_var.assign(curr_var * 1.5) - elif kl < 0.5 * config.kl_target: + elif kl_loss < 0.5 * config.kl_target: curr_var.assign(curr_var * 0.5) # Log the updated KL-coeff value. diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index 88e7b5737de7..f866165e2243 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -17,7 +17,6 @@ from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import explained_variance from ray.rllib.utils.typing import ModuleID, TensorType @@ -140,10 +139,9 @@ def _update_module_kl_coeff( *, module_id: ModuleID, config: PPOConfig, + kl_loss: float, ) -> None: - kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY))) - - if np.isnan(kl): + if np.isnan(kl_loss): logger.warning( f"KL divergence for Module {module_id} is non-finite, this " "will likely destabilize your model and the training " @@ -157,10 +155,10 @@ def _update_module_kl_coeff( # Update the KL coefficient. curr_var = self.curr_kl_coeffs_per_module[module_id] - if kl > 2.0 * config.kl_target: + if kl_loss > 2.0 * config.kl_target: # TODO (Kourosh) why not 2? curr_var.data *= 1.5 - elif kl < 0.5 * config.kl_target: + elif kl_loss < 0.5 * config.kl_target: curr_var.data *= 0.5 # Log the updated KL-coeff value. diff --git a/rllib/connectors/common/batch_individual_items.py b/rllib/connectors/common/batch_individual_items.py index 464ab0bd4d9d..654ad5e65bb8 100644 --- a/rllib/connectors/common/batch_individual_items.py +++ b/rllib/connectors/common/batch_individual_items.py @@ -162,7 +162,7 @@ def __call__( # Single-agent case: There is a dict under `column` mapping # `eps_id` to lists of items: - # Sort by eps_id, concat all these lists, then batch. + # Concat all these lists, then batch. elif not self._multi_agent: # TODO: only really need this in non-Learner connector pipeline memorized_map_structure = [] diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index b52c6603f629..c49d80ab2c29 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -26,6 +26,7 @@ LearnerConnectorPipeline, ) from ray.rllib.core import COMPONENT_OPTIMIZER, COMPONENT_RL_MODULE, DEFAULT_MODULE_ID +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.core.rl_module import validate_module_id from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, @@ -876,12 +877,22 @@ def compute_losses( module_batch = batch[module_id] module_fwd_out = fwd_out[module_id] - loss = self.compute_loss_for_module( - module_id=module_id, - config=self.config.get_config_for_module(module_id), - batch=module_batch, - fwd_out=module_fwd_out, - ) + module = self.module[module_id].unwrapped() + if isinstance(module, SelfSupervisedLossAPI): + loss = module.compute_self_supervised_loss( + learner=self, + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + else: + loss = self.compute_loss_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) loss_per_module[module_id] = loss return loss_per_module diff --git a/rllib/core/models/configs.py b/rllib/core/models/configs.py index 19f69bf579bd..a345860213c7 100644 --- a/rllib/core/models/configs.py +++ b/rllib/core/models/configs.py @@ -190,7 +190,7 @@ def output_dims(self): ) # Infer `output_dims` automatically. - return (self.output_layer_dim or self.hidden_layer_dims[-1],) + return (int(self.output_layer_dim or self.hidden_layer_dims[-1]),) def _validate(self, framework: str = "torch"): """Makes sure that settings are valid.""" diff --git a/rllib/core/rl_module/apis/__init__.py b/rllib/core/rl_module/apis/__init__.py index 7ab33dd68f1c..c79562a717e8 100644 --- a/rllib/core/rl_module/apis/__init__.py +++ b/rllib/core/rl_module/apis/__init__.py @@ -1,10 +1,12 @@ from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI +from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI __all__ = [ "InferenceOnlyAPI", + "SelfSupervisedLossAPI", "TargetNetworkAPI", "ValueFunctionAPI", ] diff --git a/rllib/core/rl_module/apis/self_supervised_loss_api.py b/rllib/core/rl_module/apis/self_supervised_loss_api.py new file mode 100644 index 000000000000..04b697e9d111 --- /dev/null +++ b/rllib/core/rl_module/apis/self_supervised_loss_api.py @@ -0,0 +1,52 @@ +import abc +from typing import Any, Dict, TYPE_CHECKING + +from ray.rllib.utils.typing import ModuleID, TensorType + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.core.learner.learner import Learner + + +class SelfSupervisedLossAPI(abc.ABC): + """An API to be implemented by RLModules that bring their own self-supervised loss. + + Learners will call these model's `compute_self_supervised_loss()` method instead of + the Learner's own `compute_loss_for_module()` method. + The call signature is identical to the Learner's `compute_loss_for_module()` method + except of an additional mandatory `learner` kwarg. + """ + + @abc.abstractmethod + def compute_self_supervised_loss( + self, + *, + learner: "Learner", + module_id: ModuleID, + config: "AlgorithmConfig", + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + """Computes the loss for a single module. + + Think of this as computing loss for a single agent. For multi-agent use-cases + that require more complicated computation for loss, consider overriding the + `compute_losses` method instead. + + Args: + learner: The Learner calling this loss method on the RLModule. + module_id: The ID of the RLModule (within a MultiRLModule). + config: The AlgorithmConfig specific to the given `module_id`. + batch: The sample batch for this particular RLModule. + fwd_out: The output of the forward pass for this particular RLModule. + + Returns: + A single total loss tensor. If you have more than one optimizer on the + provided `module_id` and would like to compute gradients separately using + these different optimizers, simply add up the individual loss terms for + each optimizer and return the sum. Also, for recording/logging any + individual loss terms, you can use the `Learner.metrics.log_value( + key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See: + :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more + information. + """ diff --git a/rllib/examples/connectors/classes/count_based_curiosity.py b/rllib/examples/connectors/classes/count_based_curiosity.py index 37af0ad9bf13..1f865e3a8ae8 100644 --- a/rllib/examples/connectors/classes/count_based_curiosity.py +++ b/rllib/examples/connectors/classes/count_based_curiosity.py @@ -62,7 +62,7 @@ def __call__( self, *, rl_module: RLModule, - data: Any, + batch: Any, episodes: List[EpisodeType], explore: Optional[bool] = None, shared_data: Optional[dict] = None, @@ -89,4 +89,4 @@ def __call__( # timestep/index). sa_episode.set_rewards(new_data=rew, at_indices=i) - return data + return batch diff --git a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py index 0babff5a33f0..c50a2caae5d7 100644 --- a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py +++ b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py @@ -73,14 +73,14 @@ def __call__( self, *, rl_module: RLModule, - data: Any, + batch: Any, episodes: List[EpisodeType], explore: Optional[bool] = None, shared_data: Optional[dict] = None, **kwargs, ) -> Any: if self._test > 10: - return data + return batch self._test += 1 # Loop through all episodes and change the reward to # [reward + intrinsic reward] @@ -119,4 +119,4 @@ def __call__( if max_dist_obs is not None: self.obs_buffer.append(max_dist_obs) - return data + return batch diff --git a/rllib/examples/connectors/classes/protobuf_cartpole_observation_decoder.py b/rllib/examples/connectors/classes/protobuf_cartpole_observation_decoder.py index 4fc66385f2dd..2ed4a891afcd 100644 --- a/rllib/examples/connectors/classes/protobuf_cartpole_observation_decoder.py +++ b/rllib/examples/connectors/classes/protobuf_cartpole_observation_decoder.py @@ -46,7 +46,7 @@ def __call__( self, *, rl_module: RLModule, - data: Any, + batch: Any, episodes: List[EpisodeType], explore: Optional[bool] = None, shared_data: Optional[dict] = None, @@ -77,4 +77,4 @@ def __call__( sa_episode.set_observations(new_data=new_obs, at_indices=-1) # Return `data` as-is. - return data + return batch diff --git a/rllib/examples/curiosity/intrinsic_curiosity_model_based_curiosity.py b/rllib/examples/curiosity/intrinsic_curiosity_model_based_curiosity.py index 9aab5a31a4ad..ae3f77563d2e 100644 --- a/rllib/examples/curiosity/intrinsic_curiosity_model_based_curiosity.py +++ b/rllib/examples/curiosity/intrinsic_curiosity_model_based_curiosity.py @@ -17,7 +17,6 @@ exploration in sparse rewards environments. For more details, see here: - [1] Curiosity-driven Exploration by Self-supervised Prediction Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. https://arxiv.org/pdf/1705.05363.pdf @@ -74,21 +73,21 @@ """ from collections import defaultdict +from ray import tune +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.connectors.env_to_module import FlattenObservations -from ray.rllib.examples.learners.classes.curiosity_dqn_torch_learner import ( - DQNConfigWithCuriosity, +from ray.rllib.examples.learners.classes.intrinsic_curiosity_learners import ( DQNTorchLearnerWithCuriosity, + PPOTorchLearnerWithCuriosity, ) from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.examples.learners.classes.curiosity_ppo_torch_learner import ( - PPOConfigWithCuriosity, - PPOTorchLearnerWithCuriosity, +from ray.rllib.examples.learners.classes.intrinsic_curiosity_learners import ( + ICM_MODULE_ID, ) from ray.rllib.examples.rl_modules.classes.intrinsic_curiosity_model_rlm import ( - ICM_MODULE_ID, IntrinsicCuriosityModel, ) from ray.rllib.utils.metrics import ( @@ -188,12 +187,9 @@ def on_sample_end( "Curiosity example only implemented for either DQN or PPO! See the " ) - config_class = ( - PPOConfigWithCuriosity if args.algo == "PPO" else DQNConfigWithCuriosity - ) - base_config = ( - config_class() + tune.registry.get_trainable_cls(args.algo) + .get_default_config() .environment( "FrozenLake-v1", env_config={ @@ -218,18 +214,23 @@ def on_sample_end( "max_episode_steps": 22, }, ) - # Use our custom `curiosity` method to set up the PPO/ICM-Learner. - .curiosity( - # Intrinsic reward coefficient. - curiosity_eta=0.05, - # Forward loss weight (vs inverse dynamics loss, which will be `1. - beta`). - # curiosity_beta=0.2, - ) .callbacks(MeasureMaxDistanceToStart) .env_runners( num_envs_per_env_runner=5 if args.algo == "PPO" else 1, env_to_module_connector=lambda env: FlattenObservations(), ) + .training( + learner_config_dict={ + # Intrinsic reward coefficient. + "intrinsic_reward_coeff": 0.05, + # Forward loss weight (vs inverse dynamics loss). Total ICM loss is: + # L(total ICM) = ( + # `forward_loss_weight` * L(forward) + # + (1.0 - `forward_loss_weight`) * L(inverse_dyn) + # ) + "forward_loss_weight": 0.2, + } + ) .rl_module( rl_module_spec=MultiRLModuleSpec( module_specs={ @@ -262,7 +263,7 @@ def on_sample_end( ), # Use a different learning rate for training the ICM. algorithm_config_overrides_per_module={ - ICM_MODULE_ID: config_class.overrides(lr=0.0005) + ICM_MODULE_ID: AlgorithmConfig.overrides(lr=0.0005) }, ) ) diff --git a/rllib/examples/inference/policy_inference_after_training_w_connector.py b/rllib/examples/inference/policy_inference_after_training_w_connector.py index 1b89d15aec8c..4fe9520b6e0f 100644 --- a/rllib/examples/inference/policy_inference_after_training_w_connector.py +++ b/rllib/examples/inference/policy_inference_after_training_w_connector.py @@ -190,7 +190,8 @@ def _env_creator(cfg): best_result = results.get_best_result( metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max" ) - # Create new Algorithm and restore its state from the last checkpoint. + + # Create RLModule from a checkpoint. rl_module = RLModule.from_checkpoint( os.path.join( best_result.checkpoint.path, diff --git a/rllib/examples/learners/classes/curiosity_dqn_torch_learner.py b/rllib/examples/learners/classes/curiosity_dqn_torch_learner.py deleted file mode 100644 index 7d482110a82b..000000000000 --- a/rllib/examples/learners/classes/curiosity_dqn_torch_learner.py +++ /dev/null @@ -1,12 +0,0 @@ -from ray.rllib.algorithms.dqn.dqn import DQNConfig -from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_learner import ( - DQNRainbowTorchLearner, -) -from ray.rllib.examples.learners.classes.curiosity_torch_learner_utils import ( - make_curiosity_config_class, - make_curiosity_learner_class, -) - - -DQNConfigWithCuriosity = make_curiosity_config_class(DQNConfig) -DQNTorchLearnerWithCuriosity = make_curiosity_learner_class(DQNRainbowTorchLearner) diff --git a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py deleted file mode 100644 index 13c7d7c1475a..000000000000 --- a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py +++ /dev/null @@ -1,10 +0,0 @@ -from ray.rllib.algorithms.ppo.ppo import PPOConfig -from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner -from ray.rllib.examples.learners.classes.curiosity_torch_learner_utils import ( - make_curiosity_config_class, - make_curiosity_learner_class, -) - - -PPOConfigWithCuriosity = make_curiosity_config_class(PPOConfig) -PPOTorchLearnerWithCuriosity = make_curiosity_learner_class(PPOTorchLearner) diff --git a/rllib/examples/learners/classes/curiosity_torch_learner_utils.py b/rllib/examples/learners/classes/curiosity_torch_learner_utils.py deleted file mode 100644 index d34819e549f2..000000000000 --- a/rllib/examples/learners/classes/curiosity_torch_learner_utils.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Any, Dict - -from ray.rllib.algorithms.algorithm_config import NotProvided -from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( - AddObservationsFromEpisodesToBatch, -) -from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa - AddNextObservationsFromEpisodesToTrainBatch, -) -from ray.rllib.core import Columns, DEFAULT_MODULE_ID -from ray.rllib.utils.metrics import ALL_MODULES - -ICM_MODULE_ID = "_intrinsic_curiosity_model" - - -def make_curiosity_config_class(config_class): - class _class(config_class): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # Define defaults. - self.curiosity_beta = 0.2 - self.curiosity_eta = 1.0 - - # Allow users to change curiosity settings. - def curiosity( - self, - *, - curiosity_beta: float = NotProvided, - curiosity_eta: float = NotProvided, - ): - """Sets the config's curiosity settings. - - Args: - curiosity_beta: The coefficient used for the intrinsic rewards. Overall - rewards are computed as `R = R[extrinsic] + beta * R[intrinsic]`. - curiosity_eta: Fraction of the forward loss (within the total loss term) - vs the inverse dynamics loss. The total loss of the ICM is computed - as: `L = eta * [forward loss] + (1.0 - eta) * [inverse loss]`. - - Returns: - This updated AlgorithmConfig object. - """ - if curiosity_beta is not NotProvided: - self.curiosity_beta = curiosity_beta - if curiosity_eta is not NotProvided: - self.curiosity_eta = curiosity_eta - return self - - return _class - - -def make_curiosity_learner_class(learner_class): - class _class(learner_class): - def build(self): - super().build() - - # Assert, we are only training one policy (RLModule) and we have the ICM - # in our MultiRLModule. - assert ( - len(self.module) == 2 - and DEFAULT_MODULE_ID in self.module - and ICM_MODULE_ID in self.module - ) - - # Prepend a "add-NEXT_OBS-from-episodes-to-train-batch" connector piece - # (right after the corresponding "add-OBS-..." default piece). - if self.config.add_default_connectors_to_learner_pipeline: - self._learner_connector.insert_after( - AddObservationsFromEpisodesToBatch, - AddNextObservationsFromEpisodesToTrainBatch(), - ) - - def compute_losses( - self, - *, - fwd_out: Dict[str, Any], - batch: Dict[str, Any], - ) -> Dict[str, Any]: - # Compute the ICM loss first (so we'll have the chance to change the rewards - # in the batch for the "main" RLModule (before we compute its loss with the - # intrinsic rewards). - icm = self.module[ICM_MODULE_ID] - # Send the exact same batch to the ICM module that we used for the "main" - # RLModule's forward pass. - icm_fwd_out = icm.forward_train(batch=batch[DEFAULT_MODULE_ID]) - # Compute the loss of the ICM module. - icm_loss = icm.compute_loss_for_module( - learner=self, - module_id=ICM_MODULE_ID, - config=self.config.get_config_for_module(ICM_MODULE_ID), - batch=batch[DEFAULT_MODULE_ID], - fwd_out=icm_fwd_out, - ) - # Log the env steps trained counter for the ICM - - # Add intrinsic rewards from ICM's `fwd_out` (multiplied by factor `eta`) - # to "main" module batch's extrinsic rewards. - batch[DEFAULT_MODULE_ID][Columns.REWARDS] += ( - self.config.curiosity_eta * icm_fwd_out[Columns.INTRINSIC_REWARDS] - ) - - # Compute the "main" RLModule's loss. - main_loss = self.compute_loss_for_module( - module_id=DEFAULT_MODULE_ID, - config=self.config.get_config_for_module(DEFAULT_MODULE_ID), - batch=batch[DEFAULT_MODULE_ID], - fwd_out=fwd_out[DEFAULT_MODULE_ID], - ) - - return { - DEFAULT_MODULE_ID: main_loss, - ICM_MODULE_ID: icm_loss, - ALL_MODULES: main_loss + icm_loss, - } - - return _class diff --git a/rllib/examples/learners/classes/intrinsic_curiosity_learners.py b/rllib/examples/learners/classes/intrinsic_curiosity_learners.py new file mode 100644 index 000000000000..d28aded45989 --- /dev/null +++ b/rllib/examples/learners/classes/intrinsic_curiosity_learners.py @@ -0,0 +1,164 @@ +from typing import Any, List, Optional + +import gymnasium as gym +import torch + +from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_learner import ( + DQNRainbowTorchLearner, +) +from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core import Columns, DEFAULT_MODULE_ID +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.typing import EpisodeType + +ICM_MODULE_ID = "_intrinsic_curiosity_model" + + +class DQNTorchLearnerWithCuriosity(DQNRainbowTorchLearner): + def build(self) -> None: + super().build() + add_intrinsic_curiosity_connectors(self) + + +class PPOTorchLearnerWithCuriosity(PPOTorchLearner): + def build(self) -> None: + super().build() + add_intrinsic_curiosity_connectors(self) + + +def add_intrinsic_curiosity_connectors(torch_learner: TorchLearner) -> None: + """Adds two connector pieces to the Learner pipeline, needed for ICM training. + + - The `AddNextObservationsFromEpisodesToTrainBatch` connector makes sure the train + batch contains the NEXT_OBS for ICM's forward- and inverse dynamics net training. + - The `IntrinsicCuriosityModelConnector` piece computes intrinsic rewards from the + ICM and adds the results to the extrinsic reward of the main module's train batch. + + Args: + torch_learner: The TorchLearner, to whose Learner pipeline the two ICM connector + pieces should be added. + """ + learner_config_dict = torch_learner.config.learner_config_dict + + # Assert, we are only training one policy (RLModule) and we have the ICM + # in our MultiRLModule. + assert ( + len(torch_learner.module) == 2 + and DEFAULT_MODULE_ID in torch_learner.module + and ICM_MODULE_ID in torch_learner.module + ) + + # Make sure both curiosity loss settings are explicitly set in the + # `learner_config_dict`. + if ( + "forward_loss_weight" not in learner_config_dict + or "intrinsic_reward_coeff" not in learner_config_dict + ): + raise KeyError( + "When using the IntrinsicCuriosityTorchLearner, both `forward_loss_weight` " + " and `intrinsic_reward_coeff` must be part of your config's " + "`learner_config_dict`! Add these values through: `config.training(" + "learner_config_dict={'forward_loss_weight': .., 'intrinsic_reward_coeff': " + "..})`." + ) + + if torch_learner.config.add_default_connectors_to_learner_pipeline: + # Prepend a "add-NEXT_OBS-from-episodes-to-train-batch" connector piece + # (right after the corresponding "add-OBS-..." default piece). + torch_learner._learner_connector.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + # Append the ICM connector, computing intrinsic rewards and adding these to + # the main model's extrinsic rewards. + torch_learner._learner_connector.insert_after( + NumpyToTensor, + IntrinsicCuriosityModelConnector( + intrinsic_reward_coeff=( + torch_learner.config.learner_config_dict["intrinsic_reward_coeff"] + ) + ), + ) + + +class IntrinsicCuriosityModelConnector(ConnectorV2): + """Learner ConnectorV2 piece to compute intrinsic rewards based on an ICM. + + For more details, see here: + [1] Curiosity-driven Exploration by Self-supervised Prediction + Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. + https://arxiv.org/pdf/1705.05363.pdf + + This connector piece: + - requires two RLModules to be present in the MultiRLModule: + DEFAULT_MODULE_ID (the policy model to be trained) and ICM_MODULE_ID (the instrinsic + curiosity architecture). + - must be located toward the end of to your Learner pipeline (after the + `NumpyToTensor` piece) in order to perform a forward pass on the ICM model with the + readily compiled batch and a following forward-loss computation to get the intrinsi + rewards. + - these intrinsic rewards will then be added to the (extrinsic) rewards in the main + model's train batch. + """ + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + intrinsic_reward_coeff: float, + **kwargs, + ): + """Initializes a CountBasedCuriosity instance. + + Args: + intrinsic_reward_coeff: The weight with which to multiply the intrinsic + reward before adding it to the extrinsic rewards of the main model. + """ + super().__init__(input_observation_space, input_action_space) + + self.intrinsic_reward_coeff = intrinsic_reward_coeff + + def __call__( + self, + *, + rl_module: RLModule, + batch: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # Assert that the batch is ready. + assert DEFAULT_MODULE_ID in batch and ICM_MODULE_ID not in batch + assert ( + Columns.OBS in batch[DEFAULT_MODULE_ID] + and Columns.NEXT_OBS in batch[DEFAULT_MODULE_ID] + ) + # TODO (sven): We are performing two forward passes per update right now. + # Once here in the connector (w/o grad) to just get the intrinsic rewards + # and once in the learner to actually compute the ICM loss and update the ICM. + # Maybe we can save one of these, but this would currently harm the DDP-setup + # for multi-GPU training. + with torch.no_grad(): + # Perform ICM forward pass. + fwd_out = rl_module[ICM_MODULE_ID].forward_train(batch[DEFAULT_MODULE_ID]) + + # Add the intrinsic rewards to the main module's extrinsic rewards. + batch[DEFAULT_MODULE_ID][Columns.REWARDS] += ( + self.intrinsic_reward_coeff * fwd_out[Columns.INTRINSIC_REWARDS] + ) + + # Duplicate the batch such that the ICM also has data to learn on. + batch[ICM_MODULE_ID] = batch[DEFAULT_MODULE_ID] + + return batch diff --git a/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py b/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py index c4a5c57f469b..469b093b2d5f 100644 --- a/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py +++ b/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py @@ -1,10 +1,10 @@ from typing import Any, Dict, TYPE_CHECKING +import tree # pip install dm_tree + from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.core.rl_module.torch import TorchRLModule -from ray.rllib.examples.learners.classes.curiosity_torch_learner_utils import ( # noqa - ICM_MODULE_ID, -) from ray.rllib.models.torch.torch_distributions import TorchCategorical from ray.rllib.models.utils import get_activation_fn from ray.rllib.utils.annotations import override @@ -19,7 +19,7 @@ torch, nn = try_import_torch() -class IntrinsicCuriosityModel(TorchRLModule): +class IntrinsicCuriosityModel(TorchRLModule, SelfSupervisedLossAPI): """An intrinsic curiosity model (ICM) as TorchRLModule for better exploration. For more details, see: @@ -152,15 +152,12 @@ def setup(self): def _forward_train(self, batch, **kwargs): # Push both observations through feature net to get feature vectors (phis). # We cat/batch them here for efficiency reasons (save one forward pass). - phis = self._feature_net( - torch.cat( - [ - batch[Columns.OBS], - batch[Columns.NEXT_OBS], - ], - dim=0, - ) + obs = tree.map_structure( + lambda obs, next_obs: torch.cat([obs, next_obs], dim=0), + batch[Columns.OBS], + batch[Columns.NEXT_OBS], ) + phis = self._feature_net(obs) # Split again to yield 2 individual phi tensors. phi, next_phi = torch.chunk(phis, 2) @@ -198,8 +195,9 @@ def _forward_train(self, batch, **kwargs): def get_train_action_dist_cls(self): return TorchCategorical - @staticmethod - def compute_loss_for_module( + @override(SelfSupervisedLossAPI) + def compute_self_supervised_loss( + self, *, learner: "TorchLearner", module_id: ModuleID, @@ -207,7 +205,7 @@ def compute_loss_for_module( batch: Dict[str, Any], fwd_out: Dict[str, Any], ) -> Dict[str, Any]: - module = learner.module[module_id] + module = learner.module[module_id].unwrapped() # Forward net loss. forward_loss = torch.mean(fwd_out[Columns.INTRINSIC_REWARDS]) @@ -226,8 +224,9 @@ def compute_loss_for_module( # Calculate the ICM loss. total_loss = ( - 1.0 - config.curiosity_beta - ) * inverse_loss + config.curiosity_beta * forward_loss + config.learner_config_dict["forward_loss_weight"] * forward_loss + + (1.0 - config.learner_config_dict["forward_loss_weight"]) * inverse_loss + ) learner.metrics.log_dict( { diff --git a/rllib/utils/metrics/metrics_logger.py b/rllib/utils/metrics/metrics_logger.py index af60b4b1eb43..8c7d9e402ce5 100644 --- a/rllib/utils/metrics/metrics_logger.py +++ b/rllib/utils/metrics/metrics_logger.py @@ -53,6 +53,18 @@ def __init__(self): self._tensor_mode = False self._tensor_keys = set() + def __contains__(self, key: Union[str, Tuple[str, ...]]) -> bool: + """Returns True, if `key` can be found in self.stats. + + Args: + key: The key to find in self.stats. This must be either a str (single, + top-level key) or a tuple of str (nested key). + + Returns: + Whether `key` could be found in self.stats. + """ + return self._key_in_stats(key) + def peek( self, key: Union[str, Tuple[str, ...]],