Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Examples folder cleanup: ModelV2 -> RLModule wrapper for migrating to new API stack (by config). #47427

Merged
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3501,6 +3501,13 @@ py_test(
srcs = ["examples/rl_modules/classes/mobilenet_rlm.py"],
)

py_test(
name = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_config",
main = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_config.py",
tags = ["team:rllib", "examples"],
size = "large",
srcs = ["examples/rl_modules/migrate_modelv2_to_new_api_stack_by_config.py"],
)
py_test(
name = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint",
main = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py",
Expand Down
177 changes: 149 additions & 28 deletions rllib/examples/rl_modules/classes/modelv2_to_rlm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pathlib
from typing import Any, Dict

from ray.rllib.core.columns import Columns
import tree
from ray.rllib.core import Columns, DEFAULT_POLICY_ID
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.torch_distributions import (
Expand All @@ -19,30 +21,95 @@
)
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

torch, _ = try_import_torch()


class ModelV2ToRLModule(TorchRLModule, ValueFunctionAPI):
"""An RLModule containing a (old stack) ModelV2, provided by a policy checkpoint."""
"""An RLModule containing a (old stack) ModelV2.

The `ModelV2` may be define either through
- an existing Policy checkpoint
- an existing Algorithm checkpoint (and a policy ID or "default_policy")
- or through an AlgorithmConfig object

The ModelV2 is created in the `setup` and contines to live through the lifetime
of the RLModule.
"""

@override(TorchRLModule)
def setup(self):
super().setup()

# Get the policy checkpoint from the `model_config_dict`.
policy_checkpoint_dir = self.config.model_config_dict.get(
"policy_checkpoint_dir"
)
if policy_checkpoint_dir is None:
raise ValueError(
"The `model_config_dict` of your RLModule must contain a "
"`policy_checkpoint_dir` key pointing to the policy checkpoint "
"directory! You can find this dir under the Algorithm's checkpoint dir "
"in subdirectory: [algo checkpoint dir]/policies/[policy ID, e.g. "
"`default_policy`]."
# Try extracting the policy ID from this RLModule's config dict.
policy_id = self.config.model_config_dict.get("policy_id", DEFAULT_POLICY_ID)

# Try getting the algorithm checkpoint from the `model_config_dict`.
algo_checkpoint_dir = self.config.model_config_dict.get("algo_checkpoint_dir")
if algo_checkpoint_dir:
algo_checkpoint_dir = pathlib.Path(algo_checkpoint_dir)
if not algo_checkpoint_dir.is_dir():
raise ValueError(
"The `model_config_dict` of your RLModule must contain a "
"`algo_checkpoint_dir` key pointing to the algo checkpoint "
"directory! You can find this dir inside the results dir of your "
"experiment. You can then add this path "
"through `config.rl_module(model_config_dict={"
"'algo_checkpoint_dir': [your algo checkpoint dir]})`."
)
policy_checkpoint_dir = algo_checkpoint_dir / "policies" / policy_id
# Try getting the policy checkpoint from the `model_config_dict`.
else:
policy_checkpoint_dir = self.config.model_config_dict.get(
"policy_checkpoint_dir"
)

# Create the ModelV2 from the Policy.
if policy_checkpoint_dir:
policy_checkpoint_dir = pathlib.Path(policy_checkpoint_dir)
if not policy_checkpoint_dir.is_dir():
raise ValueError(
"The `model_config_dict` of your RLModule must contain a "
"`policy_checkpoint_dir` key pointing to the policy checkpoint "
"directory! You can find this dir under the Algorithm's checkpoint "
"dir in subdirectory: [algo checkpoint dir]/policies/[policy ID "
"ex. `default_policy`]. You can then add this path through `config"
".rl_module(model_config_dict={'policy_checkpoint_dir': "
"[your policy checkpoint dir]})`."
)
# Create a temporary policy object.
policy = TorchPolicyV2.from_checkpoint(policy_checkpoint_dir)
# Create the ModelV2 from scratch using the config.
else:
config = self.config.model_config_dict.get("old_api_stack_algo_config")
if not config:
raise ValueError(
"The `model_config_dict` of your RLModule must contain a "
"`algo_config` key with a AlgorithmConfig object in it that "
"contains all the settings that would be necessary to construct a "
"old API stack Algorithm/Policy/ModelV2! You can add this setting "
"through `config.rl_module(model_config_dict={'algo_config': "
"[your old config]})`."
)
# Get the multi-agent policies dict.
policy_dict, _ = config.get_multi_agent_setup(
spaces={
policy_id: (
self.config.observation_space,
self.config.action_space,
),
},
default_policy_class=config.algo_class.get_default_policy_class(config),
)
config = config.to_dict()
config["__policy_id"] = policy_id
policy = policy_dict[policy_id].policy_class(
self.config.observation_space,
self.config.action_space,
config,
)

# Create a temporary policy object.
policy = TorchPolicyV2.from_checkpoint(policy_checkpoint_dir)
self._model_v2 = policy.model

# Translate the action dist classes from the old API stack to the new.
Expand All @@ -51,42 +118,96 @@ def setup(self):
# Erase the torch policy from memory, so it can be garbage collected.
del policy

@override(TorchRLModule)
def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
nn_output, state_out = self._model_v2(batch)
# Interpret the NN output as action logits.
output = {Columns.ACTION_DIST_INPUTS: nn_output}
# Add the `state_out` to the `output`, new API stack style.
if state_out:
output[Columns.STATE_OUT] = {}
for i, o in enumerate(state_out):
output[Columns.STATE_OUT][i] = o

return output
return self._forward_pass(batch, inference=True)

@override(TorchRLModule)
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward_inference(batch, **kwargs)

@override(TorchRLModule)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
out = self._forward_inference(batch, **kwargs)
out = self._forward_pass(batch, inference=False)
out[Columns.ACTION_LOGP] = self._action_dist_class(
out[Columns.ACTION_DIST_INPUTS]
).logp(batch[Columns.ACTIONS])
out[Columns.VF_PREDS] = self._model_v2.value_function()
if Columns.STATE_IN in batch and Columns.SEQ_LENS in batch:
out[Columns.VF_PREDS] = torch.reshape(
out[Columns.VF_PREDS], [len(batch[Columns.SEQ_LENS]), -1]
)
return out

def _forward_pass(self, batch, inference=True):
# Translate states and seq_lens into old API stack formats.
batch = batch.copy()
state_in = batch.pop(Columns.STATE_IN, {})
state_in = [s for i, s in sorted(state_in.items())]
seq_lens = batch.pop(Columns.SEQ_LENS, None)

if state_in:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, now LSTM is possible :)

if inference and seq_lens is None:
seq_lens = torch.tensor(
[1.0] * state_in[0].shape[0], device=state_in[0].device
)
elif not inference:
assert seq_lens is not None
# Perform the actual ModelV2 forward pass.
# A recurrent ModelV2 adds and removes the time-rank itself (whereas in the
# new API stack, the connector pipelines are responsible for doing this) ->
# We have to remove, then re-add the time rank here to make ModelV2 work.
batch = tree.map_structure(
lambda s: torch.reshape(s, [-1] + list(s.shape[2:])), batch
)
nn_output, state_out = self._model_v2(batch, state_in, seq_lens)
# Put back 1ts time rank into nn-output (inference).
if state_in:
if inference:
nn_output = tree.map_structure(
lambda s: torch.unsqueeze(s, axis=1), nn_output
)
else:
nn_output = tree.map_structure(
lambda s: torch.reshape(s, [len(seq_lens), -1] + list(s.shape[1:])),
nn_output,
)
# Interpret the NN output as action logits.
output = {Columns.ACTION_DIST_INPUTS: nn_output}
# Add the `state_out` to the `output`, new API stack style.
if state_out:
output[Columns.STATE_OUT] = {}
for i, o in enumerate(state_out):
output[Columns.STATE_OUT][i] = o

return output

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]):
self._model_v2(batch)
return self._model_v2.value_function()
self._forward_pass(batch, inference=False)
v_preds = self._model_v2.value_function()
if Columns.STATE_IN in batch and Columns.SEQ_LENS in batch:
v_preds = torch.reshape(v_preds, [len(batch[Columns.SEQ_LENS]), -1])
return v_preds

@override(TorchRLModule)
def get_inference_action_dist_cls(self):
return self._action_dist_class

@override(TorchRLModule)
def get_exploration_action_dist_cls(self):
return self._action_dist_class

@override(TorchRLModule)
def get_train_action_dist_cls(self):
return self._action_dist_class

@override(TorchRLModule)
def get_initial_state(self):
"""Converts the initial state list of ModelV2 into a dict (new API stack)."""
init_state_list = self._model_v2.get_initial_state()
return {i: s for i, s in enumerate(init_state_list)}

def _translate_dist_class(self, old_dist_class):
map_ = {
OldTorchCategorical: TorchCategorical,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import DEFAULT_POLICY_ID
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.rl_modules.classes.modelv2_to_rlm import ModelV2ToRLModule
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
)


if __name__ == "__main__":
# Configure an old stack default ModelV2.
config_old_stack = (
PPOConfig()
.environment("CartPole-v1")
.training(
lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
# Change the ModelV2 settings a bit.
model={
"fcnet_hiddens": [32],
"fcnet_activation": "linear",
"use_lstm": True,
"vf_share_layers": True,
},
)
)

# Training with the (configured and wrapped) ModelV2.

# We change the original (old API stack) `config` into a new API stack one:
config_new_stack = (
config_old_stack.copy(copy_frozen=False)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.rl_module(
rl_module_spec=RLModuleSpec(
module_class=ModelV2ToRLModule,
model_config_dict={
"policy_id": DEFAULT_POLICY_ID,
"old_api_stack_algo_config": config_old_stack,
},
),
)
)

# Build the new stack algo.
algo_new_stack = config_new_stack.build()

# Train until a higher return.
min_return_new_stack = 350.0
results = None
passed = False
for i in range(100):
results = algo_new_stack.train()
print(results)
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_new_stack:
print(
f"Reached episode return of {min_return_new_stack} -> stopping "
"new API stack training."
)
passed = True
break

if not passed:
raise ValueError(
"Continuing training on the new stack did not succeed! Last return: "
f"{results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}"
)
Loading