Skip to content

Commit

Permalink
Merge branch 'master' into update-docstring-lstm-encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 committed Sep 27, 2024
2 parents fa7ac42 + 6b44557 commit 370dde3
Show file tree
Hide file tree
Showing 124 changed files with 1,394 additions and 5,697 deletions.
1 change: 1 addition & 0 deletions doc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ doctest(
"source/rllib/rllib-sample-collection.rst",
],
),
data = ["//rllib:cartpole-v1_large"],
tags = ["team:rllib"],
)

Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib/doc_code/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import tempfile

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.checkpoints import convert_to_msgpack_checkpoint


# Base config used for both pickle-based checkpoint and msgpack-based one.
config = DQNConfig().environment("CartPole-v1")
config = PPOConfig().environment("CartPole-v1").env_runners(num_env_runners=0)
# Build algorithm object.
algo1 = config.build()

Expand Down
36 changes: 30 additions & 6 deletions doc/source/rllib/doc_code/replay_buffer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,31 @@


# __sphinx_doc_replay_buffer_type_specification__begin__
config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer})
config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(replay_buffer_config={"type": ReplayBuffer})
)

another_config = DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"})
another_config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(replay_buffer_config={"type": "ReplayBuffer"})
)


yet_another_config = DQNConfig().training(
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
yet_another_config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
)
)

validate_buffer_config(config)
Expand Down Expand Up @@ -75,13 +93,16 @@ def sample(

config = (
DQNConfig()
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.environment(env="CartPole-v1")
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
)

tune.Tuner(
"DQN",
param_space=config.to_dict(),
param_space=config,
run_config=air.RunConfig(
stop={"training_iteration": 1},
),
Expand Down Expand Up @@ -127,6 +148,9 @@ def sample(
# __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__
config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(
replay_buffer_config={
"type": "MultiAgentReplayBuffer",
Expand Down
49 changes: 35 additions & 14 deletions doc/source/rllib/doc_code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,31 @@
# __query_action_dist_start__
# Get a reference to the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.framework("tf2")
.env_runners(num_env_runners=0)
.build()
)
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>

policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>

# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
logits, _ = policy.model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])

# Compute action distribution given logits
Expand All @@ -57,14 +65,14 @@
# Query the distribution for samples, sample logps
dist.sample()
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
dist.logp([1])
dist.logp(torch.tensor([1]))
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>

# Get the estimated values for the most recent forward pass
policy.model.value_function()
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>

policy.model.base_model.summary()
print(policy.model)
"""
Model: "model"
_____________________________________________________________________
Expand Down Expand Up @@ -95,23 +103,36 @@
# __get_q_values_dqn_start__
# Get a reference to the model through the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = DQNConfig().environment("CartPole-v1").framework("tf2").build()
algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>

# List of all model variables
model.variables()
list(model.parameters())

# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is
# examples/offline_rl/saving_experiences.py
model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
model_out = model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)

# Access the base Keras models (all default models have a base)
model.base_model.summary()
print(model)
"""
Model: "model"
_______________________________________________________________________
Expand All @@ -132,16 +153,16 @@
"""

# Access the Q value model (specific to DQN)
print(model.get_q_value_distributions(model_out)[0])
print(model.get_q_value_distributions(model_out[0])[0])
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
# ^ exact numbers may differ due to randomness

model.q_value_head.summary()
print(model.advantage_module)

# Access the state value model (specific to DQN)
print(model.get_state_value(model_out))
print(model.get_state_value(model_out[0]))
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
# ^ exact number may differ due to randomness

model.state_value_head.summary()
print(model.value_module)
# __get_q_values_dqn_end__
2 changes: 2 additions & 0 deletions python/ray/tests/test_minimal_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def test_correct_python_version():


class MockBaseModel:
model_fields = {}

def __init__(self, *args, **kwargs):
pass

Expand Down
7 changes: 5 additions & 2 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,11 +2232,13 @@ def test_list_get_jobs(shutdown_only):
)

def verify():
job_data = list_jobs()[0]
job_data = list_jobs(detail=True)[0]
print(job_data)
job_id_from_api = job_data["submission_id"]
assert job_data["status"] == "SUCCEEDED"
assert job_id == job_id_from_api
assert job_data["start_time"] > 0
assert job_data["end_time"] > 0
return True

wait_for_condition(verify)
Expand All @@ -2257,10 +2259,11 @@ def f():
run_string_as_driver(script)

def verify():
jobs = list_jobs(filters=[("type", "=", "DRIVER")])
jobs = list_jobs(filters=[("type", "=", "DRIVER")], detail=True)
assert len(jobs) == 2, "1 test driver + 1 script run above"
for driver_job in jobs:
assert driver_job["driver_info"] is not None
assert driver_job["start_time"] > 0

sub_jobs = list_jobs(filters=[("type", "=", "SUBMISSION")])
assert len(sub_jobs) == 1
Expand Down
18 changes: 10 additions & 8 deletions python/ray/util/state/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def list_columns(cls, detail: bool = True) -> List[str]:
@classmethod
def columns(cls) -> Set[str]:
"""Return a set of all columns."""
return set(cls.list_columns())
return set(cls.list_columns(detail=True))

@classmethod
def filterable_columns(cls) -> Set[str]:
Expand Down Expand Up @@ -556,7 +556,7 @@ def humanify(cls, state: dict) -> dict:
return state

@classmethod
def list_columns(cls, detail: bool = False) -> List[str]:
def list_columns(cls, detail: bool = True) -> List[str]:
if not detail:
return [
"job_id",
Expand All @@ -568,7 +568,7 @@ def list_columns(cls, detail: bool = False) -> List[str]:
"error_type",
"driver_info",
]
if isinstance(JobDetails, object):
if JobDetails is None:
# We don't have pydantic in the dashboard. This is because
# we call this method at module import time, so we need to
# check if the class is a pydantic model.
Expand All @@ -577,9 +577,9 @@ def list_columns(cls, detail: bool = False) -> List[str]:
# TODO(aguo): Once we only support pydantic 2, we can remove this if check.
# In pydantic 2.0, `__fields__` has been renamed to `model_fields`.
return (
JobDetails.model_fields
list(JobDetails.model_fields.keys())
if hasattr(JobDetails, "model_fields")
else JobDetails.__fields__
else list(JobDetails.__fields__.keys())
)

def asdict(self):
Expand Down Expand Up @@ -1667,17 +1667,19 @@ def remove_ansi_escape_codes(text: str) -> str:
return re.sub(r"\x1b[^m]*m", "", text)


def dict_to_state(d: Dict, state_source: StateResource) -> StateSchema:
def dict_to_state(d: Dict, state_resource: StateResource) -> StateSchema:

"""Convert a dict to a state schema.
Args:
d: a dict to convert.
state_schema: a schema to convert to.
state_resource: the state resource to convert to.
Returns:
A state schema.
"""
try:
return resource_to_schema(state_source)(**d)
return resource_to_schema(state_resource)(**d)

except Exception as e:
raise RayStateApiException(f"Failed to convert {d} to StateSchema: {e}") from e
Loading

0 comments on commit 370dde3

Please sign in to comment.