-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] RLUnplugged example on new API stack. #46792
Changes from all commits
4d40a17
147a512
dfbff75
d8738da
27145d9
211668a
521de61
119c475
463773e
11a3594
e353ce8
9d2c168
865177c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we call this file simply: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add a hint that it is pong data from rl_unplugged. What do you think? |
||
schema={ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed here? If this is an explanation of the schema, maybe move this down into the config section (where we explain the schema translation)? |
||
a_t: int64, | ||
r_t: float, | ||
episode_return: float, | ||
o_tp1: list<item: binary>, | ||
episode_id: int64, | ||
a_tp1: int64, | ||
o_t: list<item: binary>, | ||
d_t: float | ||
} | ||
""" | ||
|
||
from google.cloud import storage | ||
import gymnasium as gym | ||
import io | ||
import numpy as np | ||
from pathlib import Path | ||
from PIL import Image | ||
import tree | ||
from typing import Optional | ||
|
||
from ray.data.datasource.tfrecords_datasource import TFXReadOptions | ||
from ray.rllib.algorithms.bc import BCConfig | ||
from ray.rllib.connectors.connector_v2 import ConnectorV2 | ||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.test_utils import ( | ||
add_rllib_example_script_args, | ||
) | ||
from ray import tune | ||
|
||
|
||
class DecodeObservations(ConnectorV2): | ||
def __init__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a (small) docstring here that explains what this connector does. |
||
self, | ||
input_observation_space: Optional[gym.Space] = None, | ||
input_action_space: Optional[gym.Space] = None, | ||
*, | ||
multi_agent: bool = False, | ||
as_learner_connector: bool = True, | ||
**kwargs, | ||
): | ||
"""Decodes observation from PNG to numpy array. | ||
|
||
Note, `rl_unplugged`'s stored observations are framestacked with | ||
four frames per observation. This connector returns therefore | ||
decoded observations of shape `(84, 84, 4)`. | ||
|
||
Args: | ||
multi_agent: Whether this is a connector operating on a multi-agent | ||
observation space mapping AgentIDs to individual agents' observations. | ||
as_learner_connector: Whether this connector is part of a Learner connector | ||
pipeline, as opposed to an env-to-module pipeline. | ||
""" | ||
super().__init__( | ||
input_observation_space=input_observation_space, | ||
input_action_space=input_action_space, | ||
**kwargs, | ||
) | ||
|
||
self._multi_agent = multi_agent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need these options? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we don't. It was not clear to me, yet, that this does not need to be set when in single-agent mode. |
||
self._as_learner_connector = as_learner_connector | ||
|
||
@override(ConnectorV2) | ||
def __call__( | ||
self, | ||
*, | ||
rl_module, | ||
data, | ||
episodes, | ||
explore=None, | ||
shared_data=None, | ||
**kwargs, | ||
): | ||
|
||
for sa_episode in self.single_agent_episode_iterator( | ||
episodes, agents_that_stepped_only=False | ||
): | ||
# Map encoded PNGs into arrays of shape (84, 84, 4). | ||
def _map_fn(s): | ||
return np.concatenate( | ||
[ | ||
np.array(Image.open(io.BytesIO(s[i]))).reshape(84, 84, 1) | ||
for i in range(4) | ||
], | ||
axis=2, | ||
) | ||
|
||
# Add the observations for t. | ||
self.add_n_batch_items( | ||
batch=data, | ||
column=Columns.OBS, | ||
items_to_add=tree.map_structure( | ||
_map_fn, | ||
sa_episode.get_observations(slice(0, len(sa_episode))), | ||
), | ||
num_items=len(sa_episode), | ||
single_agent_episode=sa_episode, | ||
) | ||
# Add the observations for t+1. | ||
self.add_n_batch_items( | ||
batch=data, | ||
column=Columns.NEXT_OBS, | ||
items_to_add=tree.map_structure( | ||
_map_fn, | ||
sa_episode.get_observations(slice(1, len(sa_episode) + 1)), | ||
), | ||
num_items=len(sa_episode), | ||
single_agent_episode=sa_episode, | ||
) | ||
|
||
return data | ||
|
||
|
||
# Make the learner connector. | ||
def _make_learner_connector(observation_space, action_space): | ||
return DecodeObservations() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's get rid of this by doing in the config below:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this more! Thanks! |
||
|
||
|
||
# Wrap the environment used in evalaution into `RLlib`'s Atari Wrapper | ||
# that automatically stacks frames and converts to the dimension used | ||
# in the collection of the `rl_unplugged` data. | ||
def _env_creator(cfg): | ||
return wrap_atari_for_new_api_stack( | ||
gym.make("ALE/Pong-v5", **cfg), | ||
# Perform frame-stacking through ConnectorV2 API. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, if we say: framestack=4 here then we do NOT perform frame-stacking through ConnectorV2 here. -> Change comment by removing the ConnectorV2 statement. Note: At this point, the gain in performance when using the connector is minimal, so can probably be neglected for simplicity (easier for user to do framestacking in env-wrapper). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the hint. WIll remove this comment. |
||
framestack=4, | ||
dim=84, | ||
) | ||
|
||
|
||
# Register the wrapped environment to `tune`. | ||
tune.register_env("WrappedALE/Pong-v5", _env_creator) | ||
|
||
parser = add_rllib_example_script_args() | ||
# Use `parser` to add your own custom command line options to this script | ||
# and (if needed) use their values toset up `config` below. | ||
args = parser.parse_args() | ||
|
||
# We only use the Atari game `Pong` here. Users can choose other Atari | ||
# games and set here the name. This will download `TfRecords` dataset from GCS. | ||
game = "Pong" | ||
# There are many run numbers, we choose the first one for demonstration. This | ||
# can be chosen by users. To use all data use a list of file paths (see | ||
# `num_shards`) and its usage further below. | ||
run_number = 1 | ||
# num_shards = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the be commented out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually a number that is now hard coded into the path. We have many runs and for each multiple shards in the bucket - each is a file. For the example I use only a single of these files. |
||
|
||
# Make the temporary directory for the downloaded data. | ||
tmp_path = "/tmp/atari" | ||
Path(tmp_path).joinpath(game).mkdir(exist_ok=True, parents=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nit: should we use the "/" op of pathlib.Path here?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's do it this is better to read. |
||
destination_file_name = f"{tmp_path}/{game}/run_{run_number}-00000-of-00001" | ||
|
||
# If the file is not downloaded, yet, download it here. | ||
if not Path(destination_file_name).exists(): | ||
# Define the bucket and source file. | ||
bucket_name = "rl_unplugged" | ||
source_blob_name = f"atari/{game}/run_{run_number}-00000-of-00100" | ||
|
||
# Download the data from the bucket. | ||
storage_client = storage.Client.create_anonymous_client() | ||
bucket = storage_client.bucket(bucket_name) | ||
blob = bucket.blob(source_blob_name) | ||
blob.download_to_filename(destination_file_name) | ||
|
||
# Define the config for Behavior Cloning. | ||
config = ( | ||
BCConfig() | ||
.environment( | ||
env="WrappedALE/Pong-v5", | ||
clip_rewards=True, | ||
) | ||
# Use the new API stack that makes directly use of `ray.data`. | ||
.api_stack( | ||
enable_rl_module_and_learner=True, | ||
enable_env_runner_and_connector_v2=True, | ||
) | ||
# Evaluate in the actual environment online. | ||
.evaluation( | ||
evaluation_interval=3, | ||
evaluation_num_env_runners=1, | ||
evaluation_duration=5, | ||
evaluation_parallel_to_training=True, | ||
) | ||
.learners( | ||
num_learners=args.num_gpus if args.num_gpus > 1 else 0, | ||
num_gpus_per_learner=1, | ||
) | ||
# Note, the `input_` argument is the major argument for the | ||
# new offline API. Via the `input_read_method_kwargs` the | ||
# arguments for the `ray.data.Dataset` read method can be | ||
# configured. The read method needs at least as many blocks | ||
# as remote learners. | ||
.offline_data( | ||
input_=destination_file_name, | ||
input_read_method="read_tfrecords", | ||
input_read_method_kwargs={ | ||
# Note, `TFRecords` datasets in `rl_unplugged` are GZIP | ||
# compressed and Arrow needs to decompress them. | ||
"arrow_open_stream_args": {"compression": "gzip"}, | ||
# Use enough reading blocks to scale well. | ||
"override_num_blocks": 20, | ||
# TFX improves performance extensively. `tfx-bsl` needs to be | ||
# installed for this. | ||
"tfx_read_options": TFXReadOptions( | ||
batch_size=2000, | ||
), | ||
}, | ||
# `rl_unplugged`'s data schema is different from the one used | ||
# internally in `RLlib`. Define the schema here so it can be used | ||
# when transforming column data to episodes. | ||
input_read_schema={ | ||
Columns.EPS_ID: "episode_id", | ||
Columns.OBS: "o_t", | ||
Columns.ACTIONS: "a_t", | ||
Columns.REWARDS: "r_t", | ||
Columns.NEXT_OBS: "o_tp1", | ||
Columns.TERMINATEDS: "d_t", | ||
}, | ||
# Increase the parallelism in transforming batches, such that while | ||
# training, new batches are transformed while others are used in updating. | ||
map_batches_kwargs={"concurrency": max(args.num_gpus * 20, 20)}, | ||
# When iterating over batches in the dataset, prefetch at least 20 | ||
# batches per learner. Increase this for scaling out more. | ||
iter_batches_kwargs={ | ||
"prefetch_batches": max(args.num_gpus * 20, 20), | ||
"local_shuffle_buffer_size": None, | ||
}, | ||
) | ||
.training( | ||
# To increase learning speed with multiple learners, | ||
# increase the learning rate correspondingly. | ||
lr=0.0008 * max(1, args.num_gpus**0.5), | ||
train_batch_size_per_learner=2000, | ||
# Use the defined learner connector above, to decode observations. | ||
learner_connector=_make_learner_connector, | ||
) | ||
.rl_module( | ||
model_config_dict={ | ||
"vf_share_layers": True, | ||
"conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], | ||
"conv_activation": "relu", | ||
"post_fcnet_hiddens": [256], | ||
"uses_new_env_runners": True, | ||
} | ||
) | ||
) | ||
|
||
# TODO (simon): Change to use the `run_rllib_example` function as soon as tuned. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: TODO? Or still WIP? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WIP :) It is still not running smoothly and for debugging I am not using tune. |
||
algo = config.build() | ||
|
||
for i in range(10): | ||
print(f"Iteration: {i + 1}") | ||
results = algo.train() | ||
print(results) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, this was missing :)