Skip to content

Commit

Permalink
[RLlib; Offline RL] Allow incomplete SampleBatch data and fully com…
Browse files Browse the repository at this point in the history
…pressed observations. (ray-project#48699)

Signed-off-by: hjiang <dentinyhao@gmail.com>
  • Loading branch information
simonsays1980 authored and dentiny committed Dec 7, 2024
1 parent 1b112dc commit ba354ae
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
124 changes: 80 additions & 44 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import gymnasium as gym
import logging
import numpy as np
import random
import uuid

from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING

import ray
from ray.actor import ActorHandle
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner import Learner
Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(
self,
config: "AlgorithmConfig",
learner: Union[Learner, list[ActorHandle]],
locality_hints: Optional[List[str]] = None,
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
locality_hints: Optional[list] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
):
Expand All @@ -103,24 +103,6 @@ def __init__(
self._module = self._learner._module
# Otherwise we have remote `Learner`s.
else:
# TODO (simon): Check with the data team how to get at
# initialization the data block location.
node_id = ray.get_runtime_context().get_node_id()
# Shuffle indices such that not each data block syncs weights
# with the same learner in case there are multiple learners
# on the same node like the `PreLearner`.
indices = list(range(len(locality_hints)))
random.shuffle(indices)
locality_hints = [locality_hints[i] for i in indices]
learner = [learner[i] for i in indices]
# Choose a learner from the same node.
for i, hint in enumerate(locality_hints):
if hint == node_id:
self._learner = learner[i]
# If no learner has been chosen, there is none on the same node.
if not self._learner:
# Then choose a learner randomly.
self._learner = learner[random.randint(0, len(learner) - 1)]
self.learner_is_remote = True
# Build the module from spec. Note, this will be a MultiRLModule.
self._module = module_spec.build()
Expand Down Expand Up @@ -525,21 +507,83 @@ def _map_sample_batch_to_episode(
# TODO (simon): Add support for multi-agent episodes.
NotImplementedError
else:
# Unpack observations, if needed.
obs = (
unpack_if_needed(obs.tolist())
if schema[Columns.OBS] in input_compress_columns
else obs.tolist()
)
# Append the last `new_obs` to get the correct length of observations.
obs.append(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
if schema[Columns.OBS] in input_compress_columns
else batch[schema[Columns.NEXT_OBS]][i][-1]
)
# Unpack observations, if needed. Note, observations could
# be either compressed by their entirety (the complete batch
# column) or individually (each column entry).
if isinstance(obs, str):
# Decompress the observations if we have a string, i.e.
# observations are compressed in their entirety.
obs = unpack_if_needed(obs)
# Convert to a list of arrays. This is needed as input by
# the `SingleAgentEpisode`.
obs = [obs[i, ...] for i in range(obs.shape[0])]
# Otherwise observations are only compressed inside of the
# batch column (if at all).
elif isinstance(obs, np.ndarray):
# Unpack observations, if they are compressed otherwise we
# simply convert to a list, which is needed by the
# `SingleAgentEpisode`.
obs = (
unpack_if_needed(obs.tolist())
if schema[Columns.OBS] in input_compress_columns
else obs.tolist()
)
else:
raise TypeError(
f"Unknown observation type: {type(obs)}. When mapping "
"from old recorded `SampleBatches` batched "
"observations should be either of type `np.array` "
"or - if the column is compressed - of `str` type."
)

if schema[Columns.NEXT_OBS] in batch:
# Append the last `new_obs` to get the correct length of
# observations.
obs.append(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
if schema[Columns.OBS] in input_compress_columns
else batch[schema[Columns.NEXT_OBS]][i][-1]
)
else:
# Otherwise we duplicate the last observation.
obs.append(obs[-1])

# Check, if we have `done`, `truncated`, or `terminated`s in
# the batch.
if (
schema[Columns.TRUNCATEDS] in batch
and schema[Columns.TERMINATEDS] in batch
):
truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
elif (
schema[Columns.TRUNCATEDS] in batch
and schema[Columns.TERMINATEDS] not in batch
):
truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
terminated = False
elif (
schema[Columns.TRUNCATEDS] not in batch
and schema[Columns.TERMINATEDS] in batch
):
terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
truncated = False
elif "done" in batch:
terminated = batch["done"][i][-1]
truncated = False
# Otherwise, if no `terminated`, nor `truncated` nor `done`
# is given, we consider the episode as terminated.
else:
terminated = True
truncated = False

# Create a `SingleAgentEpisode`.
episode = SingleAgentEpisode(
id_=str(batch[schema[Columns.EPS_ID]][i][0]),
# If the recorded episode has an ID we use this ID,
# otherwise we generate a new one.
id_=str(batch[schema[Columns.EPS_ID]][i][0])
if schema[Columns.EPS_ID] in batch
else uuid.uuid4().hex,
agent_id=agent_id,
observations=obs,
infos=(
Expand All @@ -554,16 +598,8 @@ def _map_sample_batch_to_episode(
else batch[schema[Columns.ACTIONS]][i]
),
rewards=batch[schema[Columns.REWARDS]][i],
terminated=(
any(batch[schema[Columns.TERMINATEDS]][i])
if schema[Columns.TERMINATEDS] in batch
else any(batch["dones"][i])
),
truncated=(
any(batch[schema[Columns.TRUNCATEDS]][i])
if schema[Columns.TRUNCATEDS] in batch
else False
),
terminated=terminated,
truncated=truncated,
# TODO (simon): Results in zero-length episodes in connector.
# t_started=batch[Columns.T if Columns.T in batch else
# "unroll_id"][i][0],
Expand Down
2 changes: 1 addition & 1 deletion rllib/offline/tests/test_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_sample_multiple_learners(self):
num_samples=10, return_iterator=2, num_shards=2
)
self.assertIsInstance(batch, list)
# Ensure we have indeed two such `SStreamSplitDataIterator` instances.
# Ensure we have indeed two such `StreamSplitDataIterator` instances.
self.assertEqual(len(batch), 2)
from ray.data._internal.iterator.stream_split_iterator import (
StreamSplitDataIterator,
Expand Down

0 comments on commit ba354ae

Please sign in to comment.