Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Examples folder cleanup: ModelV2 -> RLModule wrapper for migrating to new API stack. #47425

Merged
merged 7 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 23 additions & 28 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3459,20 +3459,20 @@ py_test(
# subdirectory: rl_modules/
# ....................................
py_test(
name = "examples/rl_modules/action_masking_rlm",
main = "examples/rl_modules/action_masking_rlm.py",
name = "examples/rl_modules/action_masking_rl_module",
main = "examples/rl_modules/action_masking_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/action_masking_rlm.py"],
srcs = ["examples/rl_modules/action_masking_rl_module.py"],
args = ["--enable-new-api-stack", "--stop-iters=5"],
)

py_test(
name = "examples/rl_modules/autoregressive_actions_rlm",
main = "examples/rl_modules/autoregressive_actions_rlm.py",
name = "examples/rl_modules/autoregressive_actions_rl_module",
main = "examples/rl_modules/autoregressive_actions_rl_module.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/rl_modules/autoregressive_actions_rlm.py"],
srcs = ["examples/rl_modules/autoregressive_actions_rl_module.py"],
args = ["--enable-new-api-stack"],
)
py_test(
Expand Down Expand Up @@ -3501,6 +3501,13 @@ py_test(
srcs = ["examples/rl_modules/classes/mobilenet_rlm.py"],
)

py_test(
name = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint",
main = "examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py",
tags = ["team:rllib", "examples"],
size = "large",
srcs = ["examples/rl_modules/migrate_modelv2_to_new_api_stack_by_policy_checkpoint.py"],
)
py_test(
name = "examples/rl_modules/pretraining_single_agent_training_multi_agent",
main = "examples/rl_modules/pretraining_single_agent_training_multi_agent.py",
Expand All @@ -3510,6 +3517,7 @@ py_test(
args = ["--enable-new-api-stack", "--num-agents=2", "--stop-iters-pretraining=5", "--stop-iters=20", "--stop-reward=150.0"],
)

#@OldAPIStack
py_test(
name = "examples/autoregressive_action_dist_tf",
main = "examples/autoregressive_action_dist.py",
Expand All @@ -3519,6 +3527,7 @@ py_test(
args = ["--as-test", "--framework=tf", "--stop-reward=150", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/autoregressive_action_dist_torch",
main = "examples/autoregressive_action_dist.py",
Expand All @@ -3528,6 +3537,7 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=150", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_impala_tf2",
main = "examples/cartpole_lstm.py",
Expand All @@ -3537,6 +3547,7 @@ py_test(
args = ["--run=IMPALA", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_impala_torch",
main = "examples/cartpole_lstm.py",
Expand All @@ -3546,6 +3557,7 @@ py_test(
args = ["--run=IMPALA", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_tf2",
main = "examples/cartpole_lstm.py",
Expand All @@ -3555,6 +3567,7 @@ py_test(
args = ["--run=PPO", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_torch",
main = "examples/cartpole_lstm.py",
Expand All @@ -3564,6 +3577,7 @@ py_test(
args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "examples/cartpole_lstm_ppo_torch_with_prev_a_and_r",
main = "examples/cartpole_lstm.py",
Expand Down Expand Up @@ -3613,6 +3627,7 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=6.0"]
)

#@OldAPIStack
py_test(
name = "examples/metrics/custom_metrics_and_callbacks",
main = "examples/metrics/custom_metrics_and_callbacks.py",
Expand All @@ -3622,6 +3637,7 @@ py_test(
args = ["--stop-iters=2"]
)

#@OldAPIStack
py_test(
name = "examples/custom_model_loss_and_metrics_ppo_tf",
main = "examples/custom_model_loss_and_metrics.py",
Expand All @@ -3633,6 +3649,7 @@ py_test(
args = ["--run=PPO", "--stop-iters=1", "--framework=tf","--input-files=tests/data/cartpole"]
)

#@OldAPIStack
py_test(
name = "examples/custom_model_loss_and_metrics_ppo_torch",
main = "examples/custom_model_loss_and_metrics.py",
Expand All @@ -3644,28 +3661,6 @@ py_test(
args = ["--run=PPO", "--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"]
)

py_test(
name = "examples/custom_model_loss_and_metrics_pg_tf",
main = "examples/custom_model_loss_and_metrics.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
# Include the json data file.
data = ["tests/data/cartpole/small.json"],
srcs = ["examples/custom_model_loss_and_metrics.py"],
args = ["--stop-iters=1", "--framework=tf", "--input-files=tests/data/cartpole"]
)

py_test(
name = "examples/custom_model_loss_and_metrics_pg_torch",
main = "examples/custom_model_loss_and_metrics.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
# Include the json data file.
data = ["tests/data/cartpole/small.json"],
srcs = ["examples/custom_model_loss_and_metrics.py"],
args = ["--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"]
)

py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_tf2",
main = "examples/custom_recurrent_rnn_tokenizer.py",
Expand Down
8 changes: 8 additions & 0 deletions rllib/core/rl_module/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


__all__ = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

"TargetNetworkAPI",
"ValueFunctionAPI",
]
104 changes: 104 additions & 0 deletions rllib/examples/rl_modules/classes/modelv2_to_rlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Any, Dict

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis import ValueFunctionAPI
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.models.torch.torch_distributions import (
TorchCategorical,
TorchDiagGaussian,
TorchMultiCategorical,
TorchMultiDistribution,
TorchSquashedGaussian,
)
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical as OldTorchCategorical,
TorchDiagGaussian as OldTorchDiagGaussian,
TorchMultiActionDistribution as OldTorchMultiActionDistribution,
TorchMultiCategorical as OldTorchMultiCategorical,
TorchSquashedGaussian as OldTorchSquashedGaussian,
)
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override


class ModelV2ToRLModule(TorchRLModule, ValueFunctionAPI):
"""An RLModule containing a (old stack) ModelV2, provided by a policy checkpoint."""

@override(TorchRLModule)
def setup(self):
super().setup()

# Get the policy checkpoint from the `model_config_dict`.
policy_checkpoint_dir = self.config.model_config_dict.get(
"policy_checkpoint_dir"
)
if policy_checkpoint_dir is None:
raise ValueError(
"The `model_config_dict` of your RLModule must contain a "
"`policy_checkpoint_dir` key pointing to the policy checkpoint "
"directory! You can find this dir under the Algorithm's checkpoint dir "
"in subdirectory: [algo checkpoint dir]/policies/[policy ID, e.g. "
"`default_policy`]."
)

# Create a temporary policy object.
policy = TorchPolicyV2.from_checkpoint(policy_checkpoint_dir)
self._model_v2 = policy.model

# Translate the action dist classes from the old API stack to the new.
self._action_dist_class = self._translate_dist_class(policy.dist_class)

# Erase the torch policy from memory, so it can be garbage collected.
del policy

def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
nn_output, state_out = self._model_v2(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!! So simple. I wonder if this works fine with any more complicated setup like LSTM or MARWIL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not :| but it's just a start ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a LSTM example script coming up in another PR (the one that adds the wrapper "from config").

# Interpret the NN output as action logits.
output = {Columns.ACTION_DIST_INPUTS: nn_output}
# Add the `state_out` to the `output`, new API stack style.
if state_out:
output[Columns.STATE_OUT] = {}
for i, o in enumerate(state_out):
output[Columns.STATE_OUT][i] = o

return output

def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward_inference(batch, **kwargs)

def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
out = self._forward_inference(batch, **kwargs)
out[Columns.ACTION_LOGP] = self._action_dist_class(
out[Columns.ACTION_DIST_INPUTS]
).logp(batch[Columns.ACTIONS])
out[Columns.VF_PREDS] = self._model_v2.value_function()
return out

def compute_values(self, batch: Dict[str, Any]):
self._model_v2(batch)
return self._model_v2.value_function()

def get_inference_action_dist_cls(self):
return self._action_dist_class

def get_exploration_action_dist_cls(self):
return self._action_dist_class

def get_train_action_dist_cls(self):
return self._action_dist_class

def _translate_dist_class(self, old_dist_class):
map_ = {
OldTorchCategorical: TorchCategorical,
OldTorchDiagGaussian: TorchDiagGaussian,
OldTorchMultiActionDistribution: TorchMultiDistribution,
OldTorchMultiCategorical: TorchMultiCategorical,
OldTorchSquashedGaussian: TorchSquashedGaussian,
}
if old_dist_class not in map_:
raise ValueError(
f"ModelV2ToRLModule does NOT support {old_dist_class} action "
f"distributions yet!"
)

return map_[old_dist_class]
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pathlib

import gymnasium as gym
import numpy as np
import torch

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleConfig, RLModuleSpec
from ray.rllib.examples.rl_modules.classes.modelv2_to_rlm import ModelV2ToRLModule
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
)
from ray.rllib.utils.spaces.space_utils import batch


if __name__ == "__main__":
# Configure and train an old stack default ModelV2.
config = (
PPOConfig()
# Old API stack.
.api_stack(
enable_env_runner_and_connector_v2=False,
enable_rl_module_and_learner=False,
)
.environment("CartPole-v1")
.training(
lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
)
)
algo_old_stack = config.build()

min_return_old_stack = 100.0
while True:
results = algo_old_stack.train()
print(results)
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack:
print(
f"Reached episode return of {min_return_old_stack} -> stopping "
"old API stack training."
)
break

checkpoint = algo_old_stack.save()
policy_path = (
pathlib.Path(checkpoint.checkpoint.path) / "policies" / "default_policy"
)
assert policy_path.is_dir()
algo_old_stack.stop()

print("done")

# Move the old API stack (trained) ModelV2 into the new API stack's RLModule.
# Run a simple CartPole inference experiment.
env = gym.make("CartPole-v1", render_mode="human")
rl_module = ModelV2ToRLModule(
config=RLModuleConfig(
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"policy_checkpoint_dir": policy_path},
),
)

obs, _ = env.reset()
env.render()
done = False
episode_return = 0.0
while not done:
output = rl_module.forward_inference({"obs": torch.from_numpy(batch([obs]))})
action_logits = output["action_dist_inputs"].detach().numpy()[0]
action = np.argmax(action_logits)
obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_return += reward
env.render()

print(f"Ran episode with trained ModelV2: return={episode_return}")

# Continue training with the (checkpointed) ModelV2.

# We change the original (old API stack) `config` into a new API stack one:
config = config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
).rl_module(
rl_module_spec=RLModuleSpec(
module_class=ModelV2ToRLModule,
model_config_dict={"policy_checkpoint_dir": policy_path},
),
)

# Build the new stack algo.
algo_new_stack = config.build()

# Train until a higher return.
min_return_new_stack = 450.0
passed = False
for i in range(50):
results = algo_new_stack.train()
print(results)
# Make sure that the model's weights from the old API stack training
# were properly transferred to the new API RLModule wrapper. Thus, even
# after only one iteration of new stack training, we already expect the
# return to be higher than it was at the end of the old stack training.
assert results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_new_stack:
print(
f"Reached episode return of {min_return_new_stack} -> stopping "
"new API stack training."
)
passed = True
break

if not passed:
raise ValueError(
"Continuing training on the new stack did not succeed! Last return: "
f"{results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}"
)
2 changes: 0 additions & 2 deletions rllib/tuned_examples/bc/cartpole_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@
}
)
.training(
gamma=0.99,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks. Forgot these ...

lr=0.0003,
num_sgd_iter=6,
vf_loss_coeff=0.01,
use_kl_loss=True,
)
.evaluation(
evaluation_num_env_runners=1,
Expand Down
Loading
Loading