Skip to content

Commit

Permalink
[RLlib; Offline RL] Store episodes in state form. (ray-project#47294)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent cbe6687 commit 53e641a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
19 changes: 10 additions & 9 deletions rllib/env/utils/infinite_lookback_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ def get_state(self) -> Dict[str, Any]:
A dict containing all the data and metadata from the buffer.
"""
return {
"data": to_jsonable_if_needed(self.data, self.space)
if self.space
else self.data,
"data": self.data,
"lookback": self.lookback,
"finalized": self.finalized,
"space_struct": gym_space_to_dict(self.space_struct)
if self.space_struct
else self.space_struct,
"space": gym_space_to_dict(self.space) if self.space else self.space,
}

Expand All @@ -93,16 +94,16 @@ def from_state(state: Dict[str, Any]) -> None:
from the state dict.
"""
buffer = InfiniteLookbackBuffer()
buffer.data = state["data"]
buffer.lookback = state["lookback"]
buffer.finalized = state["finalized"]
buffer.space = gym_space_from_dict(state["space"]) if state["space"] else None
buffer.space_struct = (
get_base_struct_from_space(buffer.space) if buffer.space else None
gym_space_from_dict(state["space_struct"])
if state["space_struct"]
else state["space_struct"]
)
buffer.data = (
from_jsonable_if_needed(state["data"], buffer.space)
if buffer.space
else state["data"]
buffer.space = (
gym_space_from_dict(state["space"]) if state["space"] else state["space"]
)

return buffer
Expand Down
12 changes: 11 additions & 1 deletion rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,17 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]

# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
episodes = batch["item"].tolist()
# Import `msgpack` for decoding.
import msgpack
import msgpack_numpy as mnp

# Read the episodes and decode them.
episodes = [
SingleAgentEpisode.from_state(
msgpack.unpackb(state, object_hook=mnp.decode)
)
for state in batch["item"]
]
# Else, if we have old stack `SampleBatch`es.
elif self.input_read_sample_batches:
episodes = OfflinePreLearner._map_sample_batch_to_episode(
Expand Down

0 comments on commit 53e641a

Please sign in to comment.