-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Examples folder cleanup: ModelV2 -> RLModule wrapper for migrating to new API stack. #47425
Changes from all commits
8aec844
69c4db9
20614f1
1b96242
ef3a2c5
8bdc5ab
c3ce0c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI | ||
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI | ||
|
||
|
||
__all__ = [ | ||
"TargetNetworkAPI", | ||
"ValueFunctionAPI", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Any, Dict | ||
|
||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.rl_module.apis import ValueFunctionAPI | ||
from ray.rllib.core.rl_module.torch import TorchRLModule | ||
from ray.rllib.models.torch.torch_distributions import ( | ||
TorchCategorical, | ||
TorchDiagGaussian, | ||
TorchMultiCategorical, | ||
TorchMultiDistribution, | ||
TorchSquashedGaussian, | ||
) | ||
from ray.rllib.models.torch.torch_action_dist import ( | ||
TorchCategorical as OldTorchCategorical, | ||
TorchDiagGaussian as OldTorchDiagGaussian, | ||
TorchMultiActionDistribution as OldTorchMultiActionDistribution, | ||
TorchMultiCategorical as OldTorchMultiCategorical, | ||
TorchSquashedGaussian as OldTorchSquashedGaussian, | ||
) | ||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 | ||
from ray.rllib.utils.annotations import override | ||
|
||
|
||
class ModelV2ToRLModule(TorchRLModule, ValueFunctionAPI): | ||
"""An RLModule containing a (old stack) ModelV2, provided by a policy checkpoint.""" | ||
|
||
@override(TorchRLModule) | ||
def setup(self): | ||
super().setup() | ||
|
||
# Get the policy checkpoint from the `model_config_dict`. | ||
policy_checkpoint_dir = self.config.model_config_dict.get( | ||
"policy_checkpoint_dir" | ||
) | ||
if policy_checkpoint_dir is None: | ||
raise ValueError( | ||
"The `model_config_dict` of your RLModule must contain a " | ||
"`policy_checkpoint_dir` key pointing to the policy checkpoint " | ||
"directory! You can find this dir under the Algorithm's checkpoint dir " | ||
"in subdirectory: [algo checkpoint dir]/policies/[policy ID, e.g. " | ||
"`default_policy`]." | ||
) | ||
|
||
# Create a temporary policy object. | ||
policy = TorchPolicyV2.from_checkpoint(policy_checkpoint_dir) | ||
self._model_v2 = policy.model | ||
|
||
# Translate the action dist classes from the old API stack to the new. | ||
self._action_dist_class = self._translate_dist_class(policy.dist_class) | ||
|
||
# Erase the torch policy from memory, so it can be garbage collected. | ||
del policy | ||
|
||
def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
nn_output, state_out = self._model_v2(batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice!!! So simple. I wonder if this works fine with any more complicated setup like LSTM or MARWIL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not :| but it's just a start ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a LSTM example script coming up in another PR (the one that adds the wrapper "from config"). |
||
# Interpret the NN output as action logits. | ||
output = {Columns.ACTION_DIST_INPUTS: nn_output} | ||
# Add the `state_out` to the `output`, new API stack style. | ||
if state_out: | ||
output[Columns.STATE_OUT] = {} | ||
for i, o in enumerate(state_out): | ||
output[Columns.STATE_OUT][i] = o | ||
|
||
return output | ||
|
||
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
return self._forward_inference(batch, **kwargs) | ||
|
||
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||
out = self._forward_inference(batch, **kwargs) | ||
out[Columns.ACTION_LOGP] = self._action_dist_class( | ||
out[Columns.ACTION_DIST_INPUTS] | ||
).logp(batch[Columns.ACTIONS]) | ||
out[Columns.VF_PREDS] = self._model_v2.value_function() | ||
return out | ||
|
||
def compute_values(self, batch: Dict[str, Any]): | ||
self._model_v2(batch) | ||
return self._model_v2.value_function() | ||
|
||
def get_inference_action_dist_cls(self): | ||
return self._action_dist_class | ||
|
||
def get_exploration_action_dist_cls(self): | ||
return self._action_dist_class | ||
|
||
def get_train_action_dist_cls(self): | ||
return self._action_dist_class | ||
|
||
def _translate_dist_class(self, old_dist_class): | ||
map_ = { | ||
OldTorchCategorical: TorchCategorical, | ||
OldTorchDiagGaussian: TorchDiagGaussian, | ||
OldTorchMultiActionDistribution: TorchMultiDistribution, | ||
OldTorchMultiCategorical: TorchMultiCategorical, | ||
OldTorchSquashedGaussian: TorchSquashedGaussian, | ||
} | ||
if old_dist_class not in map_: | ||
raise ValueError( | ||
f"ModelV2ToRLModule does NOT support {old_dist_class} action " | ||
f"distributions yet!" | ||
) | ||
|
||
return map_[old_dist_class] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import pathlib | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
|
||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.rllib.core.rl_module.rl_module import RLModuleConfig, RLModuleSpec | ||
from ray.rllib.examples.rl_modules.classes.modelv2_to_rlm import ModelV2ToRLModule | ||
from ray.rllib.utils.metrics import ( | ||
ENV_RUNNER_RESULTS, | ||
EPISODE_RETURN_MEAN, | ||
) | ||
from ray.rllib.utils.spaces.space_utils import batch | ||
|
||
|
||
if __name__ == "__main__": | ||
# Configure and train an old stack default ModelV2. | ||
config = ( | ||
PPOConfig() | ||
# Old API stack. | ||
.api_stack( | ||
enable_env_runner_and_connector_v2=False, | ||
enable_rl_module_and_learner=False, | ||
) | ||
.environment("CartPole-v1") | ||
.training( | ||
lr=0.0003, | ||
num_sgd_iter=6, | ||
vf_loss_coeff=0.01, | ||
) | ||
) | ||
algo_old_stack = config.build() | ||
|
||
min_return_old_stack = 100.0 | ||
while True: | ||
results = algo_old_stack.train() | ||
print(results) | ||
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack: | ||
print( | ||
f"Reached episode return of {min_return_old_stack} -> stopping " | ||
"old API stack training." | ||
) | ||
break | ||
|
||
checkpoint = algo_old_stack.save() | ||
policy_path = ( | ||
pathlib.Path(checkpoint.checkpoint.path) / "policies" / "default_policy" | ||
) | ||
assert policy_path.is_dir() | ||
algo_old_stack.stop() | ||
|
||
print("done") | ||
|
||
# Move the old API stack (trained) ModelV2 into the new API stack's RLModule. | ||
# Run a simple CartPole inference experiment. | ||
env = gym.make("CartPole-v1", render_mode="human") | ||
rl_module = ModelV2ToRLModule( | ||
config=RLModuleConfig( | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
model_config_dict={"policy_checkpoint_dir": policy_path}, | ||
), | ||
) | ||
|
||
obs, _ = env.reset() | ||
env.render() | ||
done = False | ||
episode_return = 0.0 | ||
while not done: | ||
output = rl_module.forward_inference({"obs": torch.from_numpy(batch([obs]))}) | ||
action_logits = output["action_dist_inputs"].detach().numpy()[0] | ||
action = np.argmax(action_logits) | ||
obs, reward, terminated, truncated, _ = env.step(action) | ||
done = terminated or truncated | ||
episode_return += reward | ||
env.render() | ||
|
||
print(f"Ran episode with trained ModelV2: return={episode_return}") | ||
|
||
# Continue training with the (checkpointed) ModelV2. | ||
|
||
# We change the original (old API stack) `config` into a new API stack one: | ||
config = config.api_stack( | ||
enable_rl_module_and_learner=True, | ||
enable_env_runner_and_connector_v2=True, | ||
).rl_module( | ||
rl_module_spec=RLModuleSpec( | ||
module_class=ModelV2ToRLModule, | ||
model_config_dict={"policy_checkpoint_dir": policy_path}, | ||
), | ||
) | ||
|
||
# Build the new stack algo. | ||
algo_new_stack = config.build() | ||
|
||
# Train until a higher return. | ||
min_return_new_stack = 450.0 | ||
passed = False | ||
for i in range(50): | ||
results = algo_new_stack.train() | ||
print(results) | ||
# Make sure that the model's weights from the old API stack training | ||
# were properly transferred to the new API RLModule wrapper. Thus, even | ||
# after only one iteration of new stack training, we already expect the | ||
# return to be higher than it was at the end of the old stack training. | ||
assert results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_old_stack | ||
if results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= min_return_new_stack: | ||
print( | ||
f"Reached episode return of {min_return_new_stack} -> stopping " | ||
"new API stack training." | ||
) | ||
passed = True | ||
break | ||
|
||
if not passed: | ||
raise ValueError( | ||
"Continuing training on the new stack did not succeed! Last return: " | ||
f"{results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,11 +31,9 @@ | |
} | ||
) | ||
.training( | ||
gamma=0.99, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah thanks. Forgot these ... |
||
lr=0.0003, | ||
num_sgd_iter=6, | ||
vf_loss_coeff=0.01, | ||
use_kl_loss=True, | ||
) | ||
.evaluation( | ||
evaluation_num_env_runners=1, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!