Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Store episodes in state form. #47294

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7f7decb
Added 'get_state' and 'from_state' to 'InfiniteLookbackBuffer' and mo…
simonsays1980 Aug 23, 2024
9cd30ae
Added serialization for 'SingleAgentEpisode' objects via 'msgpack'. A…
simonsays1980 Aug 26, 2024
7c5f67a
Merge branch 'master' into offline-store-episodes-in-state-form
simonsays1980 Aug 26, 2024
37d37a0
Added 'msgpack' and 'msgpack_numpy' to the 'rllib-test-requirements.txt.
simonsays1980 Aug 26, 2024
13e4d71
Merge branch 'master' into offline-store-episodes-in-state-form
simonsays1980 Aug 27, 2024
8fb8a38
Changed some code in response to @sven1977's review.
simonsays1980 Aug 27, 2024
b1a924a
Merge branch 'master' into offline-store-episodes-in-state-form
simonsays1980 Aug 27, 2024
54b5d2d
Modified comment.
simonsays1980 Aug 30, 2024
ffeb244
Added 'msgpack' and 'msgpack_numpy' to the 'autodoc_mock_imports'.
simonsays1980 Aug 30, 2024
0e9260f
Merged master.
simonsays1980 Aug 30, 2024
bcb9fb7
Fixed name error for 'msgpack-numpy' in requirements.
simonsays1980 Aug 30, 2024
cdfea4a
Merge branch 'master' into offline-store-episodes-in-state-form
simonsays1980 Sep 2, 2024
0426889
Added 'msgpack' and 'msgpack-numpy' to the 'setup.py'.
simonsays1980 Sep 4, 2024
5d6a213
Merge branch 'master' into offline-store-episodes-in-state-form
simonsays1980 Sep 6, 2024
c6a862d
Removed dependencies on 'msgpack_numpy' from 'setup.py'. Modified imp…
simonsays1980 Sep 6, 2024
8ca8411
Merge branch 'master' into offline-store-episodes-in-state-form
simonsays1980 Sep 9, 2024
53ba4f5
fix
sven1977 Sep 9, 2024
bc79bc8
LINT
sven1977 Sep 9, 2024
2cb5a7c
fix
sven1977 Sep 9, 2024
9be3266
Merged master.
simonsays1980 Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/requirements/ml/rllib-test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ h5py==3.10.0

# Requirements for rendering.
moviepy

# Requirements for offline data recording and reading
msgpack
msgpack_numpy
82 changes: 61 additions & 21 deletions rllib/env/single_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ray.rllib.core.columns import Columns
from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict
from ray.rllib.utils.typing import AgentID, ModuleID
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -323,7 +324,7 @@ def __init__(
if isinstance(v, InfiniteLookbackBuffer):
self.extra_model_outputs[k] = v
else:
# We cannot use the defaultdict's own constructore here as this would
# We cannot use the defaultdict's own constructor here as this would
# auto-set the lookback buffer to 0 (there is no data passed to that
# constructor). Then, when we manually have to set the data property,
# the lookback buffer would still be (incorrectly) 0.
Expand Down Expand Up @@ -1695,32 +1696,44 @@ def get_state(self) -> Dict[str, Any]:
"""Returns the pickable state of an episode.

The data in the episode is stored into a dictionary. Note that episodes
can also be generated from states (see `self.from_state()`).
can also be generated from states (see `SingleAgentEpisode.from_state()`).

Returns:
A dict containing all the data from the episode.
"""
infos = self.infos.get_state()
infos["data"] = np.array([info if info else None for info in infos["data"]])
return {
"id_": self.id_,
"agent_id": self.agent_id,
"module_id": self.module_id,
"multi_agent_episode_id": self.multi_agent_episode_id,
# TODO (simon): Check, if we need to have a `get_state` method for
# `InfiniteLookbackBuffer` and call it here.
"observations": self.observations,
"actions": self.actions,
"rewards": self.rewards,
"infos": self.infos,
"extra_model_outputs": self.extra_model_outputs,
# Note, all data is stored in `InfiniteLookbackBuffer`s.
"observations": self.observations.get_state(),
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
"actions": self.actions.get_state(),
"rewards": self.rewards.get_state(),
"infos": self.infos.get_state(),
"extra_model_outputs": {
k: v.get_state() if v else v
for k, v in self.extra_model_outputs.items()
}
if len(self.extra_model_outputs) > 0
else None,
"is_terminated": self.is_terminated,
"is_truncated": self.is_truncated,
"t_started": self.t_started,
"t": self.t,
"_observation_space": self._observation_space,
"_action_space": self._action_space,
"_observation_space": gym_space_to_dict(self._observation_space)
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
if self._observation_space
else self._observation_space,
"_action_space": gym_space_to_dict(self._action_space)
if self._action_space
else self._action_space,
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
"_start_time": self._start_time,
"_last_step_time": self._last_step_time,
"_temporary_timestep_data": self._temporary_timestep_data,
"_temporary_timestep_data": dict(self._temporary_timestep_data)
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
if len(self._temporary_timestep_data) > 0
else None,
}

@staticmethod
Expand All @@ -1739,21 +1752,48 @@ def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode":
episode.agent_id = state["agent_id"]
episode.module_id = state["module_id"]
episode.multi_agent_episode_id = state["multi_agent_episode_id"]
episode.observations = state["observations"]
episode.actions = state["actions"]
episode.rewards = state["rewards"]
episode.infos = state["infos"]
episode.extra_model_outputs = state["extra_model_outputs"]
# Convert data back to `InfiniteLookbackBuffer`s.
episode.observations = InfiniteLookbackBuffer.from_state(state["observations"])
episode.actions = InfiniteLookbackBuffer.from_state(state["actions"])
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
episode.rewards = InfiniteLookbackBuffer.from_state(state["rewards"])
episode.infos = InfiniteLookbackBuffer.from_state(state["infos"])
episode.extra_model_outputs = (
defaultdict(
functools.partial(
InfiniteLookbackBuffer, lookback=episode.observations.lookback
),
{
k: InfiniteLookbackBuffer.from_state(v)
for k, v in state["extra_model_outputs"].items()
},
)
if state["extra_model_outputs"]
else defaultdict(
functools.partial(
InfiniteLookbackBuffer, lookback=episode.observations.lookback
),
)
)
episode.is_terminated = state["is_terminated"]
episode.is_truncated = state["is_truncated"]
episode.t_started = state["t_started"]
episode.t = state["t"]
episode._observation_space = state["_observation_space"]
episode._action_space = state["_action_space"]
# We need to convert the spaces to dictionaries for serialization.
episode._observation_space = (
gym_space_from_dict(state["_observation_space"])
if state["_observation_space"]
else state["_observation_space"]
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
)
episode._action_space = (
gym_space_from_dict(state["_action_space"])
if state["_action_space"]
else state["_action_space"]
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
)
episode._start_time = state["_start_time"]
episode._last_step_time = state["_last_step_time"]
episode._temporary_timestep_data = state["_temporary_timestep_data"]

episode._temporary_timestep_data = defaultdict(
list, state["_temporary_timestep_data"] or {}
)
# Validate the episode.
episode.validate()

Expand Down
74 changes: 73 additions & 1 deletion rllib/env/utils/infinite_lookback_buffer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import gymnasium as gym
import numpy as np
import tree # pip install dm_tree

from ray.rllib.utils.numpy import LARGE_INTEGER, one_hot, one_hot_multidiscrete
from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict
from ray.rllib.utils.spaces.space_utils import (
batch,
get_dummy_batch_for_space,
Expand Down Expand Up @@ -34,6 +35,77 @@ def __init__(
self.space_struct = None
self.space = space

def __eq__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question: why do we need this method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for testing purposes to make comparisons of instances before writing and after writing.

self,
other: "InfiniteLookbackBuffer",
) -> bool:
"""Compares two `InfiniteLookbackBuffers.

Args:
other: Another object. If another `LookbackBuffer` instance all
their attributes are compared.

Returns:
`True`, if `other` is an `InfiniteLookbackBuffer` instance and all
attributes are identical. Otherwise, returns `False`.
"""
if isinstance(other, InfiniteLookbackBuffer):
if (
self.data == other.data
and self.lookback == other.lookback
and self.finalized == other.finalized
and self.space_struct == other.space_struct
and self.space == other.space
):
return True
return False

def get_state(self) -> Dict[str, Any]:
"""Returns the pickable state of a buffer.

The data in the buffer is stored into a dictionary. Note that
buffers can also be generated from pickable states (see
`InfiniteLookbackBuffer.from_state`)

Returns:
A dict containing all the data and metadata from the buffer.
"""
return {
"data": self.data,
"lookback": self.lookback,
"finalized": self.finalized,
"space_struct": gym_space_to_dict(self.space_struct)
if self.space_struct
else self.space_struct,
"space": gym_space_to_dict(self.space) if self.space else self.space,
}

@staticmethod
def from_state(state: Dict[str, Any]) -> None:
"""Creates a new `InfiniteLookbackBuffer` from a state dict.

Args:
state: The state dict, as returned by `self.get_state`.

Returns:
A new `InfiniteLookbackBuffer` instance with the data and metadata
from the state dict.
"""
buffer = InfiniteLookbackBuffer()
buffer.data = state["data"]
buffer.lookback = state["lookback"]
buffer.finalized = state["finalized"]
buffer.space_struct = (
gym_space_from_dict(state["space_struct"])
if state["space_struct"]
else state["space_struct"]
)
buffer.space = (
gym_space_from_dict(state["space"]) if state["space"] else state["space"]
)

return buffer

def append(self, item) -> None:
"""Appends the given item to the end of this buffer."""
if self.finalized:
Expand Down
15 changes: 14 additions & 1 deletion rllib/offline/offline_env_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import msgpack
import msgpack_numpy as m
import ray

from pathlib import Path
Expand All @@ -13,6 +15,7 @@
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.spaces.space_utils import to_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType
from ray.util.debug import log_once

logger = logging.Logger(__file__)

Expand Down Expand Up @@ -124,7 +127,17 @@ def sample(

# Add data to the buffers.
if self.output_write_episodes:
self._samples.extend(samples)
if log_once("msgpack"):
logger.info(
"Packing episodes with `msgpack` and encode array with "
"`msg_pack-numpy` for serialization. This is needed for "
"recording episodes."
)
# Note, we serialize episodes with `msgpack` and `msgpack_numpy` to
# ensure version compatibility.
self._samples.extend(
[msgpack.packb(eps.get_state(), default=m.encode) for eps in samples]
)
else:
self._map_episodes_to_data(samples)

Expand Down
9 changes: 8 additions & 1 deletion rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import gymnasium as gym
import msgpack
import msgpack_numpy as m
import numpy as np
import random
import ray
Expand Down Expand Up @@ -143,7 +145,12 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]

# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
episodes = batch["item"].tolist()
episodes = [
SingleAgentEpisode.from_state(
msgpack.unpackb(state, object_hook=m.decode)
)
for state in batch["item"]
]
# Otherwise we ap the batch to episodes.
else:
episodes = self._map_to_episodes(
Expand Down
11 changes: 10 additions & 1 deletion rllib/offline/tests/test_offline_env_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import msgpack
import msgpack_numpy as m
import pathlib
import shutil
import unittest
Expand Down Expand Up @@ -81,7 +83,14 @@ def test_offline_env_runner_record_episodes(self):
# Assert the dataset has only 100 rows (each row containing an episode).
self.assertEqual(offline_data.data.count(), 100)
# Take a single row and ensure its a `SingleAgentEpisode` instance.
self.assertIsInstance(offline_data.data.take(1)[0]["item"], SingleAgentEpisode)
self.assertIsInstance(
SingleAgentEpisode.from_state(
msgpack.unpackb(
offline_data.data.take(1)[0]["item"], object_hook=m.decode
)
),
SingleAgentEpisode,
)
# The batch contains now episodes (in a numpy.NDArray).
episodes = offline_data.data.take_batch(100)["item"]
# The batch should contain 100 episodes (not 100 env steps).
Expand Down