Skip to content

Commit

Permalink
[RLlib; Offline RL] - Enable reading old-stack SampleBatch data in …
Browse files Browse the repository at this point in the history
…new stack Offline RL. (ray-project#47359)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 12, 2024
1 parent 3557f98 commit a97f121
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 13 deletions.
1 change: 1 addition & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,7 @@ py_test(
srcs = ["offline/tests/test_offline_data.py"],
data = [
"tests/data/cartpole/cartpole-v1_large",
"tests/data/cartpole/large.json",
],
)

Expand Down
19 changes: 17 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.input_read_episodes = False
self.input_read_sample_batches = False
self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.input_spaces_jsonable = True
self.map_batches_kwargs = {}
Expand Down Expand Up @@ -2404,6 +2405,7 @@ def offline_data(
input_read_method_kwargs: Optional[Dict] = NotProvided,
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_read_episodes: Optional[bool] = NotProvided,
input_read_sample_batches: Optional[bool] = NotProvided,
input_compress_columns: Optional[List[str]] = NotProvided,
map_batches_kwargs: Optional[Dict] = NotProvided,
iter_batches_kwargs: Optional[Dict] = NotProvided,
Expand Down Expand Up @@ -2465,8 +2467,19 @@ def offline_data(
inside of RLlib's schema. The other format is a columnar format and is
agnostic to the RL framework used. Use the latter format, if you are
unsure when to use the data or in which RL framework. The default is
to read column data, i.e. `False`. See also `output_write_episodes`
to define the output data format when recording.
to read column data, i.e. `False`. `input_read_episodes` and
`inpuit_read_sample_batches` cannot be `True` at the same time. See
also `output_write_episodes` to define the output data format when
recording.
input_read_sample_batches: Whether offline data is stored in RLlib's old
stack `SampleBatch` type. This is usually the case for older data
recorded with RLlib in JSON line format. Reading in `SampleBatch`
data needs extra transforms and might not concatenate episode chunks
contained in different `SampleBatch`es in the data. If possible avoid
to read `SampleBatch`es and convert them in a controlled form into
RLlib`s `EpisodeType`s (i.e. `SingleAgentEpisode` or
`MultiAgentEpisode`). The default is `False`. `input_read_episodes`
and `inpuit_read_sample_batches` cannot be `True` at the same time.
input_compress_columns: What input columns are compressed with LZ4 in the
input data. If data is stored in `RLlib`'s `SingleAgentEpisode` (
`MultiAgentEpisode` not supported, yet). Note,
Expand Down Expand Up @@ -2566,6 +2579,8 @@ def offline_data(
self.input_read_schema = input_read_schema
if input_read_episodes is not NotProvided:
self.input_read_episodes = input_read_episodes
if input_read_sample_batches is not NotProvided:
self.input_read_sample_batches = input_read_sample_batches
if input_compress_columns is not NotProvided:
self.input_compress_columns = input_compress_columns
if map_batches_kwargs is not NotProvided:
Expand Down
9 changes: 9 additions & 0 deletions rllib/algorithms/marwil/marwil_offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, MultiAgentBatch]:
# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
episodes = batch["item"].tolist()
# Else, if we have old stack `SampleBatch`es.
elif self.input_read_sample_batches:
episodes = OfflinePreLearner._map_sample_batch_to_episode(
self._is_multi_agent,
batch,
finalize=False,
schema=SCHEMA | self.config.input_read_schema,
input_compress_columns=self.config.input_compress_columns,
)["episodes"]
# Otherwise we ap the batch to episodes.
else:
# Map the batch to episodes.
Expand Down
6 changes: 4 additions & 2 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core import COMPONENT_RL_MODULE
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.annotations import (
ExperimentalAPI,
Expand All @@ -24,7 +25,7 @@ def __init__(self, config: AlgorithmConfig):
self.path = (
config.input_ if isinstance(config.input_, list) else Path(config.input_)
)
# Use `read_json` as default data read method.
# Use `read_parquet` as default data read method.
self.data_read_method = config.input_read_method
# Override default arguments for the data read method.
self.data_read_method_kwargs = (
Expand Down Expand Up @@ -72,12 +73,13 @@ def sample(
# TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs
# sampling later with a different batch size would need a
# reinstantiation of the iterator.

self.batch_iterator = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs={
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces["__env__"],
"spaces": self.spaces[INPUT_ENV_SPACES],
},
batch_size=num_samples,
**self.map_batches_kwargs,
Expand Down
132 changes: 127 additions & 5 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(

self.config = config
self.input_read_episodes = self.config.input_read_episodes
self.input_read_sample_batches = self.config.input_read_sample_batches
# We need this learner to run the learner connector pipeline.
# If it is a `Learner` instance, the `Learner` is local.
if isinstance(learner, Learner):
Expand Down Expand Up @@ -144,7 +145,16 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
episodes = batch["item"].tolist()
# Otherwise we ap the batch to episodes.
# Else, if we have old stack `SampleBatch`es.
elif self.input_read_sample_batches:
episodes = OfflinePreLearner._map_sample_batch_to_episode(
self._is_multi_agent,
batch,
finalize=False,
schema=SCHEMA | self.config.input_read_schema,
input_compress_columns=self.config.input_compress_columns,
)["episodes"]
# Otherwise we map the batch to episodes.
else:
episodes = self._map_to_episodes(
self._is_multi_agent,
Expand Down Expand Up @@ -227,7 +237,7 @@ def _should_module_be_updated(self, module_id, multi_agent_batch=None):
@staticmethod
def _map_to_episodes(
is_multi_agent: bool,
batch: Dict[str, np.ndarray],
batch: Dict[str, Union[list, np.ndarray]],
schema: Dict[str, str] = SCHEMA,
finalize: bool = False,
input_compress_columns: Optional[List[str]] = None,
Expand Down Expand Up @@ -271,7 +281,7 @@ def convert(sample, space):

if is_multi_agent:
# TODO (simon): Add support for multi-agent episodes.
pass
NotImplementedError
else:
# Build a single-agent episode with a single row of the batch.
episode = SingleAgentEpisode(
Expand All @@ -288,7 +298,7 @@ def convert(sample, space):
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]),
observation_space,
)
if Columns.NEXT_OBS in input_compress_columns
if Columns.OBS in input_compress_columns
else convert(
batch[schema[Columns.NEXT_OBS]][i], observation_space
),
Expand Down Expand Up @@ -334,7 +344,11 @@ def convert(sample, space):
else v[i]
]
for k, v in batch.items()
if (k not in schema and k not in schema.values())
if (
k not in schema
and k not in schema.values()
and k not in ["dones", "agent_index", "type"]
)
},
len_lookback_buffer=0,
)
Expand All @@ -344,3 +358,111 @@ def convert(sample, space):
episodes.append(episode)
# Note, `map_batches` expects a `Dict` as return value.
return {"episodes": episodes}

def _map_sample_batch_to_episode(
is_multi_agent: bool,
batch: Dict[str, Union[list, np.ndarray]],
schema: Dict[str, str] = SCHEMA,
finalize: bool = False,
input_compress_columns: Optional[List[str]] = None,
) -> Dict[str, List[EpisodeType]]:
"""Maps an old stack `SampleBatch` to new stack episodes."""

# Set `input_compress_columns` to an empty `list` if `None`.
input_compress_columns = input_compress_columns or []

# TODO (simon): CHeck, if needed. It could possibly happen that a batch contains
# data from different episodes. Merging and resplitting the batch would then
# be the solution.
# Check, if batch comes actually from multiple episodes.
# episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1

# Define a container to collect episodes.
episodes = []
# Loop over `SampleBatch`es in the `ray.data` batch (a dict).
for i, obs in enumerate(batch[schema[Columns.OBS]]):

# If multi-agent we need to extract the agent ID.
# TODO (simon): Check, what happens with the module ID.
if is_multi_agent:
agent_id = (
# The old stack uses "agent_index" instead of "agent_id".
batch[schema["agent_index"]][i][0]
if schema["agent_index"] in batch
else None
)
else:
agent_id = None

if is_multi_agent:
# TODO (simon): Add support for multi-agent episodes.
NotImplementedError
else:
# Unpack observations, if needed.
obs = (
unpack_if_needed(obs.tolist())
if schema[Columns.OBS] in input_compress_columns
else obs.tolist()
)
# Append the last `new_obs` to get the correct length of observations.
obs.append(
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
if schema[Columns.OBS] in input_compress_columns
else batch[schema[Columns.NEXT_OBS]][i][-1]
)
# Create a `SingleAgentEpisode`.
episode = SingleAgentEpisode(
id_=batch[schema[Columns.EPS_ID]][i][0],
agent_id=agent_id,
observations=obs,
infos=(
batch[schema[Columns.INFOS]][i]
if schema[Columns.INFOS] in batch
else [{}] * len(obs)
),
# Actions might be (a) serialized. We unserialize them here.
actions=(
unpack_if_needed(batch[schema[Columns.ACTIONS]][i])
if Columns.ACTIONS in input_compress_columns
else batch[schema[Columns.ACTIONS]][i]
),
rewards=batch[schema[Columns.REWARDS]][i],
terminated=(
any(batch[schema[Columns.TERMINATEDS]][i])
if schema[Columns.TERMINATEDS] in batch
else any(batch["dones"][i])
),
truncated=(
any(batch[schema[Columns.TRUNCATEDS]][i])
if schema[Columns.TRUNCATEDS] in batch
else False
),
# TODO (simon): Results in zero-length episodes in connector.
# t_started=batch[Columns.T if Columns.T in batch else
# "unroll_id"][i][0],
# TODO (simon): Single-dimensional columns are not supported.
# Extra model outputs might be serialized. We unserialize them here
# if needed.
# TODO (simon): Check, if we need here also reconversion from
# JSONable in case of composite spaces.
extra_model_outputs={
k: unpack_if_needed(v[i])
if k in input_compress_columns
else v[i]
for k, v in batch.items()
if (
k not in schema
and k not in schema.values()
and k not in ["dones", "agent_index", "type"]
)
},
len_lookback_buffer=0,
)
# Finalize, if necessary.
# TODO (simon, sven): Check, if we should convert all data to lists
# before. Right now only obs are lists.
if finalize:
episode.finalize()
episodes.append(episode)
# Note, `map_batches` expects a `Dict` as return value.
return {"episodes": episodes}
20 changes: 20 additions & 0 deletions rllib/offline/tests/test_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,26 @@ def test_offline_convert_to_episodes(self):
self.assertTrue(len(episodes) == 10)
self.assertTrue(isinstance(episodes[0], SingleAgentEpisode))

def test_offline_convert_from_old_sample_batch_to_episodes(self):

base_path = Path(__file__).parents[2]
sample_batch_data_path = base_path / "tests/data/cartpole/large.json"
config = AlgorithmConfig().offline_data(
input_=["local://" + sample_batch_data_path.as_posix()],
input_read_method="read_json",
input_read_sample_batches=True,
)

offline_data = OfflineData(config)

batch = offline_data.data.take_batch(batch_size=10)
episodes = OfflinePreLearner._map_sample_batch_to_episode(False, batch)[
"episodes"
]

self.assertTrue(len(episodes) == 10)
self.assertTrue(isinstance(episodes[0], SingleAgentEpisode))

def test_sample(self):

config = AlgorithmConfig().offline_data(input_=[self.data_path])
Expand Down
4 changes: 2 additions & 2 deletions rllib/tuned_examples/bc/cartpole_bc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pathlib import Path

from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.bc import BCConfig
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
EVALUATION_RESULTS,
TRAINING_ITERATION_TIMER,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
Expand Down Expand Up @@ -75,7 +75,7 @@

stop = {
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0,
TRAINING_ITERATION_TIMER: 350.0,
TRAINING_ITERATION: 350,
}

if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rllib/tuned_examples/bc/pendulum_bc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pathlib import Path

from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.bc import BCConfig
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
EVALUATION_RESULTS,
TRAINING_ITERATION_TIMER,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
Expand Down Expand Up @@ -61,7 +61,7 @@

stop = {
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -200.0,
TRAINING_ITERATION_TIMER: 350.0,
TRAINING_ITERATION: 350,
}

if __name__ == "__main__":
Expand Down

0 comments on commit a97f121

Please sign in to comment.