Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; new API stack by default] Switch on new API stack by default for SAC and DQN. #47217

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/source/rllib/doc_code/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import tempfile

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.checkpoints import convert_to_msgpack_checkpoint


# Base config used for both pickle-based checkpoint and msgpack-based one.
config = DQNConfig().environment("CartPole-v1")
config = PPOConfig().environment("CartPole-v1").env_runners(num_env_runners=0)
# Build algorithm object.
algo1 = config.build()

Expand Down
36 changes: 30 additions & 6 deletions doc/source/rllib/doc_code/replay_buffer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,31 @@


# __sphinx_doc_replay_buffer_type_specification__begin__
config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer})
config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(replay_buffer_config={"type": ReplayBuffer})
)

another_config = DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"})
another_config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(replay_buffer_config={"type": "ReplayBuffer"})
)


yet_another_config = DQNConfig().training(
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
yet_another_config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
)
)

validate_buffer_config(config)
Expand Down Expand Up @@ -75,13 +93,16 @@ def sample(

config = (
DQNConfig()
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.environment(env="CartPole-v1")
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
)

tune.Tuner(
"DQN",
param_space=config.to_dict(),
param_space=config,
run_config=air.RunConfig(
stop={"training_iteration": 1},
),
Expand Down Expand Up @@ -127,6 +148,9 @@ def sample(
# __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__
config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(
replay_buffer_config={
"type": "MultiAgentReplayBuffer",
Expand Down
49 changes: 35 additions & 14 deletions doc/source/rllib/doc_code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,31 @@
# __query_action_dist_start__
# Get a reference to the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.framework("tf2")
.env_runners(num_env_runners=0)
.build()
)
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>

policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>

# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
logits, _ = policy.model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])

# Compute action distribution given logits
Expand All @@ -57,14 +65,14 @@
# Query the distribution for samples, sample logps
dist.sample()
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
dist.logp([1])
dist.logp(torch.tensor([1]))
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>

# Get the estimated values for the most recent forward pass
policy.model.value_function()
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>

policy.model.base_model.summary()
print(policy.model)
"""
Model: "model"
_____________________________________________________________________
Expand Down Expand Up @@ -95,23 +103,36 @@
# __get_q_values_dqn_start__
# Get a reference to the model through the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = DQNConfig().environment("CartPole-v1").framework("tf2").build()
algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>

# List of all model variables
model.variables()
list(model.parameters())

# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is
# examples/offline_rl/saving_experiences.py
model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
model_out = model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)

# Access the base Keras models (all default models have a base)
model.base_model.summary()
print(model)
"""
Model: "model"
_______________________________________________________________________
Expand All @@ -132,16 +153,16 @@
"""

# Access the Q value model (specific to DQN)
print(model.get_q_value_distributions(model_out)[0])
print(model.get_q_value_distributions(model_out[0])[0])
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
# ^ exact numbers may differ due to randomness

model.q_value_head.summary()
print(model.advantage_module)

# Access the state value model (specific to DQN)
print(model.get_state_value(model_out))
print(model.get_state_value(model_out[0]))
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
# ^ exact number may differ due to randomness

model.state_value_head.summary()
print(model.value_module)
# __get_q_values_dqn_end__
Loading
Loading