Skip to content

Commit

Permalink
[RLlib; Offline RL] - Enable reading old-stack SampleBatch data in …
Browse files Browse the repository at this point in the history
…new stack Offline RL. (ray-project#47359)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent 063d004 commit 9daf658
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 97 deletions.
14 changes: 1 addition & 13 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1697,19 +1697,7 @@ py_test(
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
]
)

py_test(
name = "test_offline_prelearner",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_prelearner.py"],
# Include the offline data files.
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
]
],
)

# --------------------------------------------------------------------
Expand Down
25 changes: 5 additions & 20 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,6 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_schema = {}
self.input_read_episodes = False
self.input_read_sample_batches = False
self.input_filesystem = None
self.input_filesystem_kwargs = {}
self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.input_spaces_jsonable = True
self.materialize_data = False
Expand Down Expand Up @@ -2555,8 +2553,6 @@ def offline_data(
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_read_episodes: Optional[bool] = NotProvided,
input_read_sample_batches: Optional[bool] = NotProvided,
input_filesystem: Optional[str] = NotProvided,
input_filesystem_kwargs: Optional[Dict] = NotProvided,
input_compress_columns: Optional[List[str]] = NotProvided,
materialize_data: Optional[bool] = NotProvided,
materialize_mapped_data: Optional[bool] = NotProvided,
Expand Down Expand Up @@ -2624,8 +2620,8 @@ def offline_data(
inside of RLlib's schema. The other format is a columnar format and is
agnostic to the RL framework used. Use the latter format, if you are
unsure when to use the data or in which RL framework. The default is
to read column data, i.e. False. `input_read_episodes` and
`input_read_sample_batches` cannot be True at the same time. See
to read column data, i.e. `False`. `input_read_episodes` and
`inpuit_read_sample_batches` cannot be `True` at the same time. See
also `output_write_episodes` to define the output data format when
recording.
input_read_sample_batches: Whether offline data is stored in RLlib's old
Expand All @@ -2634,16 +2630,9 @@ def offline_data(
data needs extra transforms and might not concatenate episode chunks
contained in different `SampleBatch`es in the data. If possible avoid
to read `SampleBatch`es and convert them in a controlled form into
RLlib's `EpisodeType` (i.e. `SingleAgentEpisode` or
`MultiAgentEpisode`). The default is False. `input_read_episodes`
and `input_read_sample_batches` cannot be True at the same time.
input_filesystem: A cloud filesystem to handle access to cloud storage when
reading experiences. Should be either "gcs" for Google Cloud Storage,
"s3" for AWS S3 buckets, or "abs" for Azure Blob Storage.
input_filesystem_kwargs: A dictionary holding the kwargs for the filesystem
given by `input_filesystem`. See `gcsfs.GCSFilesystem` for GCS,
`pyarrow.fs.S3FileSystem`, for S3, and `ablfs.AzureBlobFilesystem` for
ABS filesystem arguments.
RLlib`s `EpisodeType`s (i.e. `SingleAgentEpisode` or
`MultiAgentEpisode`). The default is `False`. `input_read_episodes`
and `inpuit_read_sample_batches` cannot be `True` at the same time.
input_compress_columns: What input columns are compressed with LZ4 in the
input data. If data is stored in RLlib's `SingleAgentEpisode` (
`MultiAgentEpisode` not supported, yet). Note the providing
Expand Down Expand Up @@ -2770,10 +2759,6 @@ def offline_data(
self.input_read_episodes = input_read_episodes
if input_read_sample_batches is not NotProvided:
self.input_read_sample_batches = input_read_sample_batches
if input_filesystem is not NotProvided:
self.input_filesystem = input_filesystem
if input_filesystem_kwargs is not NotProvided:
self.input_filesystem_kwargs = input_filesystem_kwargs
if input_compress_columns is not NotProvided:
self.input_compress_columns = input_compress_columns
if materialize_data is not NotProvided:
Expand Down
49 changes: 31 additions & 18 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, config: AlgorithmConfig):
else Path(config.input_)
)
# Use `read_parquet` as default data read method.
self.data_read_method = self.config.input_read_method
self.data_read_method = config.input_read_method
# Override default arguments for the data read method.
self.data_read_method_kwargs = (
self.default_read_method_kwargs | self.config.input_read_method_kwargs
Expand Down Expand Up @@ -114,23 +114,36 @@ def sample(
return_iterator: bool = False,
num_shards: int = 1,
):
# Materialize the mapped data, if necessary. This runs for all the
# data the `OfflinePreLearner` logic and maps them to `MultiAgentBatch`es.
# TODO (simon, sven): This would never update the module nor the
# the connectors. If this is needed we have to check, if we give
# (a) only an iterator and let the learner and OfflinePreLearner
# communicate through the object storage. This only works when
# not materializing.
# (b) Rematerialize the data every couple of iterations. This is
# is costly.
if not self.data_is_mapped:
# Constructor `kwargs` for the `OfflinePreLearner`.
fn_constructor_kwargs = {
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces[INPUT_ENV_SPACES],
}
# If we have multiple learners, add to the constructor `kwargs`.
if (
not return_iterator
or return_iterator
and num_shards <= 1
and not self.batch_iterator
):
# If no iterator should be returned, or if we want to return a single
# batch iterator, we instantiate the batch iterator once, here.
# TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs
# sampling later with a different batch size would need a
# reinstantiation of the iterator.

self.batch_iterator = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs={
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces[INPUT_ENV_SPACES],
},
batch_size=num_samples,
**self.map_batches_kwargs,
).iter_batches(
batch_size=num_samples,
**self.iter_batches_kwargs,
)

# Do we want to return an iterator or a single batch?
if return_iterator:
# In case of multiple shards, we return multiple
# `StreamingSplitIterator` instances.
if num_shards > 1:
# Call here the learner to get an up-to-date module state.
# TODO (simon): This is a workaround as along as learners cannot
Expand Down
138 changes: 110 additions & 28 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,42 +168,16 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]

# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
# Import `msgpack` for decoding.
import msgpack
import msgpack_numpy as mnp

# Read the episodes and decode them.
episodes = [
SingleAgentEpisode.from_state(
msgpack.unpackb(state, object_hook=mnp.decode)
)
for state in batch["item"]
]
self.episode_buffer.add(episodes)
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
# TODO (simon): This can be removed as soon as DreamerV3 has been
# cleaned up, i.e. can use episode samples for training.
sample_episodes=True,
finalize=True,
)
episodes = batch["item"].tolist()
# Else, if we have old stack `SampleBatch`es.
elif self.input_read_sample_batches:
episodes = OfflinePreLearner._map_sample_batch_to_episode(
self._is_multi_agent,
batch,
finalize=True,
finalize=False,
schema=SCHEMA | self.config.input_read_schema,
input_compress_columns=self.config.input_compress_columns,
)["episodes"]
self.episode_buffer.add(episodes)
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
# TODO (simon): This can be removed as soon as DreamerV3 has been
# cleaned up, i.e. can use episode samples for training.
sample_episodes=True,
finalize=True,
)
# Otherwise we map the batch to episodes.
else:
episodes = self._map_to_episodes(
Expand Down Expand Up @@ -532,3 +506,111 @@ def _map_sample_batch_to_episode(
episodes.append(episode)
# Note, `map_batches` expects a `Dict` as return value.
return {"episodes": episodes}

def _map_sample_batch_to_episode(
is_multi_agent: bool,
batch: Dict[str, Union[list, np.ndarray]],
schema: Dict[str, str] = SCHEMA,
finalize: bool = False,
input_compress_columns: Optional[List[str]] = None,
) -> Dict[str, List[EpisodeType]]:
"""Maps an old stack `SampleBatch` to new stack episodes."""

# Set `input_compress_columns` to an empty `list` if `None`.
input_compress_columns = input_compress_columns or []

# TODO (simon): CHeck, if needed. It could possibly happen that a batch contains
# data from different episodes. Merging and resplitting the batch would then
# be the solution.
# Check, if batch comes actually from multiple episodes.
# episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1

# Define a container to collect episodes.
episodes = []
# Loop over `SampleBatch`es in the `ray.data` batch (a dict).
for i, obs in enumerate(batch[schema[Columns.OBS]]):

# If multi-agent we need to extract the agent ID.
# TODO (simon): Check, what happens with the module ID.
if is_multi_agent:
agent_id = (
# The old stack uses "agent_index" instead of "agent_id".
batch[schema["agent_index"]][i][0]
if schema["agent_index"] in batch
else None
)
else:
agent_id = None

if is_multi_agent:
# TODO (simon): Add support for multi-agent episodes.
NotImplementedError
else:
# Unpack observations, if needed.
obs = (
unpack_if_needed(obs.tolist())
if schema[Columns.OBS] in input_compress_columns
else obs.tolist()
)
# Append the last `new_obs` to get the correct length of observations.
obs.append(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
if schema[Columns.OBS] in input_compress_columns
else batch[schema[Columns.NEXT_OBS]][i][-1]
)
# Create a `SingleAgentEpisode`.
episode = SingleAgentEpisode(
id_=batch[schema[Columns.EPS_ID]][i][0],
agent_id=agent_id,
observations=obs,
infos=(
batch[schema[Columns.INFOS]][i]
if schema[Columns.INFOS] in batch
else [{}] * len(obs)
),
# Actions might be (a) serialized. We unserialize them here.
actions=(
unpack_if_needed(batch[schema[Columns.ACTIONS]][i])
if Columns.ACTIONS in input_compress_columns
else batch[schema[Columns.ACTIONS]][i]
),
rewards=batch[schema[Columns.REWARDS]][i],
terminated=(
any(batch[schema[Columns.TERMINATEDS]][i])
if schema[Columns.TERMINATEDS] in batch
else any(batch["dones"][i])
),
truncated=(
any(batch[schema[Columns.TRUNCATEDS]][i])
if schema[Columns.TRUNCATEDS] in batch
else False
),
# TODO (simon): Results in zero-length episodes in connector.
# t_started=batch[Columns.T if Columns.T in batch else
# "unroll_id"][i][0],
# TODO (simon): Single-dimensional columns are not supported.
# Extra model outputs might be serialized. We unserialize them here
# if needed.
# TODO (simon): Check, if we need here also reconversion from
# JSONable in case of composite spaces.
extra_model_outputs={
k: unpack_if_needed(v[i])
if k in input_compress_columns
else v[i]
for k, v in batch.items()
if (
k not in schema
and k not in schema.values()
and k not in ["dones", "agent_index", "type"]
)
},
len_lookback_buffer=0,
)
# Finalize, if necessary.
# TODO (simon, sven): Check, if we should convert all data to lists
# before. Right now only obs are lists.
if finalize:
episode.finalize()
episodes.append(episode)
# Note, `map_batches` expects a `Dict` as return value.
return {"episodes": episodes}
58 changes: 41 additions & 17 deletions rllib/offline/tests/test_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,47 @@ def test_offline_data_load(self):
def test_sample_single_learner(self):
"""Tests using sampling using a single learner."""

# Create a simple config.
config = (
BCConfig()
.environment("CartPole-v1")
.api_stack(
enable_env_runner_and_connector_v2=True,
enable_rl_module_and_learner=True,
)
.offline_data(
input_=[self.data_path],
dataset_num_iters_per_learner=1,
)
.learners(
num_learners=0,
)
.training(
train_batch_size_per_learner=256,
config = AlgorithmConfig().offline_data(
input_=[self.data_path],
)

offline_data = OfflineData(config)

batch = offline_data.data.take_batch(batch_size=10)
episodes = OfflinePreLearner._map_to_episodes(False, batch)["episodes"]

self.assertTrue(len(episodes) == 10)
self.assertTrue(isinstance(episodes[0], SingleAgentEpisode))

def test_offline_convert_from_old_sample_batch_to_episodes(self):

base_path = Path(__file__).parents[2]
sample_batch_data_path = base_path / "tests/data/cartpole/large.json"
config = AlgorithmConfig().offline_data(
input_=["local://" + sample_batch_data_path.as_posix()],
input_read_method="read_json",
input_read_sample_batches=True,
)

offline_data = OfflineData(config)

batch = offline_data.data.take_batch(batch_size=10)
episodes = OfflinePreLearner._map_sample_batch_to_episode(False, batch)[
"episodes"
]

self.assertTrue(len(episodes) == 10)
self.assertTrue(isinstance(episodes[0], SingleAgentEpisode))

def test_sample(self):

config = AlgorithmConfig().offline_data(input_=[self.data_path])

offline_data = OfflineData(config)

batch_iterator = offline_data.data.map_batches(
functools.partial(
OfflinePreLearner._map_to_episodes, offline_data.is_multi_agent
)
)

Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/bc/cartpole_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)

stop = {
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0,
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0,
TRAINING_ITERATION: 350,
}

Expand Down

0 comments on commit 9daf658

Please sign in to comment.