-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] RLModule API:
SelfSupervisedLossAPI
for RLModules that brin…
…g their own loss (algo independent). (#47581)
- Loading branch information
Showing
20 changed files
with
319 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI | ||
from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI | ||
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI | ||
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI | ||
|
||
|
||
__all__ = [ | ||
"InferenceOnlyAPI", | ||
"SelfSupervisedLossAPI", | ||
"TargetNetworkAPI", | ||
"ValueFunctionAPI", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import abc | ||
from typing import Any, Dict, TYPE_CHECKING | ||
|
||
from ray.rllib.utils.typing import ModuleID, TensorType | ||
|
||
if TYPE_CHECKING: | ||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig | ||
from ray.rllib.core.learner.learner import Learner | ||
|
||
|
||
class SelfSupervisedLossAPI(abc.ABC): | ||
"""An API to be implemented by RLModules that bring their own self-supervised loss. | ||
Learners will call these model's `compute_self_supervised_loss()` method instead of | ||
the Learner's own `compute_loss_for_module()` method. | ||
The call signature is identical to the Learner's `compute_loss_for_module()` method | ||
except of an additional mandatory `learner` kwarg. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def compute_self_supervised_loss( | ||
self, | ||
*, | ||
learner: "Learner", | ||
module_id: ModuleID, | ||
config: "AlgorithmConfig", | ||
batch: Dict[str, Any], | ||
fwd_out: Dict[str, TensorType], | ||
) -> TensorType: | ||
"""Computes the loss for a single module. | ||
Think of this as computing loss for a single agent. For multi-agent use-cases | ||
that require more complicated computation for loss, consider overriding the | ||
`compute_losses` method instead. | ||
Args: | ||
learner: The Learner calling this loss method on the RLModule. | ||
module_id: The ID of the RLModule (within a MultiRLModule). | ||
config: The AlgorithmConfig specific to the given `module_id`. | ||
batch: The sample batch for this particular RLModule. | ||
fwd_out: The output of the forward pass for this particular RLModule. | ||
Returns: | ||
A single total loss tensor. If you have more than one optimizer on the | ||
provided `module_id` and would like to compute gradients separately using | ||
these different optimizers, simply add up the individual loss terms for | ||
each optimizer and return the sum. Also, for recording/logging any | ||
individual loss terms, you can use the `Learner.metrics.log_value( | ||
key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See: | ||
:py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more | ||
information. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.