-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New ConnectorV2 API #03: Introduce actual ConnectorV2
API. (#41074)
#41212
Merged
sven1977
merged 25 commits into
ray-project:master
from
sven1977:env_runner_support_connectors_03_connectorv2_api
Dec 21, 2023
Merged
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
57e79f9
wip
sven1977 99d9019
wip
sven1977 d3dca2f
wip
sven1977 009a7fd
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 b84b544
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 b0b3c37
LINT
sven1977 4df7dfe
wip
sven1977 1de7ebb
wip
sven1977 a9acbee
wip
sven1977 5fe97e1
LINT
sven1977 213f0d1
LINT
sven1977 50b7fc6
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 90e9c34
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 91b4399
Merge branch 'master' into env_runner_support_connectors_03_connector…
sven1977 3102238
merge
sven1977 7618d52
Merge remote-tracking branch 'origin/env_runner_support_connectors_03…
sven1977 4958597
wip
sven1977 c40f5b0
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 bdf803d
wip
sven1977 2649e70
wip
sven1977 34f8827
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 b58ad31
wip
sven1977 7bc0ac6
wip
sven1977 8b8cf06
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 f7dde73
wip
sven1977 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import numpy as np | ||
from typing import Any, List, Optional | ||
|
||
import gymnasium as gym | ||
import tree # pip install dm_tree | ||
|
||
from ray.rllib.connectors.connector_v2 import ConnectorV2 | ||
from ray.rllib.core.rl_module.rl_module import RLModule | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.spaces.space_utils import batch | ||
from ray.rllib.utils.typing import EpisodeType | ||
|
||
|
||
class _FrameStackingConnector(ConnectorV2): | ||
"""A connector piece that stacks the previous n observations into one.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
# Base class constructor args. | ||
input_observation_space: gym.Space, | ||
input_action_space: gym.Space, | ||
# Specific framestacking args. | ||
num_frames: int = 1, | ||
as_learner_connector: bool = False, | ||
**kwargs, | ||
): | ||
"""Initializes a _FrameStackingConnector instance. | ||
|
||
Args: | ||
num_frames: The number of observation frames to stack up (into a single | ||
observation) for the RLModule's forward pass. | ||
as_preprocessor: Whether this connector should simply postprocess the | ||
received observations from the env and store these directly in the | ||
episode object. In this mode, the connector can only be used in | ||
an `EnvToModulePipeline` and it will act as a classic | ||
RLlib framestacking postprocessor. | ||
as_learner_connector: Whether this connector is part of a Learner connector | ||
pipeline, as opposed to an env-to-module pipeline. | ||
""" | ||
super().__init__( | ||
input_observation_space=input_observation_space, | ||
input_action_space=input_action_space, | ||
**kwargs, | ||
) | ||
|
||
self.num_frames = num_frames | ||
self.as_learner_connector = as_learner_connector | ||
|
||
# Some assumptions: Space is box AND last dim (the stacking one) is 1. | ||
assert isinstance(self.observation_space, gym.spaces.Box) | ||
assert self.observation_space.shape[-1] == 1 | ||
|
||
# Change our observation space according to the given stacking settings. | ||
self.observation_space = gym.spaces.Box( | ||
low=np.repeat(self.observation_space.low, repeats=self.num_frames, axis=-1), | ||
high=np.repeat( | ||
self.observation_space.high, repeats=self.num_frames, axis=-1 | ||
), | ||
shape=list(self.observation_space.shape)[:-1] + [self.num_frames], | ||
dtype=self.observation_space.dtype, | ||
) | ||
|
||
@override(ConnectorV2) | ||
def __call__( | ||
self, | ||
*, | ||
rl_module: RLModule, | ||
input_: Optional[Any], | ||
episodes: List[EpisodeType], | ||
explore: Optional[bool] = None, | ||
persistent_data: Optional[dict] = None, | ||
**kwargs, | ||
) -> Any: | ||
# This is a data-in-data-out connector, so we expect `input_` to be a dict | ||
# with: key=column name, e.g. "obs" and value=[data to be processed by | ||
# RLModule]. We will add to `input_` the last n observations. | ||
observations = [] | ||
|
||
# Learner connector pipeline. Episodes have been finalized/numpy'ized. | ||
if self.as_learner_connector: | ||
for episode in episodes: | ||
|
||
def _map_fn(s): | ||
# Squeeze out last dim. | ||
s = np.squeeze(s, axis=-1) | ||
# Calculate new shape and strides | ||
new_shape = (len(episode), self.num_frames) + s.shape[1:] | ||
new_strides = (s.strides[0],) + s.strides | ||
# Create a strided view of the array. | ||
return np.lib.stride_tricks.as_strided( | ||
s, shape=new_shape, strides=new_strides | ||
) | ||
|
||
# Get all observations from the episode in one np array (except for | ||
# the very last one, which is the final observation not needed for | ||
# learning). | ||
observations.append( | ||
tree.map_structure( | ||
_map_fn, | ||
episode.get_observations( | ||
indices=slice(-self.num_frames + 1, len(episode)), | ||
neg_indices_left_of_zero=True, | ||
fill=0.0, | ||
), | ||
) | ||
) | ||
|
||
# Move stack-dimension to the end and concatenate along batch axis. | ||
input_[SampleBatch.OBS] = tree.map_structure( | ||
lambda *s: np.transpose(np.concatenate(s, axis=0), axes=[0, 2, 3, 1]), | ||
*observations, | ||
) | ||
|
||
# Env-to-module pipeline. Episodes still operate on lists. | ||
else: | ||
for episode in episodes: | ||
assert not episode.is_finalized | ||
# Get the list of observations to stack. | ||
obs_stack = episode.get_observations( | ||
indices=slice(-self.num_frames, None), | ||
fill=0.0, | ||
) | ||
# Observation components are (w, h, 1) | ||
# -> stack to (w, h, [num_frames], 1), then squeeze out last dim to get | ||
# (w, h, [num_frames]). | ||
stacked_obs = tree.map_structure( | ||
lambda *s: np.squeeze(np.stack(s, axis=2), axis=-1), | ||
*obs_stack, | ||
) | ||
observations.append(stacked_obs) | ||
|
||
input_[SampleBatch.OBS] = batch(observations) | ||
|
||
return input_ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not defined in the signature?