Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Support writing and reading composite spaces samples. #47046

4 changes: 4 additions & 0 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def setup(self, config: AlgorithmConfig) -> None:
else:
self.offline_data.learner_handles = [self.learner_group._learner]

# Provide the `OfflineData` instance with space information. It might
# need it for reading recorded experiences.
self.offline_data.spaces = self.env_runner_group.get_spaces()

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self, metrics_logger=self.metrics)

Expand Down
13 changes: 11 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
Expand Down Expand Up @@ -431,6 +432,8 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.input_spaces_jsonable = True
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
Expand All @@ -442,7 +445,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.shuffle_buffer_size = 0
self.output = None
self.output_config = {}
self.output_compress_columns = ["obs", "new_obs"]
self.output_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.output_max_file_size = 64 * 1024 * 1024
self.output_max_rows_per_file = None
self.output_write_method = "write_parquet"
Expand Down Expand Up @@ -2385,6 +2388,7 @@ def offline_data(
input_read_method: Optional[Union[str, Callable]] = NotProvided,
input_read_method_kwargs: Optional[Dict] = NotProvided,
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_compress_columns: Optional[List[str]] = NotProvided,
map_batches_kwargs: Optional[Dict] = NotProvided,
iter_batches_kwargs: Optional[Dict] = NotProvided,
prelearner_class: Optional[Type] = NotProvided,
Expand All @@ -2396,7 +2400,7 @@ def offline_data(
shuffle_buffer_size: Optional[int] = NotProvided,
output: Optional[str] = NotProvided,
output_config: Optional[Dict] = NotProvided,
output_compress_columns: Optional[bool] = NotProvided,
output_compress_columns: Optional[List[str]] = NotProvided,
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
output_max_file_size: Optional[float] = NotProvided,
output_max_rows_per_file: Optional[int] = NotProvided,
output_write_method: Optional[str] = NotProvided,
Expand Down Expand Up @@ -2437,6 +2441,9 @@ def offline_data(
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed.
intput_compress_columns: What input columns are compressed with LZ4 in the
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
intput data. If data is stored in `RLlib`'s `SingleAgentEpisode` (
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
`MultiAgentEpisode` not supported, yet).
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
Expand Down Expand Up @@ -2528,6 +2535,8 @@ def offline_data(
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if input_compress_columns is not NotProvided:
self.input_compress_columns = input_compress_columns
if map_batches_kwargs is not NotProvided:
self.map_batches_kwargs = map_batches_kwargs
if iter_batches_kwargs is not NotProvided:
Expand Down
2 changes: 2 additions & 0 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def sample(
fn_constructor_kwargs={
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces["__env__"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we merge the other (__env__ constant) PR into this one and then push again. This way, we won't forget to change this here.

},
batch_size=num_samples,
**self.map_batches_kwargs,
Expand Down Expand Up @@ -106,6 +107,7 @@ def sample(
fn_constructor_kwargs={
"config": self.config,
"learner": self.learner_handles,
"spaces": self.spaces["__env__"],
"locality_hints": self.locality_hints,
"module_spec": self.module_spec,
"module_state": module_state,
Expand Down
27 changes: 21 additions & 6 deletions rllib/offline/offline_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.spaces.space_utils import to_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType

logger = logging.Logger(__file__)
Expand Down Expand Up @@ -209,6 +210,8 @@ def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
samples: List of episodes to be converted.
"""
# Loop through all sampled episodes.
obs_space = self.env.observation_space
action_space = self.env.action_space
for sample in samples:
# Loop through all items of the episode.
for i in range(len(sample)):
Expand All @@ -217,18 +220,30 @@ def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
Columns.AGENT_ID: sample.agent_id,
Columns.MODULE_ID: sample.module_id,
# Compress observations, if requested.
Columns.OBS: pack_if_needed(sample.get_observations(i))
Columns.OBS: pack_if_needed(
to_jsonable_if_needed(sample.get_observations(i), obs_space)
)
if Columns.OBS in self.output_compress_columns
else sample.get_observations(i),
else obs_space.to_jsonable_if_needed(
sample.get_observations(i), obs_space
),
# Compress actions, if requested.
Columns.ACTIONS: pack_if_needed(sample.get_actions(i))
Columns.ACTIONS: pack_if_needed(
to_jsonable_if_needed(sample.get_actions(i), action_space)
)
if Columns.OBS in self.output_compress_columns
else sample.get_actions(i),
else action_space.to_jsonable_if_needed(
sample.get_actions(i), action_space
),
Columns.REWARDS: sample.get_rewards(i),
# Compress next observations, if requested.
Columns.NEXT_OBS: pack_if_needed(sample.get_observations(i + 1))
Columns.NEXT_OBS: pack_if_needed(
to_jsonable_if_needed(sample.get_observations(i + 1), obs_space)
)
if Columns.OBS in self.output_compress_columns
else sample.get_observations(i + 1),
else obs_space.to_jsonable_if_needed(
sample.get_observations(i + 1), obs_space
),
Columns.TERMINATEDS: False
if i < len(sample) - 1
else sample.is_terminated,
Expand Down
73 changes: 65 additions & 8 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gymnasium as gym
import numpy as np
import random
import ray
from ray.actor import ActorHandle
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING

from ray.rllib.core.columns import Columns
from ray.rllib.core.learner import Learner
Expand All @@ -15,6 +16,7 @@
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.spaces.space_utils import from_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType, ModuleID

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(
self,
config: "AlgorithmConfig",
learner: Union[Learner, list[ActorHandle]],
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
locality_hints: Optional[list] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
Expand Down Expand Up @@ -116,10 +119,12 @@ def __init__(
# Build the module from spec. Note, this will be a MultiRLModule.
self._module = module_spec.build()
self._module.set_state(module_state)

self.spaces = spaces or (None, None)
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# Build the learner connector pipeline.
self._learner_connector = self.config.build_learner_connector(
input_observation_space=None,
input_action_space=None,
input_observation_space=self.spaces[0],
input_action_space=self.spaces[1],
)
# Cache the policies to be trained to update weights only for these.
self._policies_to_train = self.config.policies_to_train
Expand All @@ -132,7 +137,13 @@ def __init__(
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]:
# Map the batch to episodes.
episodes = self._map_to_episodes(
self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema
self._is_multi_agent,
batch,
schema=SCHEMA | self.config.input_read_schema,
finalize=False,
input_compress_columns=self.config.input_compress_columns,
observation_space=self.spaces[0],
action_space=self.spaces[1],
)
# TODO (simon): Make synching work. Right now this becomes blocking or never
# receives weights. Learners appear to be non accessable via other actors.
Expand Down Expand Up @@ -208,9 +219,25 @@ def _map_to_episodes(
batch: Dict[str, np.ndarray],
schema: Dict[str, str] = SCHEMA,
finalize: bool = False,
input_compress_columns: Optional[List[str]] = None,
observation_space: gym.Space = None,
action_space: gym.Space = None,
) -> Dict[str, List[EpisodeType]]:
"""Maps a batch of data to episodes."""

# Set to empty list, if `None`.
input_compress_columns = input_compress_columns or []

# If spaces are given, we can use the space-specific
# conversion method to convert space samples.
if observation_space and action_space:
convert = from_jsonable_if_needed
# Otherwise we use an identity function.
else:

def convert(sample, space):
return sample

episodes = []
# TODO (simon): Give users possibility to provide a custom schema.
for i, obs in enumerate(batch[schema[Columns.OBS]]):
Expand Down Expand Up @@ -240,17 +267,39 @@ def _map_to_episodes(
episode = SingleAgentEpisode(
id_=batch[schema[Columns.EPS_ID]][i],
agent_id=agent_id,
# Observations might be (a) serialized and/or (b) converted
# to a JSONable (when a composite space was used). We unserialize
# and then reconvert from JSONable to space sample.
observations=[
unpack_if_needed(obs),
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]),
convert(unpack_if_needed(obs), observation_space)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if observation_space is None here? Could that happen? Or should we make the arg non-optional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could happen and is in fact okay. As it uses from_jsonable_if_needed the conversion simply does not take place, if the input space is None. We need them only, if we want to convert from JSONable data types to a composite space type.

if Columns.OBS in input_compress_columns
else convert(obs, observation_space),
convert(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]),
observation_space,
)
if Columns.NEXT_OBS in input_compress_columns
else convert(
batch[schema[Columns.NEXT_OBS]][i], observation_space
),
],
infos=[
{},
batch[schema[Columns.INFOS]][i]
if schema[Columns.INFOS] in batch
else {},
],
actions=[batch[schema[Columns.ACTIONS]][i]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated question on the above:

  • Why is the first info always {}?
  • I think else {}, would yield a tuple, correct? So the resulting final list would be: [{}, ({},)]. Maybe I'm wrong, but can you check this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question, because there is only a single INFO column in all the offline data. As we need two in the Episode we need to fill in a default and this is filled in at timestep zero.

YOur second point is valid - this could lead to a tuple, even though the comma should be part of the list ... I check this. Thanks1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yeah, makes sense. We don't have a NEXT_INFOS.

Hmm, I wonder whether this could be a general problem in some strange use cases :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also thinking about it, but could not came of with cases that need an info at ts=0.

# Actions might be (a) serialized and/or (b) converted to a JSONable
# (when a composite space was used). We unserializer and then
# reconvert from JSONable to space sample.
actions=[
convert(
unpack_if_needed(batch[schema[Columns.ACTIONS]][i]),
action_space,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for action_space. What if it's None (not provided by caller)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above ;)

)
if Columns.ACTIONS in input_compress_columns
else convert(batch[schema[Columns.ACTIONS]][i], action_space)
],
rewards=[batch[schema[Columns.REWARDS]][i]],
terminated=batch[
schema[Columns.TERMINATEDS]
Expand All @@ -264,8 +313,16 @@ def _map_to_episodes(
# t_started=batch[Columns.T if Columns.T in batch else
# "unroll_id"][i][0],
# TODO (simon): Single-dimensional columns are not supported.
# Extra model outputs might be serialized. We unserialize them here
# if needed.
# TODO (simon): Check, if we need here also reconversion from
# JSONable in case of composite spaces.
extra_model_outputs={
k: [v[i]]
k: [
unpack_if_needed(v[i])
if k in input_compress_columns
else v[i]
]
for k, v in batch.items()
if (k not in schema and k not in schema.values())
},
Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/bc/cartpole_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
# To increase learning speed with multiple learners,
# increase the learning rate correspondingly.
lr=0.0008 * max(1, args.num_gpus**0.5),
train_batch_size_per_learner=2000,
train_batch_size_per_learner=256,
)
)

Expand Down
3 changes: 2 additions & 1 deletion rllib/tuned_examples/bc/cartpole_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
)
.offline_data(
output="local:///tmp/cartpole/",
output_write_episodes=True,
output_write_episodes=False,
output_max_rows_per_file=1000,
output_compress_columns=["obs", "new_obs", "actions"],
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
)
)

Expand Down
74 changes: 74 additions & 0 deletions rllib/utils/spaces/space_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gymnasium as gym
from gymnasium.spaces import Tuple, Dict
from gymnasium.core import ActType, ObsType
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
from ray.rllib.utils.annotations import DeveloperAPI
import tree # pip install dm_tree
Expand Down Expand Up @@ -43,6 +44,79 @@ def get_original_space(space: gym.Space) -> gym.Space:
return space


@DeveloperAPI
def is_composite_space(space: gym.Space) -> bool:
"""Returns true, if the space is composite.

Note, we follow here the glossary of `gymnasium` by which any spoace
that holds other spaces is defined as being 'composite'.

Args:
space: The space to be checked for being composed of other spaces.

Returns:
True, if the space is composed of other spaces, otherwise False.
"""
if type(space) in [
gym.spaces.Dict,
gym.spaces.Graph,
gym.spaces.Sequence,
gym.spaces.Tuple,
]:
return True
else:
return False


@DeveloperAPI
def to_jsonable_if_needed(
sample: Union[ActType, ObsType], space: gym.Space
) -> Union[ActType, ObsType, List]:
"""Returns a jsonabled space sample, if the space is composite.

Checks, if the space is composite and converts the sample to a jsonable
struct in this case. Otherwise return the sample as is.

Args:
sample: Any action or observation type possible in `gymnasium`.
space: Any space defined in `gymnasium.spaces`.

Returns:
The `sample` as-is, if the `space` is composite, otherwise converts the
composite sample to a JSONable data type.
"""

if is_composite_space(space):
return space.to_jsonable([sample])
else:
return sample


@DeveloperAPI
def from_jsonable_if_needed(
sample: Union[ActType, ObsType], space: gym.Space
) -> Union[ActType, ObsType, List]:
"""Returns a jsonabled space sample, if the space is composite.

Checks, if the space is composite and converts the sample to a JSONable
struct in this case. Otherwise return the sample as is.

Args:
sample: Any action or observation type possible in `gymnasium`, or a
JSONable data type.
space: Any space defined in `gymnasium.spaces`.

Returns:
The `sample` as-is, if the `space` is not composite, otherwise converts the
composite sample jsonable to an actual `space` sample..
"""

if is_composite_space(space):
return space.from_jsonable(sample)[0]
else:
return sample


@DeveloperAPI
def flatten_space(space: gym.Space) -> List[gym.Space]:
"""Flattens a gym.Space into its primitive components.
Expand Down