-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Docs do-over (new API stack): Prep. RLModule; introduce `Defa…
…ult[algo]RLModule` classes (rename from `[algo]RLModule`). (#49366)
- Loading branch information
Showing
42 changed files
with
385 additions
and
316 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
__all__ = [ | ||
"APPO", | ||
"APPOConfig", | ||
# @OldAPIStack | ||
"APPOTF1Policy", | ||
"APPOTF2Policy", | ||
"APPOTorchPolicy", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,11 @@ | ||
import abc | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule | ||
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY | ||
from ray.rllib.core.learner.utils import make_target_network | ||
from ray.rllib.core.models.base import ACTOR | ||
from ray.rllib.core.models.tf.encoder import ENCODER_OUT | ||
from ray.rllib.core.rl_module.apis import TargetNetworkAPI | ||
from ray.rllib.utils.typing import NetworkType | ||
|
||
from ray.rllib.utils.annotations import ( | ||
override, | ||
OverrideToImplementCustomLogic_CallToSuperRecommended, | ||
# Backward compat import. | ||
from ray.rllib.algorithms.appo.default_appo_rl_module import ( # noqa | ||
DefaultAPPORLModule as APPORLModule, | ||
) | ||
from ray.rllib.utils.deprecation import deprecation_warning | ||
|
||
|
||
class APPORLModule(PPORLModule, TargetNetworkAPI, abc.ABC): | ||
@override(TargetNetworkAPI) | ||
def make_target_networks(self): | ||
self._old_encoder = make_target_network(self.encoder) | ||
self._old_pi = make_target_network(self.pi) | ||
|
||
@override(TargetNetworkAPI) | ||
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: | ||
return [ | ||
(self.encoder, self._old_encoder), | ||
(self.pi, self._old_pi), | ||
] | ||
|
||
@override(TargetNetworkAPI) | ||
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]: | ||
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR] | ||
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded) | ||
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits} | ||
|
||
@OverrideToImplementCustomLogic_CallToSuperRecommended | ||
@override(PPORLModule) | ||
def get_non_inference_attributes(self) -> List[str]: | ||
# Get the NON inference-only attributes from the parent class | ||
# `PPOTorchRLModule`. | ||
ret = super().get_non_inference_attributes() | ||
# Add the two (APPO) target networks to it (NOT needed in | ||
# inference-only mode). | ||
ret += ["_old_encoder", "_old_pi"] | ||
return ret | ||
deprecation_warning( | ||
old="ray.rllib.algorithms.appo.appo_rl_module.APPORLModule", | ||
new="ray.rllib.algorithms.appo.default_appo_rl_module.DefaultAPPORLModule", | ||
error=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import abc | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule | ||
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY | ||
from ray.rllib.core.learner.utils import make_target_network | ||
from ray.rllib.core.models.base import ACTOR | ||
from ray.rllib.core.models.tf.encoder import ENCODER_OUT | ||
from ray.rllib.core.rl_module.apis import TargetNetworkAPI | ||
from ray.rllib.utils.typing import NetworkType | ||
|
||
from ray.rllib.utils.annotations import ( | ||
override, | ||
OverrideToImplementCustomLogic_CallToSuperRecommended, | ||
) | ||
|
||
|
||
class DefaultAPPORLModule(DefaultPPORLModule, TargetNetworkAPI, abc.ABC): | ||
"""Default RLModule used by APPO, if user does not specify a custom RLModule. | ||
Users who want to train their RLModules with APPO may implement any RLModule | ||
(or TorchRLModule) subclass as long as the custom class also implements the | ||
`ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) | ||
and the `TargetNetworkAPI` (see | ||
ray.rllib.core.rl_module.apis.target_network_api.py). | ||
""" | ||
|
||
@override(TargetNetworkAPI) | ||
def make_target_networks(self): | ||
self._old_encoder = make_target_network(self.encoder) | ||
self._old_pi = make_target_network(self.pi) | ||
|
||
@override(TargetNetworkAPI) | ||
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: | ||
return [ | ||
(self.encoder, self._old_encoder), | ||
(self.pi, self._old_pi), | ||
] | ||
|
||
@override(TargetNetworkAPI) | ||
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]: | ||
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR] | ||
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded) | ||
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits} | ||
|
||
@OverrideToImplementCustomLogic_CallToSuperRecommended | ||
@override(DefaultPPORLModule) | ||
def get_non_inference_attributes(self) -> List[str]: | ||
# Get the NON inference-only attributes from the parent class | ||
# `PPOTorchRLModule`. | ||
ret = super().get_non_inference_attributes() | ||
# Add the two (APPO) target networks to it (NOT needed in | ||
# inference-only mode). | ||
ret += ["_old_encoder", "_old_pi"] | ||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule | ||
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule | ||
# Backward compat import. | ||
from ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module import ( # noqa | ||
DefaultAPPOTorchRLModule as APPOTorchRLModule, | ||
) | ||
from ray.rllib.utils.deprecation import deprecation_warning | ||
|
||
|
||
class APPOTorchRLModule(PPOTorchRLModule, APPORLModule): | ||
pass | ||
deprecation_warning( | ||
old="ray.rllib.algorithms.appo.torch.appo_torch_rl_module.APPOTorchRLModule", | ||
new="ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module." | ||
"DefaultAPPOTorchRLModule", | ||
error=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ray.rllib.algorithms.appo.default_appo_rl_module import DefaultAPPORLModule | ||
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( | ||
DefaultPPOTorchRLModule, | ||
) | ||
|
||
|
||
class DefaultAPPOTorchRLModule(DefaultPPOTorchRLModule, DefaultAPPORLModule): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from ray.rllib.algorithms.bc.bc import BCConfig, BC | ||
|
||
__all__ = [ | ||
"BCConfig", | ||
"BC", | ||
"BCConfig", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import abc | ||
from typing import Any, Dict | ||
|
||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.models.base import ENCODER_OUT | ||
from ray.rllib.core.rl_module.rl_module import RLModule | ||
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule | ||
from ray.rllib.utils.annotations import override | ||
from ray.util.annotations import DeveloperAPI | ||
|
||
|
||
@DeveloperAPI(stability="alpha") | ||
class DefaultBCTorchRLModule(TorchRLModule, abc.ABC): | ||
"""The default TorchRLModule used, if no custom RLModule is provided. | ||
Builds an encoder net based on the observation space. | ||
Builds a pi head based on the action space. | ||
Passes observations from the input batch through the encoder, then the pi head to | ||
compute action logits. | ||
""" | ||
|
||
@override(RLModule) | ||
def setup(self): | ||
# Build model components (encoder and pi head) from catalog. | ||
super().setup() | ||
self._encoder = self.catalog.build_encoder(framework=self.framework) | ||
self._pi_head = self.catalog.build_pi_head(framework=self.framework) | ||
|
||
@override(TorchRLModule) | ||
def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]: | ||
"""Generic BC forward pass (for all phases of training/evaluation).""" | ||
# Encoder embeddings. | ||
encoder_outs = self._encoder(batch) | ||
# Action dist inputs. | ||
return { | ||
Columns.ACTION_DIST_INPUTS: self._pi_head(encoder_outs[ENCODER_OUT]), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
__all__ = [ | ||
"CQL", | ||
"CQLTorchPolicy", | ||
"CQLConfig", | ||
# @OldAPIStack | ||
"CQLTorchPolicy", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.