Skip to content

Commit

Permalink
[RLlib] Add APPO Atari (Pong) release test. (ray-project#48681)
Browse files Browse the repository at this point in the history
Signed-off-by: JP-sDEV <jon.pablo80@gmail.com>
  • Loading branch information
sven1977 authored and JP-sDEV committed Nov 14, 2024
1 parent 342dfa7 commit 63a634d
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 30 deletions.
36 changes: 35 additions & 1 deletion release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2693,6 +2693,40 @@
# Learning and benchmarking tests
# ----------------------------------------------------------

# --------------------------
# APPO
# --------------------------
- name: rllib_learning_tests_pong_appo_torch
group: RLlib tests
working_dir: rllib_tests

stable: true

frequency: nightly
team: rllib
cluster:
byod:
type: gpu
post_build_script: byod_rllib_test.sh
runtime_env:
- RLLIB_TEST_NO_JAX_IMPORT=1
- LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/ray/.mujoco/mujoco210/bin
cluster_compute: 2gpus_64cpus.yaml

run:
timeout: 1800
script: python learning_tests/tuned_examples/appo/pong_appo.py --enable-new-api-stack --num-learners=0 --num-env-runners=46 --stop-reward=19.5 --as-release-test

alert: default

variations:
- __suffix__: aws
- __suffix__: gce
env: gce
frequency: manual
cluster:
cluster_compute: 2gpus_64cpus_gce.yaml

# --------------------------
# DreamerV3
# --------------------------
Expand Down Expand Up @@ -2731,7 +2765,7 @@
# --------------------------
# IMPALA
# --------------------------
- name: rllib_learning_tests_impala_ppo_torch
- name: rllib_learning_tests_pong_impala_torch
group: RLlib tests
working_dir: rllib_tests

Expand Down
55 changes: 26 additions & 29 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def __init__(self, algo_class=None):
self.broadcast_interval = 1
self.num_aggregation_workers = 0
self.num_gpu_loader_threads = 8
# IMPALA takes care of its own EnvRunner (weights, connector, counters)
# synching.
# IMPALA takes care of its own EnvRunner (weights, connector, metrics) synching.
self._dont_auto_sync_env_runner_states = True

self.grad_clip = 40.0
Expand Down Expand Up @@ -650,15 +649,21 @@ def training_step(self):
value=len(data_packages_for_learner_group),
)
rl_module_state = None
last_good_learner_results = None
num_learner_group_results_received = 0

for batch_ref_or_episode_list_ref in data_packages_for_learner_group:
return_state = (
self.metrics.peek(
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS,
default=0,
)
>= self.config.broadcast_interval
)
if self.config.num_aggregation_workers:
learner_results = self.learner_group.update_from_batch(
batch=batch_ref_or_episode_list_ref,
async_update=do_async_updates,
return_state=True,
return_state=return_state,
timesteps={
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
Expand All @@ -673,7 +678,7 @@ def training_step(self):
learner_results = self.learner_group.update_from_episodes(
episodes=batch_ref_or_episode_list_ref,
async_update=do_async_updates,
return_state=True,
return_state=return_state,
timesteps={
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
Expand All @@ -684,6 +689,13 @@ def training_step(self):
minibatch_size=self.config.minibatch_size,
shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch,
)
# TODO (sven): Rename this metric into a more fitting name: ex.
# `NUM_LEARNER_UPDATED_SINCE_LAST_WEIGHTS_SYNC`
self.metrics.log_value(
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS,
1,
reduce="sum",
)
if not do_async_updates:
learner_results = [learner_results]

Expand All @@ -699,7 +711,6 @@ def training_step(self):
stats_dicts=results_from_n_learners,
key=LEARNER_RESULTS,
)
last_good_learner_results = results_from_n_learners
self.metrics.log_value(
key=MEAN_NUM_LEARNER_GROUP_RESULTS_RECEIVED,
value=num_learner_group_results_received,
Expand All @@ -711,31 +722,17 @@ def training_step(self):
# Figure out, whether we should sync/broadcast the (remote) EnvRunner states.
# Note: `learner_results` is a List of n (num async calls) Lists of m
# (num Learner workers) ResultDicts each.
if last_good_learner_results:
# TODO (sven): Rename this metric into a more fitting name: ex.
# `NUM_LEARNER_UPDATED_SINCE_LAST_WEIGHTS_SYNC`
self.metrics.log_value(
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, 1, reduce="sum"
if rl_module_state is not None:
self.metrics.set_value(
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, 0
)
# Merge available EnvRunner states into local worker's EnvRunner state.
# Broadcast merged EnvRunner state AND new model weights back to all remote
# EnvRunners that - in this call - had returned samples.
if (
self.metrics.peek(
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS
)
>= self.config.broadcast_interval
):
self.metrics.set_value(
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, 0
self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum")
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.env_runner_group.sync_env_runner_states(
config=self.config,
connector_states=connector_states,
rl_module_state=rl_module_state,
)
self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum")
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.env_runner_group.sync_env_runner_states(
config=self.config,
connector_states=connector_states,
rl_module_state=rl_module_state,
)

def _sample_and_get_connector_states(self):
def _remote_sample_get_state_and_metrics(_worker):
Expand Down
84 changes: 84 additions & 0 deletions rllib/tuned_examples/appo/pong_appo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import gymnasium as gym

from ray.rllib.algorithms.appo import APPOConfig
from ray.rllib.connectors.env_to_module.frame_stacking import FrameStackingEnvToModule
from ray.rllib.connectors.learner.frame_stacking import FrameStackingLearner
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack
from ray.rllib.utils.test_utils import add_rllib_example_script_args
from ray.tune.registry import register_env

parser = add_rllib_example_script_args(
default_reward=20.0,
default_timesteps=10000000,
)
parser.set_defaults(
enable_new_api_stack=True,
env="ale_py:ALE/Pong-v5",
)
args = parser.parse_args()


def _make_env_to_module_connector(env):
return FrameStackingEnvToModule(num_frames=4)


def _make_learner_connector(input_observation_space, input_action_space):
return FrameStackingLearner(num_frames=4)


def _env_creator(cfg):
return wrap_atari_for_new_api_stack(
gym.make(args.env, **cfg, **{"render_mode": "rgb_array"}),
dim=64,
framestack=None,
)


register_env("env", _env_creator)


config = (
APPOConfig()
.environment(
"env",
env_config={
# Make analogous to old v4 + NoFrameskip.
"frameskip": 1,
"full_action_space": False,
"repeat_action_probability": 0.0,
},
clip_rewards=True,
)
.env_runners(
env_to_module_connector=_make_env_to_module_connector,
num_envs_per_env_runner=2,
max_requests_in_flight_per_env_runner=1,
)
.training(
learner_connector=_make_learner_connector,
train_batch_size_per_learner=500,
grad_clip=30.0,
grad_clip_by="global_norm",
lr=0.0009 * ((args.num_learners or 1) ** 0.5),
vf_loss_coeff=1.0,
entropy_coeff=[[0, 0.05], [3000000, 0.0]], # <- crucial parameter to finetune
# Only update connector states and model weights every n training_step calls.
broadcast_interval=5,
learner_queue_size=1,
)
.rl_module(
model_config=DefaultModelConfig(
vf_share_layers=True,
conv_filters=[(16, 4, 2), (32, 4, 2), (64, 4, 2), (128, 4, 2)],
conv_activation="relu",
head_fcnet_hiddens=[256],
)
)
)


if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

run_rllib_example_script_experiment(config, args)

0 comments on commit 63a634d

Please sign in to comment.