Skip to content

Commit

Permalink
[RLlib] RLModule API: SelfSupervisedLossAPI for RLModules that brin…
Browse files Browse the repository at this point in the history
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 6d9bef9 commit 3251d2c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions rllib/core/rl_module/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI
from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


__all__ = [
"InferenceOnlyAPI",
"SelfSupervisedLossAPI",
"TargetNetworkAPI",
"ValueFunctionAPI",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -187,6 +188,10 @@ def _forward_train(self, batch, **kwargs):

return output

@override(TorchRLModule)
def get_train_action_dist_cls(self):
return TorchCategorical

@override(SelfSupervisedLossAPI)
def compute_self_supervised_loss(
self,
Expand Down

0 comments on commit 3251d2c

Please sign in to comment.