Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New ConnectorV2 API #03: Introduce actual ConnectorV2 API. (#41074) #41212

Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
57e79f9
wip
sven1977 Nov 16, 2023
99d9019
wip
sven1977 Nov 17, 2023
d3dca2f
wip
sven1977 Nov 17, 2023
009a7fd
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Nov 30, 2023
b84b544
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 14, 2023
b0b3c37
LINT
sven1977 Dec 14, 2023
4df7dfe
wip
sven1977 Dec 14, 2023
1de7ebb
wip
sven1977 Dec 14, 2023
a9acbee
wip
sven1977 Dec 14, 2023
5fe97e1
LINT
sven1977 Dec 14, 2023
213f0d1
LINT
sven1977 Dec 14, 2023
50b7fc6
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 14, 2023
90e9c34
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 15, 2023
91b4399
Merge branch 'master' into env_runner_support_connectors_03_connector…
sven1977 Dec 16, 2023
3102238
merge
sven1977 Dec 18, 2023
7618d52
Merge remote-tracking branch 'origin/env_runner_support_connectors_03…
sven1977 Dec 18, 2023
4958597
wip
sven1977 Dec 18, 2023
c40f5b0
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 19, 2023
bdf803d
wip
sven1977 Dec 19, 2023
2649e70
wip
sven1977 Dec 21, 2023
34f8827
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 21, 2023
b58ad31
wip
sven1977 Dec 21, 2023
7bc0ac6
wip
sven1977 Dec 21, 2023
8b8cf06
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 21, 2023
f7dde73
wip
sven1977 Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ py_test(


# --------------------------------------------------------------------
# Connector tests
# Connector(V1) tests
# rllib/connector/
#
# Tag: connector
Expand All @@ -774,6 +774,20 @@ py_test(
srcs = ["connectors/tests/test_agent.py"]
)

# --------------------------------------------------------------------
# ConnectorV2 tests
# rllib/connector/
#
# Tag: connector_v2
# --------------------------------------------------------------------

py_test(
name = "connectors/tests/test_connector_v2",
tags = ["team:rllib", "connector_v2"],
size = "small",
srcs = ["connectors/tests/test_connector_v2.py"]
)

# --------------------------------------------------------------------
# Env tests
# rllib/env/
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,8 @@ def setup(self, config: AlgorithmConfig) -> None:
)

# Only when using RolloutWorkers: Update also the worker set's
# `should_module_be_updated_fn` (analogous to is_policy_to_train).
# `is_policy_to_train` (analogous to LearnerGroup's
# `should_module_be_updated_fn`).
# Note that with the new EnvRunner API in combination with the new stack,
# this information only needs to be kept in the LearnerGroup and not on the
# EnvRunners anymore.
Expand Down
143 changes: 143 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.evaluation.episode import Episode as OldEpisode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -327,6 +329,8 @@ def __init__(self, algo_class=None):
self.num_envs_per_worker = 1
self.create_env_on_local_worker = False
self.enable_connectors = True
self._env_to_module_connector = None
self._module_to_env_connector = None
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
# and `sample_duration_unit` (replacing batch_mode), like we do it
# in the evaluation config).
Expand Down Expand Up @@ -374,6 +378,7 @@ def __init__(self, algo_class=None):
except AttributeError:
pass

self._learner_connector = None
self.optimizer = {}
self.max_requests_in_flight_per_sampler_worker = 2
self._learner_class = None
Expand Down Expand Up @@ -1152,6 +1157,121 @@ class directly. Note that this arg can also be specified via
logger_creator=self.logger_creator,
)

def build_env_to_module_connector(self, env):
from ray.rllib.connectors.env_to_module import (
EnvToModulePipeline,
DefaultEnvToModule,
)

custom_connectors = []
# Create an env-to-module connector pipeline (including RLlib's default
# env->module connector piece) and return it.
if self._env_to_module_connector is not None:
val_ = self._env_to_module_connector(env)

from ray.rllib.connectors.connector_v2 import ConnectorV2

if isinstance(val_, ConnectorV2) and not isinstance(
val_, EnvToModulePipeline
):
custom_connectors = [val_]
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
else:
return val_

pipeline = EnvToModulePipeline(
connectors=custom_connectors,
input_observation_space=env.single_observation_space,
input_action_space=env.single_action_space,
env=env,
)
pipeline.append(
DefaultEnvToModule(
input_observation_space=pipeline.observation_space,
input_action_space=pipeline.action_space,
env=env,
)
)
return pipeline

def build_module_to_env_connector(self, env):

from ray.rllib.connectors.module_to_env import (
DefaultModuleToEnv,
ModuleToEnvPipeline,
)

custom_connectors = []
# Create a module-to-env connector pipeline (including RLlib's default
# module->env connector piece) and return it.
if self._module_to_env_connector is not None:
val_ = self._module_to_env_connector(env)

from ray.rllib.connectors.connector_v2 import ConnectorV2

if isinstance(val_, ConnectorV2) and not isinstance(
val_, ModuleToEnvPipeline
):
custom_connectors = [val_]
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
else:
return val_

pipeline = ModuleToEnvPipeline(
connectors=custom_connectors,
input_observation_space=env.single_observation_space,
input_action_space=env.single_action_space,
env=env,
)
pipeline.append(
DefaultModuleToEnv(
input_observation_space=pipeline.observation_space,
input_action_space=pipeline.action_space,
env=env,
normalize_actions=self.normalize_actions,
clip_actions=self.clip_actions,
)
)
return pipeline

def build_learner_connector(self, input_observation_space, input_action_space):
from ray.rllib.connectors.learner import (
DefaultLearnerConnector,
LearnerConnectorPipeline,
)

custom_connectors = []
# Create a learner connector pipeline (including RLlib's default
# learner connector piece) and return it.
if self._learner_connector is not None:
val_ = self._learner_connector(input_observation_space, input_action_space)

from ray.rllib.connectors.connector_v2 import ConnectorV2

if isinstance(val_, ConnectorV2) and not isinstance(
val_, LearnerConnectorPipeline
):
custom_connectors = [val_]
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
else:
return val_

pipeline = LearnerConnectorPipeline(
connectors=custom_connectors,
input_observation_space=input_observation_space,
input_action_space=input_action_space,
)
pipeline.append(
DefaultLearnerConnector(
input_observation_space=pipeline.observation_space,
input_action_space=pipeline.action_space,
)
)
return pipeline

def build_learner_group(
self,
*,
Expand Down Expand Up @@ -1605,6 +1725,12 @@ def rollouts(
create_env_on_local_worker: Optional[bool] = NotProvided,
sample_collector: Optional[Type[SampleCollector]] = NotProvided,
enable_connectors: Optional[bool] = NotProvided,
env_to_module_connector: Optional[
Callable[[EnvType], "ConnectorV2"]
] = NotProvided,
module_to_env_connector: Optional[
Callable[[EnvType, "RLModule"], "ConnectorV2"]
] = NotProvided,
use_worker_filter_stats: Optional[bool] = NotProvided,
update_worker_filter_stats: Optional[bool] = NotProvided,
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
Expand Down Expand Up @@ -1650,6 +1776,11 @@ def rollouts(
enable_connectors: Use connector based environment runner, so that all
preprocessing of obs and postprocessing of actions are done in agent
and action connectors.
env_to_module_connector: A callable taking an Env as input arg and returning
an env-to-module ConnectorV2 (might be a pipeline) object.
module_to_env_connector: A callable taking an Env and an RLModule as input
args and returning a module-to-env ConnectorV2 (might be a pipeline)
object.
use_worker_filter_stats: Whether to use the workers in the WorkerSet to
update the central filters (held by the local worker). If False, stats
from the workers will not be used and discarded.
Expand Down Expand Up @@ -1737,6 +1868,10 @@ def rollouts(
self.create_env_on_local_worker = create_env_on_local_worker
if enable_connectors is not NotProvided:
self.enable_connectors = enable_connectors
if env_to_module_connector is not NotProvided:
self._env_to_module_connector = env_to_module_connector
if module_to_env_connector is not NotProvided:
self._module_to_env_connector = module_to_env_connector
if use_worker_filter_stats is not NotProvided:
self.use_worker_filter_stats = use_worker_filter_stats
if update_worker_filter_stats is not NotProvided:
Expand Down Expand Up @@ -1855,6 +1990,9 @@ def training(
optimizer: Optional[dict] = NotProvided,
max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided,
learner_class: Optional[Type["Learner"]] = NotProvided,
learner_connector: Optional[
Callable[["RLModule"], "ConnectorV2"]
] = NotProvided,
# Deprecated arg.
_enable_learner_api: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
Expand Down Expand Up @@ -1916,6 +2054,9 @@ def training(
in your experiment of timesteps.
learner_class: The `Learner` class to use for (distributed) updating of the
RLModule. Only used when `_enable_new_api_stack=True`.
learner_connector: A callable taking an env observation space and an env
action space as inputs and returning a learner ConnectorV2 (might be
a pipeline) object.
Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -1960,6 +2101,8 @@ def training(
)
if learner_class is not NotProvided:
self._learner_class = learner_class
if learner_connector is not NotProvided:
self._learner_connector = learner_connector

return self

Expand Down
Empty file.
136 changes: 136 additions & 0 deletions rllib/connectors/common/frame_stacking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import numpy as np
from typing import Any, List, Optional

import gymnasium as gym
import tree # pip install dm_tree

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.spaces.space_utils import batch
from ray.rllib.utils.typing import EpisodeType


class _FrameStackingConnector(ConnectorV2):
"""A connector piece that stacks the previous n observations into one."""

def __init__(
self,
*,
# Base class constructor args.
input_observation_space: gym.Space,
input_action_space: gym.Space,
# Specific framestacking args.
num_frames: int = 1,
as_learner_connector: bool = False,
**kwargs,
):
"""Initializes a _FrameStackingConnector instance.

Args:
num_frames: The number of observation frames to stack up (into a single
observation) for the RLModule's forward pass.
as_preprocessor: Whether this connector should simply postprocess the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not defined in the signature?

received observations from the env and store these directly in the
episode object. In this mode, the connector can only be used in
an `EnvToModulePipeline` and it will act as a classic
RLlib framestacking postprocessor.
as_learner_connector: Whether this connector is part of a Learner connector
pipeline, as opposed to an env-to-module pipeline.
"""
super().__init__(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
**kwargs,
)

self.num_frames = num_frames
self.as_learner_connector = as_learner_connector

# Some assumptions: Space is box AND last dim (the stacking one) is 1.
assert isinstance(self.observation_space, gym.spaces.Box)
assert self.observation_space.shape[-1] == 1

# Change our observation space according to the given stacking settings.
self.observation_space = gym.spaces.Box(
low=np.repeat(self.observation_space.low, repeats=self.num_frames, axis=-1),
high=np.repeat(
self.observation_space.high, repeats=self.num_frames, axis=-1
),
shape=list(self.observation_space.shape)[:-1] + [self.num_frames],
dtype=self.observation_space.dtype,
)

@override(ConnectorV2)
def __call__(
self,
*,
rl_module: RLModule,
input_: Optional[Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
persistent_data: Optional[dict] = None,
**kwargs,
) -> Any:
# This is a data-in-data-out connector, so we expect `input_` to be a dict
# with: key=column name, e.g. "obs" and value=[data to be processed by
# RLModule]. We will add to `input_` the last n observations.
observations = []

# Learner connector pipeline. Episodes have been finalized/numpy'ized.
if self.as_learner_connector:
for episode in episodes:

def _map_fn(s):
# Squeeze out last dim.
s = np.squeeze(s, axis=-1)
# Calculate new shape and strides
new_shape = (len(episode), self.num_frames) + s.shape[1:]
new_strides = (s.strides[0],) + s.strides
# Create a strided view of the array.
return np.lib.stride_tricks.as_strided(
s, shape=new_shape, strides=new_strides
)

# Get all observations from the episode in one np array (except for
# the very last one, which is the final observation not needed for
# learning).
observations.append(
tree.map_structure(
_map_fn,
episode.get_observations(
indices=slice(-self.num_frames + 1, len(episode)),
neg_indices_left_of_zero=True,
fill=0.0,
),
)
)

# Move stack-dimension to the end and concatenate along batch axis.
input_[SampleBatch.OBS] = tree.map_structure(
lambda *s: np.transpose(np.concatenate(s, axis=0), axes=[0, 2, 3, 1]),
*observations,
)

# Env-to-module pipeline. Episodes still operate on lists.
else:
for episode in episodes:
assert not episode.is_finalized
# Get the list of observations to stack.
obs_stack = episode.get_observations(
indices=slice(-self.num_frames, None),
fill=0.0,
)
# Observation components are (w, h, 1)
# -> stack to (w, h, [num_frames], 1), then squeeze out last dim to get
# (w, h, [num_frames]).
stacked_obs = tree.map_structure(
lambda *s: np.squeeze(np.stack(s, axis=2), axis=-1),
*obs_stack,
)
observations.append(stacked_obs)

input_[SampleBatch.OBS] = batch(observations)

return input_
Loading
Loading