Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib - Offline RL] Allow incomplete SampleBatch data and fully compressed observations. #48699

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 80 additions & 44 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import gymnasium as gym
import logging
import numpy as np
import random
import uuid

from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING

import ray
from ray.actor import ActorHandle
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner import Learner
Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(
self,
config: "AlgorithmConfig",
learner: Union[Learner, list[ActorHandle]],
locality_hints: Optional[List[str]] = None,
spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
locality_hints: Optional[list] = None,
module_spec: Optional[MultiRLModuleSpec] = None,
module_state: Optional[Dict[ModuleID, Any]] = None,
):
Expand All @@ -103,24 +103,6 @@ def __init__(
self._module = self._learner._module
# Otherwise we have remote `Learner`s.
else:
# TODO (simon): Check with the data team how to get at
# initialization the data block location.
node_id = ray.get_runtime_context().get_node_id()
# Shuffle indices such that not each data block syncs weights
# with the same learner in case there are multiple learners
# on the same node like the `PreLearner`.
indices = list(range(len(locality_hints)))
random.shuffle(indices)
locality_hints = [locality_hints[i] for i in indices]
learner = [learner[i] for i in indices]
# Choose a learner from the same node.
for i, hint in enumerate(locality_hints):
if hint == node_id:
self._learner = learner[i]
# If no learner has been chosen, there is none on the same node.
if not self._learner:
# Then choose a learner randomly.
self._learner = learner[random.randint(0, len(learner) - 1)]
self.learner_is_remote = True
# Build the module from spec. Note, this will be a MultiRLModule.
self._module = module_spec.build()
Expand Down Expand Up @@ -525,21 +507,83 @@ def _map_sample_batch_to_episode(
# TODO (simon): Add support for multi-agent episodes.
NotImplementedError
else:
# Unpack observations, if needed.
obs = (
unpack_if_needed(obs.tolist())
if schema[Columns.OBS] in input_compress_columns
else obs.tolist()
)
# Append the last `new_obs` to get the correct length of observations.
obs.append(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
if schema[Columns.OBS] in input_compress_columns
else batch[schema[Columns.NEXT_OBS]][i][-1]
)
# Unpack observations, if needed. Note, observations could
# be either compressed by their entirety (the complete batch
# column) or individually (each column entry).
if isinstance(obs, str):
# Decompress the observations if we have a string, i.e.
# observations are compressed in their entirety.
obs = unpack_if_needed(obs)
# Convert to a list of arrays. This is needed as input by
# the `SingleAgentEpisode`.
obs = [obs[i, ...] for i in range(obs.shape[0])]
# Otherwise observations are only compressed inside of the
# batch column (if at all).
elif isinstance(obs, np.ndarray):
# Unpack observations, if they are compressed otherwise we
# simply convert to a list, which is needed by the
# `SingleAgentEpisode`.
obs = (
unpack_if_needed(obs.tolist())
if schema[Columns.OBS] in input_compress_columns
else obs.tolist()
)
else:
raise TypeError(
f"Unknown observation type: {type(obs)}. When mapping "
"from old recorded `SampleBatches` batched "
"observations should be either of type `np.array` "
"or - if the column is compressed - of `str` type."
)

if schema[Columns.NEXT_OBS] in batch:
# Append the last `new_obs` to get the correct length of
# observations.
obs.append(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
if schema[Columns.OBS] in input_compress_columns
else batch[schema[Columns.NEXT_OBS]][i][-1]
)
else:
# Otherwise we duplicate the last observation.
obs.append(obs[-1])

# Check, if we have `done`, `truncated`, or `terminated`s in
# the batch.
if (
schema[Columns.TRUNCATEDS] in batch
and schema[Columns.TERMINATEDS] in batch
):
truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
elif (
schema[Columns.TRUNCATEDS] in batch
and schema[Columns.TERMINATEDS] not in batch
):
truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
terminated = False
elif (
schema[Columns.TRUNCATEDS] not in batch
and schema[Columns.TERMINATEDS] in batch
):
terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
truncated = False
elif "done" in batch:
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
terminated = batch["done"][i][-1]
truncated = False
# Otherwise, if no `terminated`, nor `truncated` nor `done`
# is given, we consider the episode as terminated.
else:
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
terminated = True
truncated = False

# Create a `SingleAgentEpisode`.
episode = SingleAgentEpisode(
id_=str(batch[schema[Columns.EPS_ID]][i][0]),
# If the recorded episode has an ID we use this ID,
# otherwise we generate a new one.
id_=str(batch[schema[Columns.EPS_ID]][i][0])
if schema[Columns.EPS_ID] in batch
else uuid.uuid4().hex,
agent_id=agent_id,
observations=obs,
infos=(
Expand All @@ -554,16 +598,8 @@ def _map_sample_batch_to_episode(
else batch[schema[Columns.ACTIONS]][i]
),
rewards=batch[schema[Columns.REWARDS]][i],
terminated=(
any(batch[schema[Columns.TERMINATEDS]][i])
if schema[Columns.TERMINATEDS] in batch
else any(batch["dones"][i])
),
truncated=(
any(batch[schema[Columns.TRUNCATEDS]][i])
if schema[Columns.TRUNCATEDS] in batch
else False
),
terminated=terminated,
truncated=truncated,
# TODO (simon): Results in zero-length episodes in connector.
# t_started=batch[Columns.T if Columns.T in batch else
# "unroll_id"][i][0],
Expand Down
2 changes: 1 addition & 1 deletion rllib/offline/tests/test_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_sample_multiple_learners(self):
num_samples=10, return_iterator=2, num_shards=2
)
self.assertIsInstance(batch, list)
# Ensure we have indeed two such `SStreamSplitDataIterator` instances.
# Ensure we have indeed two such `StreamSplitDataIterator` instances.
self.assertEqual(len(batch), 2)
from ray.data._internal.iterator.stream_split_iterator import (
StreamSplitDataIterator,
Expand Down