Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce gen…
Browse files Browse the repository at this point in the history
…eric `_forward` to further simplify the user experience). (ray-project#47889)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 34deedd commit 9301790
Show file tree
Hide file tree
Showing 14 changed files with 161 additions and 178 deletions.
7 changes: 4 additions & 3 deletions rllib/algorithms/bc/torch/bc_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BCTorchRLModule(TorchRLModule):
@override(RLModule)
def setup(self):
# __sphinx_doc_begin__
# Build models from catalog.
# Build models from catalog
self.encoder = self.catalog.build_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)

Expand All @@ -20,12 +20,13 @@ def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]:
"""Generic BC forward pass (for all phases of training/evaluation)."""
output = {}

# Encoder forward pass.
# State encodings.
encoder_outs = self.encoder(batch)
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Actions.
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT])
action_logits = self.pi(encoder_outs[ENCODER_OUT])
output[Columns.ACTION_DIST_INPUTS] = action_logits

return output
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/torch/marwil_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def possibly_masked_mean(data_):
# Otherwise, compute advantages.
else:
# cumulative_rewards = batch[Columns.ADVANTAGES]
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
)
advantages = batch[Columns.VALUE_TARGETS] - value_fn_out
advantages_squared_mean = possibly_masked_mean(torch.pow(advantages, 2.0))

Expand Down
5 changes: 1 addition & 4 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def input_specs_train(self) -> SpecDict:

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Columns.VF_PREDS,
Columns.ACTION_DIST_INPUTS,
]
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
Expand Down
33 changes: 27 additions & 6 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
torch, nn = try_import_torch()


def get_expected_module_config(env, model_config_dict, observation_space):
config = RLModuleConfig(
observation_space=observation_space,
action_space=env.action_space,
model_config_dict=model_config_dict,
catalog_class=PPOCatalog,
)

return config


def dummy_torch_ppo_loss(module, batch, fwd_out):
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
Expand All @@ -34,12 +45,21 @@ def dummy_torch_ppo_loss(module, batch, fwd_out):
return loss


def _get_ppo_module(env, lstm, observation_space):
return PPOTorchRLModule(
observation_space=observation_space,
action_space=env.action_space,
model_config=DefaultModelConfig(use_lstm=lstm),
catalog_class=PPOCatalog,
def dummy_tf_ppo_loss(module, batch, fwd_out):
adv = batch[Columns.REWARDS] - module.compute_values(batch)
action_dist_class = module.get_train_action_dist_cls()
action_probs = action_dist_class.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
).logp(batch[Columns.ACTIONS])
actor_loss = -tf.reduce_mean(action_probs * adv)
critic_loss = tf.reduce_mean(tf.square(adv))
return actor_loss + critic_loss


def _get_ppo_module(framework, env, lstm, observation_space):
model_config_dict = {"use_lstm": lstm}
config = get_expected_module_config(
env, model_config_dict=model_config_dict, observation_space=observation_space
)


Expand Down Expand Up @@ -98,6 +118,7 @@ def test_rollouts(self):

def test_forward_train(self):
# TODO: Add FrozenLake-v1 to cover LSTM case.
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]
lstm = [False, True]
config_combinations = [env_names, lstm]
Expand Down
28 changes: 5 additions & 23 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
Expand Down Expand Up @@ -30,32 +30,14 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return output

@override(RLModule)
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward_inference(batch)

@override(RLModule)
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
if self.config.inference_only:
raise RuntimeError(
"Trying to train a module that is not a learner module. Set the "
"flag `inference_only=False` when building the module."
)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Train forward pass (keep features for possible shared value func. call)."""
output = {}

# Shared encoder.
encoder_outs = self.encoder(batch)
output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC]
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Value head.
vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC])
# Squeeze out last dim (value function node).
output[Columns.VF_PREDS] = vf_out.squeeze(-1)

# Policy head.
action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
output[Columns.ACTION_DIST_INPUTS] = action_logits

output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
return output

@override(ValueFunctionAPI)
Expand Down
66 changes: 3 additions & 63 deletions rllib/core/rl_module/multi_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import (
override,
Expand Down Expand Up @@ -138,6 +137,7 @@ def __init__(
def setup(self):
"""Sets up the underlying, individual RLModules."""
self._rl_modules = {}
self._check_module_configs(self.config.modules)
# Make sure all individual RLModules have the same framework OR framework=None.
framework = None
for module_id, rl_module_spec in self.rl_module_specs.items():
Expand Down Expand Up @@ -404,54 +404,8 @@ def __len__(self) -> int:
"""Returns the number of RLModules within this MultiRLModule."""
return len(self._rl_modules)

The underlying single-agent RLModules will check the input specs.
"""
return []
@override(RLModule)
def _forward_train(
self, batch: MultiAgentBatch, **kwargs
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_train pass.

Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).

Returns:
The output of the forward_train pass the specified modules.
"""
return self._run_forward_pass("forward_train", batch, **kwargs)
@override(RLModule)
def _forward_inference(
self, batch: MultiAgentBatch, **kwargs
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_inference pass.

Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).

Returns:
The output of the forward_inference pass the specified modules.
"""
return self._run_forward_pass("forward_inference", batch, **kwargs)
@override(RLModule)
def _forward_exploration(
self, batch: MultiAgentBatch, **kwargs
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_exploration pass.

Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).

Returns:
The output of the forward_exploration pass the specified modules.
"""
return self._run_forward_pass("forward_exploration", batch, **kwargs)
def __repr__(self) -> str:
return f"MARL({pprint.pformat(self._rl_modules)})"

@override(RLModule)
def get_state(
Expand Down Expand Up @@ -588,20 +542,6 @@ def _check_module_configs(cls, module_configs: Dict[ModuleID, Any]):
if not isinstance(module_spec, RLModuleSpec):
raise ValueError(f"Module {module_id} is not a RLModuleSpec object.")

@classmethod
def _check_module_specs(cls, rl_module_specs: Dict[ModuleID, RLModuleSpec]):
"""Checks the individual RLModuleSpecs for validity.

Args:
rl_module_specs: Dict mapping ModuleIDs to the respective RLModuleSpec.

Raises:
ValueError: If any RLModuleSpec is invalid.
"""
for module_id, rl_module_spec in rl_module_specs.items():
if not isinstance(rl_module_spec, RLModuleSpec):
raise ValueError(f"Module {module_id} is not a RLModuleSpec object.")
def _check_module_exists(self, module_id: ModuleID) -> None:
if module_id not in self._rl_modules:
raise KeyError(
Expand Down
109 changes: 69 additions & 40 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@

import gymnasium as gym

if TYPE_CHECKING:
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
)
from ray.rllib.core.models.catalog import Catalog

from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.typing import SpecType
Expand All @@ -35,9 +28,16 @@
serialize_type,
deserialize_type,
)
from ray.rllib.utils.typing import SampleBatchType, StateDict
from ray.rllib.utils.typing import StateDict
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
)
from ray.rllib.core.models.catalog import Catalog


@PublicAPI(stability="alpha")
@dataclass
Expand Down Expand Up @@ -550,43 +550,26 @@ def get_train_action_dist_cls(self) -> Type[Distribution]:
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Generic forward pass method, used in all phases of training and evaluation.
By default, RLlib assumes that the module is non-recurrent if the initial
state is an empty dict and recurrent otherwise.
This behavior can be overridden by implementing this method.
"""
initial_state = self.get_initial_state()
assert isinstance(initial_state, dict), (
"The initial state of an RLModule must be a dict, but is "
f"{type(initial_state)} instead."
)
return bool(initial_state)
If you need a more nuanced distinction between forward passes in the different
phases of training and evaluation, override the following methods instead:
For distinct action computation logic w/o exploration, override the
`self._forward_inference()` method.
For distinct action computation logic with exploration, override the
`self._forward_exploration()` method.
For distinct forward pass logic before loss computation, override the
`self._forward_train()` method.
@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_inference(self) -> SpecType:
"""Returns the output specs of the `forward_inference()` method.
Override this method to customize the output specs of the inference call.
The default implementation requires the `forward_inference()` method to return
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_exploration(self) -> SpecType:
"""Returns the output specs of the `forward_exploration()` method.
Args:
batch: The input batch.
**kwargs: Additional keyword arguments.
Override this method to customize the output specs of the exploration call.
The default implementation requires the `forward_exploration()` method to return
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
Returns:
The output of the forward pass.
"""
return [Columns.ACTION_DIST_INPUTS]

def output_specs_train(self) -> SpecType:
"""Returns the output specs of the forward_train method."""
return {}

@check_input_specs("_input_specs_inference")
@check_output_specs("_output_specs_inference")
def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.
Expand Down Expand Up @@ -616,6 +599,8 @@ def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
return self._forward(batch, **kwargs)

@check_input_specs("_input_specs_exploration")
@check_output_specs("_output_specs_exploration")
def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.
Expand Down Expand Up @@ -645,6 +630,8 @@ def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any
"""
return self._forward(batch, **kwargs)

@check_input_specs("_input_specs_train")
@check_output_specs("_output_specs_train")
def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""DO NOT OVERRIDE! Forward-pass during training called from the learner.
Expand Down Expand Up @@ -756,6 +743,48 @@ def get_ctor_args_and_kwargs(self):
}, # **kwargs
)

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_inference(self) -> SpecType:
"""Returns the output specs of the `forward_inference()` method.
Override this method to customize the output specs of the inference call.
The default implementation requires the `forward_inference()` method to return
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
def output_specs_exploration(self) -> SpecType:
"""Returns the output specs of the `forward_exploration()` method.
Override this method to customize the output specs of the exploration call.
The default implementation requires the `forward_exploration()` method to return
a dict that has `action_dist` key and its value is an instance of
`Distribution`.
"""
return [Columns.ACTION_DIST_INPUTS]

def output_specs_train(self) -> SpecType:
"""Returns the output specs of the forward_train method."""
return {}

def input_specs_inference(self) -> SpecType:
"""Returns the input specs of the forward_inference method."""
return self._default_input_specs()

def input_specs_exploration(self) -> SpecType:
"""Returns the input specs of the forward_exploration method."""
return self._default_input_specs()

def input_specs_train(self) -> SpecType:
"""Returns the input specs of the forward_train method."""
return self._default_input_specs()

def _default_input_specs(self) -> SpecType:
"""Returns the default input specs."""
return [Columns.OBS]

def as_multi_rl_module(self) -> "MultiRLModule":
"""Returns a multi-agent wrapper around this module."""
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
Expand Down
Loading

0 comments on commit 9301790

Please sign in to comment.