Skip to content

Commit

Permalink
[RLlib] Docs do-over (new API stack): Prep. RLModule; introduce `Defa…
Browse files Browse the repository at this point in the history
…ult[algo]RLModule` classes (rename from `[algo]RLModule`). (#49366)
  • Loading branch information
sven1977 authored Dec 20, 2024
1 parent 47ae84e commit 1e34aa7
Show file tree
Hide file tree
Showing 42 changed files with 385 additions and 316 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-catalogs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ Inject your custom model or action distributions into Catalogs
You can make a :py:class:`~ray.rllib.core.models.catalog.Catalog` build custom ``models`` by overriding the Catalog’s methods used by RLModules to build ``models``.
Have a look at these lines from the constructor of the :py:class:`~ray.rllib.algorithms.ppo.ppo_torch_rl_module.PPOTorchRLModule` to see how Catalogs are being used by an :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`:

.. literalinclude:: ../../../rllib/algorithms/ppo/ppo_rl_module.py
.. literalinclude:: ../../../rllib/algorithms/ppo/default_ppo_rl_module.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Expand Down
16 changes: 8 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1355,39 +1355,39 @@ py_test(

# LearnerGroup
py_test(
name = "TestLearnerGroupSyncUpdate",
name = "test_learner_group_async_update",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupSyncUpdate"]
args = ["TestLearnerGroupAsyncUpdate"]
)

py_test(
name = "TestLearnerGroupCheckpointRestore",
name = "test_learner_group_sync_update",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupCheckpointRestore"]
args = ["TestLearnerGroupSyncUpdate"]
)

py_test(
name = "TestLearnerGroupAsyncUpdate",
name = "test_learner_group_checkpoint_restore",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupAsyncUpdate"]
args = ["TestLearnerGroupCheckpointRestore"]
)

py_test(
name = "TestLearnerGroupSaveLoadState",
name = "test_learner_group_save_and_restore_state",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupSaveLoadState"]
args = ["TestLearnerGroupSaveAndRestoreState"]
)

# Learner
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/appo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = [
"APPO",
"APPOConfig",
# @OldAPIStack
"APPOTF1Policy",
"APPOTF2Policy",
"APPOTorchPolicy",
Expand Down
53 changes: 9 additions & 44 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,11 @@
import abc
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.apis import TargetNetworkAPI
from ray.rllib.utils.typing import NetworkType

from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
# Backward compat import.
from ray.rllib.algorithms.appo.default_appo_rl_module import ( # noqa
DefaultAPPORLModule as APPORLModule,
)
from ray.rllib.utils.deprecation import deprecation_warning


class APPORLModule(PPORLModule, TargetNetworkAPI, abc.ABC):
@override(TargetNetworkAPI)
def make_target_networks(self):
self._old_encoder = make_target_network(self.encoder)
self._old_pi = make_target_network(self.pi)

@override(TargetNetworkAPI)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return [
(self.encoder, self._old_encoder),
(self.pi, self._old_pi),
]

@override(TargetNetworkAPI)
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(PPORLModule)
def get_non_inference_attributes(self) -> List[str]:
# Get the NON inference-only attributes from the parent class
# `PPOTorchRLModule`.
ret = super().get_non_inference_attributes()
# Add the two (APPO) target networks to it (NOT needed in
# inference-only mode).
ret += ["_old_encoder", "_old_pi"]
return ret
deprecation_warning(
old="ray.rllib.algorithms.appo.appo_rl_module.APPORLModule",
new="ray.rllib.algorithms.appo.default_appo_rl_module.DefaultAPPORLModule",
error=False,
)
55 changes: 55 additions & 0 deletions rllib/algorithms/appo/default_appo_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import abc
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.apis import TargetNetworkAPI
from ray.rllib.utils.typing import NetworkType

from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)


class DefaultAPPORLModule(DefaultPPORLModule, TargetNetworkAPI, abc.ABC):
"""Default RLModule used by APPO, if user does not specify a custom RLModule.
Users who want to train their RLModules with APPO may implement any RLModule
(or TorchRLModule) subclass as long as the custom class also implements the
`ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py)
and the `TargetNetworkAPI` (see
ray.rllib.core.rl_module.apis.target_network_api.py).
"""

@override(TargetNetworkAPI)
def make_target_networks(self):
self._old_encoder = make_target_network(self.encoder)
self._old_pi = make_target_network(self.pi)

@override(TargetNetworkAPI)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return [
(self.encoder, self._old_encoder),
(self.pi, self._old_pi),
]

@override(TargetNetworkAPI)
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(DefaultPPORLModule)
def get_non_inference_attributes(self) -> List[str]:
# Get the NON inference-only attributes from the parent class
# `PPOTorchRLModule`.
ret = super().get_non_inference_attributes()
# Add the two (APPO) target networks to it (NOT needed in
# inference-only mode).
ret += ["_old_encoder", "_old_pi"]
return ret
15 changes: 11 additions & 4 deletions rllib/algorithms/appo/torch/appo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
# Backward compat import.
from ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module import ( # noqa
DefaultAPPOTorchRLModule as APPOTorchRLModule,
)
from ray.rllib.utils.deprecation import deprecation_warning


class APPOTorchRLModule(PPOTorchRLModule, APPORLModule):
pass
deprecation_warning(
old="ray.rllib.algorithms.appo.torch.appo_torch_rl_module.APPOTorchRLModule",
new="ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module."
"DefaultAPPOTorchRLModule",
error=False,
)
8 changes: 8 additions & 0 deletions rllib/algorithms/appo/torch/default_appo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ray.rllib.algorithms.appo.default_appo_rl_module import DefaultAPPORLModule
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
DefaultPPOTorchRLModule,
)


class DefaultAPPOTorchRLModule(DefaultPPOTorchRLModule, DefaultAPPORLModule):
pass
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ray.rllib.algorithms.bc.bc import BCConfig, BC

__all__ = [
"BCConfig",
"BC",
"BCConfig",
]
6 changes: 4 additions & 2 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ def __init__(self, algo_class=None):
@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
if self.framework_str == "torch":
from ray.rllib.algorithms.bc.torch.bc_torch_rl_module import BCTorchRLModule
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import (
DefaultBCTorchRLModule,
)

return RLModuleSpec(
module_class=BCTorchRLModule,
module_class=DefaultBCTorchRLModule,
catalog_class=BCCatalog,
)
else:
Expand Down
33 changes: 0 additions & 33 deletions rllib/algorithms/bc/torch/bc_torch_rl_module.py

This file was deleted.

38 changes: 38 additions & 0 deletions rllib/algorithms/bc/torch/default_bc_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import abc
from typing import Any, Dict

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.util.annotations import DeveloperAPI


@DeveloperAPI(stability="alpha")
class DefaultBCTorchRLModule(TorchRLModule, abc.ABC):
"""The default TorchRLModule used, if no custom RLModule is provided.
Builds an encoder net based on the observation space.
Builds a pi head based on the action space.
Passes observations from the input batch through the encoder, then the pi head to
compute action logits.
"""

@override(RLModule)
def setup(self):
# Build model components (encoder and pi head) from catalog.
super().setup()
self._encoder = self.catalog.build_encoder(framework=self.framework)
self._pi_head = self.catalog.build_pi_head(framework=self.framework)

@override(TorchRLModule)
def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
"""Generic BC forward pass (for all phases of training/evaluation)."""
# Encoder embeddings.
encoder_outs = self._encoder(batch)
# Action dist inputs.
return {
Columns.ACTION_DIST_INPUTS: self._pi_head(encoder_outs[ENCODER_OUT]),
}
3 changes: 2 additions & 1 deletion rllib/algorithms/cql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

__all__ = [
"CQL",
"CQLTorchPolicy",
"CQLConfig",
# @OldAPIStack
"CQLTorchPolicy",
]
21 changes: 8 additions & 13 deletions rllib/algorithms/dqn/dqn_rainbow_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,19 @@ def setup(self):
# If in a dueling setting setup the value function head.
self.vf = self.catalog.build_vf_head(framework=self.framework)

@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
return ["_target_encoder", "_target_af"] + (
["_target_vf"] if self.uses_dueling else []
)

@override(TargetNetworkAPI)
def make_target_networks(self) -> None:
self._target_encoder = make_target_network(self.encoder)
self._target_af = make_target_network(self.af)
if self.uses_dueling:
self._target_vf = make_target_network(self.vf)

@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
return ["_target_encoder", "_target_af"] + (
["_target_vf"] if self.uses_dueling else []
)

@override(TargetNetworkAPI)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return [(self.encoder, self._target_encoder), (self.af, self._target_af)] + (
Expand Down Expand Up @@ -119,11 +119,6 @@ def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
),
)

# TODO (simon): DQN Rainbow does not support RNNs, yet.
@override(RLModule)
def get_initial_state(self) -> Any:
return {}

@override(RLModule)
def input_specs_exploration(self) -> SpecType:
return [Columns.OBS]
Expand Down Expand Up @@ -180,8 +175,8 @@ def _qf(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
Results:
A dictionary containing the Q-value predictions ("qf_preds")
and in case of distributional Q-learning in addition to the Q-value
predictions ("qf_preds") the support atoms ("atoms"), the Q-logits
and in case of distributional Q-learning - in addition to the Q-value
predictions ("qf_preds") - the support atoms ("atoms"), the Q-logits
("qf_logits"), and the probabilities ("qf_probs").
"""
# If we have a dueling architecture we have to add the value stream.
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from ray.rllib.algorithms.impala.impala_torch_policy import ImpalaTorchPolicy

__all__ = [
"IMPALAConfig",
"IMPALA",
"IMPALAConfig",
# @OldAPIStack
"ImpalaTF1Policy",
"ImpalaTF2Policy",
"ImpalaTorchPolicy",
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
*,
in_queue: queue.Queue,
out_queue: deque,
device: torch.device,
device: "torch.device",
metrics_logger: MetricsLogger,
):
super().__init__()
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/marwil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
__all__ = [
"MARWIL",
"MARWILConfig",
# @OldAPIStack
"MARWILTF1Policy",
"MARWILTF2Policy",
"MARWILTorchPolicy",
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ class (multi-/single-learner setup) and evaluation on
self.eval_env_runner_group.sync_weights(
# Sync weights from learner_group to all EnvRunners.
from_worker_or_learner_group=self.learner_group,
policies=modules_to_update,
policies=list(modules_to_update),
inference_only=True,
)

Expand Down
Loading

0 comments on commit 1e34aa7

Please sign in to comment.