Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] RLModule API: SelfSupervisedLossAPI for RLModules that bring their own loss (algo independent). #47581

Merged
2 changes: 1 addition & 1 deletion rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:

@abc.abstractmethod
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
"""Dynamically update the KL loss coefficients of each module with.
"""Dynamically update the KL loss coefficients of each module.

The update is completed using the mean KL divergence between the action
distributions current policy and old policy of each module. That action
Expand Down
17 changes: 15 additions & 2 deletions rllib/algorithms/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ray.rllib.algorithms.ppo.ppo import (
LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY,
LEARNER_RESULTS_KL_KEY,
PPOConfig,
)
from ray.rllib.connectors.learner import (
Expand All @@ -19,6 +20,7 @@
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import ModuleID, TensorType

Expand Down Expand Up @@ -106,17 +108,26 @@ def after_gradient_based_update(
config.use_kl_loss
and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0)
> 0
and (module_id, LEARNER_RESULTS_KL_KEY) in self.metrics
):
self._update_module_kl_coeff(module_id=module_id, config=config)
kl_loss = convert_to_numpy(
self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY))
)
self._update_module_kl_coeff(
module_id=module_id,
config=config,
kl_loss=kl_loss,
)

@abc.abstractmethod
def _update_module_kl_coeff(
self,
*,
module_id: ModuleID,
config: PPOConfig,
kl_loss: float,
) -> None:
"""Dynamically update the KL loss coefficients of each module with.
"""Dynamically update the KL loss coefficients of each module.

The update is completed using the mean KL divergence between the action
distributions current policy and old policy of each module. That action
Expand All @@ -125,4 +136,6 @@ def _update_module_kl_coeff(
Args:
module_id: The module whose KL loss coefficient to update.
config: The AlgorithmConfig specific to the given `module_id`.
kl_loss: The mean KL loss of the module, computed inside
`compute_loss_for_module()`.
"""
10 changes: 4 additions & 6 deletions rllib/algorithms/ppo/tf/ppo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import explained_variance
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import ModuleID, TensorType

_, tf, _ = try_import_tf()
Expand Down Expand Up @@ -151,10 +150,9 @@ def _update_module_kl_coeff(
*,
module_id: ModuleID,
config: PPOConfig,
kl_loss: float,
) -> None:
kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)))

if np.isnan(kl):
if np.isnan(kl_loss):
logger.warning(
f"KL divergence for Module {module_id} is non-finite, this "
"will likely destabilize your model and the training "
Expand All @@ -168,10 +166,10 @@ def _update_module_kl_coeff(

# Update the KL coefficient.
curr_var = self.curr_kl_coeffs_per_module[module_id]
if kl > 2.0 * config.kl_target:
if kl_loss > 2.0 * config.kl_target:
# TODO (Kourosh) why not 2?
curr_var.assign(curr_var * 1.5)
elif kl < 0.5 * config.kl_target:
elif kl_loss < 0.5 * config.kl_target:
curr_var.assign(curr_var * 0.5)

# Log the updated KL-coeff value.
Expand Down
10 changes: 4 additions & 6 deletions rllib/algorithms/ppo/torch/ppo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import explained_variance
from ray.rllib.utils.typing import ModuleID, TensorType

Expand Down Expand Up @@ -140,10 +139,9 @@ def _update_module_kl_coeff(
*,
module_id: ModuleID,
config: PPOConfig,
kl_loss: float,
) -> None:
kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)))

if np.isnan(kl):
if np.isnan(kl_loss):
logger.warning(
f"KL divergence for Module {module_id} is non-finite, this "
"will likely destabilize your model and the training "
Expand All @@ -157,10 +155,10 @@ def _update_module_kl_coeff(

# Update the KL coefficient.
curr_var = self.curr_kl_coeffs_per_module[module_id]
if kl > 2.0 * config.kl_target:
if kl_loss > 2.0 * config.kl_target:
# TODO (Kourosh) why not 2?
curr_var.data *= 1.5
elif kl < 0.5 * config.kl_target:
elif kl_loss < 0.5 * config.kl_target:
curr_var.data *= 0.5

# Log the updated KL-coeff value.
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/common/batch_individual_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __call__(

# Single-agent case: There is a dict under `column` mapping
# `eps_id` to lists of items:
# Sort by eps_id, concat all these lists, then batch.
# Concat all these lists, then batch.
elif not self._multi_agent:
# TODO: only really need this in non-Learner connector pipeline
memorized_map_structure = []
Expand Down
23 changes: 17 additions & 6 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LearnerConnectorPipeline,
)
from ray.rllib.core import COMPONENT_OPTIMIZER, COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
Expand Down Expand Up @@ -876,12 +877,22 @@ def compute_losses(
module_batch = batch[module_id]
module_fwd_out = fwd_out[module_id]

loss = self.compute_loss_for_module(
module_id=module_id,
config=self.config.get_config_for_module(module_id),
batch=module_batch,
fwd_out=module_fwd_out,
)
module = self.module[module_id].unwrapped()
if isinstance(module, SelfSupervisedLossAPI):
loss = module.compute_self_supervised_loss(
learner=self,
module_id=module_id,
config=self.config.get_config_for_module(module_id),
batch=module_batch,
fwd_out=module_fwd_out,
)
else:
loss = self.compute_loss_for_module(
module_id=module_id,
config=self.config.get_config_for_module(module_id),
batch=module_batch,
fwd_out=module_fwd_out,
)
loss_per_module[module_id] = loss

return loss_per_module
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def output_dims(self):
)

# Infer `output_dims` automatically.
return (self.output_layer_dim or self.hidden_layer_dims[-1],)
return (int(self.output_layer_dim or self.hidden_layer_dims[-1]),)

def _validate(self, framework: str = "torch"):
"""Makes sure that settings are valid."""
Expand Down
2 changes: 2 additions & 0 deletions rllib/core/rl_module/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI
from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


__all__ = [
"InferenceOnlyAPI",
"SelfSupervisedLossAPI",
"TargetNetworkAPI",
"ValueFunctionAPI",
]
52 changes: 52 additions & 0 deletions rllib/core/rl_module/apis/self_supervised_loss_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import abc
from typing import Any, Dict, TYPE_CHECKING

from ray.rllib.utils.typing import ModuleID, TensorType

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core.learner.learner import Learner


class SelfSupervisedLossAPI(abc.ABC):
"""An API to be implemented by RLModules that bring their own self-supervised loss.

Learners will call these model's `compute_self_supervised_loss()` method instead of
the Learner's own `compute_loss_for_module()` method.
The call signature is identical to the Learner's `compute_loss_for_module()` method
except of an additional mandatory `learner` kwarg.
"""

@abc.abstractmethod
def compute_self_supervised_loss(
self,
*,
learner: "Learner",
module_id: ModuleID,
config: "AlgorithmConfig",
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> TensorType:
"""Computes the loss for a single module.

Think of this as computing loss for a single agent. For multi-agent use-cases
that require more complicated computation for loss, consider overriding the
`compute_losses` method instead.

Args:
learner: The Learner calling this loss method on the RLModule.
module_id: The ID of the RLModule (within a MultiRLModule).
config: The AlgorithmConfig specific to the given `module_id`.
batch: The sample batch for this particular RLModule.
fwd_out: The output of the forward pass for this particular RLModule.

Returns:
A single total loss tensor. If you have more than one optimizer on the
provided `module_id` and would like to compute gradients separately using
these different optimizers, simply add up the individual loss terms for
each optimizer and return the sum. Also, for recording/logging any
individual loss terms, you can use the `Learner.metrics.log_value(
key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See:
:py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more
information.
"""
4 changes: 2 additions & 2 deletions rllib/examples/connectors/classes/count_based_curiosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __call__(
self,
*,
rl_module: RLModule,
data: Any,
batch: Any,
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
Expand All @@ -89,4 +89,4 @@ def __call__(
# timestep/index).
sa_episode.set_rewards(new_data=rew, at_indices=i)

return data
return batch
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def __call__(
self,
*,
rl_module: RLModule,
data: Any,
batch: Any,
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
if self._test > 10:
return data
return batch
self._test += 1
# Loop through all episodes and change the reward to
# [reward + intrinsic reward]
Expand Down Expand Up @@ -119,4 +119,4 @@ def __call__(
if max_dist_obs is not None:
self.obs_buffer.append(max_dist_obs)

return data
return batch
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(
self,
*,
rl_module: RLModule,
data: Any,
batch: Any,
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
Expand Down Expand Up @@ -77,4 +77,4 @@ def __call__(
sa_episode.set_observations(new_data=new_obs, at_indices=-1)

# Return `data` as-is.
return data
return batch
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
exploration in sparse rewards environments.

For more details, see here:

[1] Curiosity-driven Exploration by Self-supervised Prediction
Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
https://arxiv.org/pdf/1705.05363.pdf
Expand Down Expand Up @@ -74,21 +73,21 @@
"""
from collections import defaultdict

from ray import tune
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.examples.learners.classes.curiosity_dqn_torch_learner import (
DQNConfigWithCuriosity,
from ray.rllib.examples.learners.classes.intrinsic_curiosity_learners import (
DQNTorchLearnerWithCuriosity,
PPOTorchLearnerWithCuriosity,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.learners.classes.curiosity_ppo_torch_learner import (
PPOConfigWithCuriosity,
PPOTorchLearnerWithCuriosity,
from ray.rllib.examples.learners.classes.intrinsic_curiosity_learners import (
ICM_MODULE_ID,
)
from ray.rllib.examples.rl_modules.classes.intrinsic_curiosity_model_rlm import (
ICM_MODULE_ID,
IntrinsicCuriosityModel,
)
from ray.rllib.utils.metrics import (
Expand Down Expand Up @@ -188,12 +187,9 @@ def on_sample_end(
"Curiosity example only implemented for either DQN or PPO! See the "
)

config_class = (
PPOConfigWithCuriosity if args.algo == "PPO" else DQNConfigWithCuriosity
)

base_config = (
config_class()
tune.registry.get_trainable_cls(args.algo)
.get_default_config()
.environment(
"FrozenLake-v1",
env_config={
Expand All @@ -218,18 +214,23 @@ def on_sample_end(
"max_episode_steps": 22,
},
)
# Use our custom `curiosity` method to set up the PPO/ICM-Learner.
.curiosity(
# Intrinsic reward coefficient.
curiosity_eta=0.05,
# Forward loss weight (vs inverse dynamics loss, which will be `1. - beta`).
# curiosity_beta=0.2,
)
.callbacks(MeasureMaxDistanceToStart)
.env_runners(
num_envs_per_env_runner=5 if args.algo == "PPO" else 1,
env_to_module_connector=lambda env: FlattenObservations(),
)
.training(
learner_config_dict={
# Intrinsic reward coefficient.
"intrinsic_reward_coeff": 0.05,
# Forward loss weight (vs inverse dynamics loss). Total ICM loss is:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice comment!

# L(total ICM) = (
# `forward_loss_weight` * L(forward)
# + (1.0 - `forward_loss_weight`) * L(inverse_dyn)
# )
"forward_loss_weight": 0.2,
}
)
.rl_module(
rl_module_spec=MultiRLModuleSpec(
module_specs={
Expand Down Expand Up @@ -262,7 +263,7 @@ def on_sample_end(
),
# Use a different learning rate for training the ICM.
algorithm_config_overrides_per_module={
ICM_MODULE_ID: config_class.overrides(lr=0.0005)
ICM_MODULE_ID: AlgorithmConfig.overrides(lr=0.0005)
},
)
)
Expand Down
Loading
Loading