Skip to content

Commit

Permalink
[RLlib] APPO enhancements (new API stack) vol 01: Add circular buffer (
Browse files Browse the repository at this point in the history
…ray-project#48798)

Signed-off-by: Connor Sanders <connor@elastiflow.com>
  • Loading branch information
sven1977 authored and jecsand838 committed Dec 4, 2024
1 parent 7f3e56b commit fe9da76
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 84 deletions.
39 changes: 20 additions & 19 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -164,31 +164,32 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete", "torch_only"],
size = "large",
srcs = ["tuned_examples/appo/cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=1"]
)
py_test(
name = "learning_tests_cartpole_appo_gpu",
main = "tuned_examples/appo/cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "large",
srcs = ["tuned_examples/appo/cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=0", "--num-gpus-per-learner=1"]
args = ["--as-test", "--num-learners=1", "--num-cpus=8", "--num-env-runners=6"]
)
# TODO (sven): For some weird reason, this test runs extremely slow on the CI (not on cluster, not locally) -> taking this out for now ...
# py_test(
# name = "learning_tests_cartpole_appo_gpu",
# main = "tuned_examples/appo/cartpole_appo.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
# size = "large",
# srcs = ["tuned_examples/appo/cartpole_appo.py"],
# args = ["--as-test", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"]
# )
py_test(
name = "learning_tests_cartpole_appo_multi_cpu",
main = "tuned_examples/appo/cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/appo/cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=2"]
args = ["--as-test", "--num-learners=2", "--num-cpus=9", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_cartpole_appo_multi_gpu",
main = "tuned_examples/appo/cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"],
size = "large",
srcs = ["tuned_examples/appo/cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=2", "--num-gpus-per-learner=1"]
args = ["--as-test", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"]
)
# MultiAgentCartPole
py_test(
Expand All @@ -197,31 +198,31 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete", "torch_only"],
size = "large",
srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=1"]
args = ["--as-test", "--num-agents=2", "--num-learners=1", "--num-cpus=8", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_multi_agent_cartpole_appo_gpu",
main = "tuned_examples/appo/multi_agent_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "large",
srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=6"]
args = ["--as-test", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_multi_agent_cartpole_appo_multi_cpu",
main = "tuned_examples/appo/multi_agent_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=2", "--num-cpus=7"]
args = ["--as-test", "--num-agents=2", "--num-learners=2", "--num-cpus=9", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_multi_agent_cartpole_appo_multi_gpu",
main = "tuned_examples/appo/multi_agent_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"],
size = "large",
srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7"]
args = ["--as-test", "--num-agents=2", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"]
)
# StatelessCartPole
py_test(
Expand All @@ -230,31 +231,31 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=1"]
args = ["--as-test", "--num-learners=1", "--num-cpus=8", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_stateless_cartpole_appo_gpu",
main = "tuned_examples/appo/stateless_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "large",
srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1"]
args = ["--as-test", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_stateless_cartpole_appo_multi_cpu",
main = "tuned_examples/appo/stateless_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=2"]
args = ["--as-test", "--num-learners=2", "--num-cpus=9", "--num-env-runners=6"]
)
py_test(
name = "learning_tests_stateless_cartpole_appo_multi_gpu",
main = "tuned_examples/appo/stateless_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"],
size = "large",
srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-learners=2", "--num-gpus-per-learner=1"]
args = ["--as-test", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"]
)
# MultiAgentStatelessCartPole
# py_test(
Expand Down
80 changes: 67 additions & 13 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Asynchronous Proximal Policy Optimization (APPO)
================================================
"""Asynchronous Proximal Policy Optimization (APPO)
This file defines the distributed Algorithm class for the asynchronous version
of proximal policy optimization (APPO).
See `appo_[tf|torch]_policy.py` for the definition of the policy loss.
The algorithm is described in [1] (under the name of "IMPACT"):
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#appo
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Luo et al. 2020
https://arxiv.org/pdf/1912.00167
"""

from typing import Optional, Type
Expand Down Expand Up @@ -108,18 +108,19 @@ def __init__(self, algo_class=None):
self.kl_coeff = 1.0
self.kl_target = 0.01
self.target_worker_clipping = 2.0
# TODO (sven): Activate once v-trace sequences in non-RNN batch are solved.
# If we switch this on right now, the shuffling would destroy the rollout
# sequences (non-zero-padded!) needed in the batch for v-trace.
# self.shuffle_batch_per_epoch = True

# Circular replay buffer settings.
# Used in [1] for discrete action tasks:
# `circular_buffer_num_batches=4` and `circular_buffer_iterations_per_batch=2`
# For cont. action tasks:
# `circular_buffer_num_batches=16` and `circular_buffer_iterations_per_batch=20`
self.circular_buffer_num_batches = 4
self.circular_buffer_iterations_per_batch = 2

# Override some of IMPALAConfig's default values with APPO-specific values.
self.num_env_runners = 2
self.min_time_s_per_iteration = 10
self.target_network_update_freq = 1
self.learner_queue_size = 16
self.learner_queue_timeout = 300
self.max_sample_requests_in_flight_per_worker = 2
self.broadcast_interval = 1
self.grad_clip = 40.0
# Note: Only when using enable_rl_module_and_learner=True can the clipping mode
Expand All @@ -145,6 +146,8 @@ def __init__(self, algo_class=None):
self.minibatch_buffer_size = 1 # @OldAPIStack
self.replay_proportion = 0.0 # @OldAPIStack
self.replay_buffer_num_slots = 100 # @OldAPIStack
self.learner_queue_size = 16 # @OldAPIStack
self.learner_queue_timeout = 300 # @OldAPIStack

# Deprecated keys.
self.target_update_frequency = DEPRECATED_VALUE
Expand All @@ -164,6 +167,8 @@ def training(
tau: Optional[float] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
target_worker_clipping: Optional[float] = NotProvided,
circular_buffer_num_batches: Optional[int] = NotProvided,
circular_buffer_iterations_per_batch: Optional[int] = NotProvided,
# Deprecated keys.
target_update_frequency=DEPRECATED_VALUE,
**kwargs,
Expand Down Expand Up @@ -197,6 +202,14 @@ def training(
target_worker_clipping: The maximum value for the target-worker-clipping
used for computing the IS ratio, described in [1]
IS = min(π(i) / π(target), ρ) * (π / π(i))
circular_buffer_num_batches: The number of train batches that fit
into the circular buffer. Each such train batch can be sampled for
training max. `circular_buffer_iterations_per_batch` times.
circular_buffer_iterations_per_batch: The number of times any train
batch in the circular buffer can be sampled for training. A batch gets
evicted from the buffer either if it's the oldest batch in the buffer
and a new batch is added OR if the batch reaches this max. number of
being sampled.
Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -233,9 +246,50 @@ def training(
self.target_network_update_freq = target_network_update_freq
if target_worker_clipping is not NotProvided:
self.target_worker_clipping = target_worker_clipping
if circular_buffer_num_batches is not NotProvided:
self.circular_buffer_num_batches = circular_buffer_num_batches
if circular_buffer_iterations_per_batch is not NotProvided:
self.circular_buffer_iterations_per_batch = (
circular_buffer_iterations_per_batch
)

return self

@override(IMPALAConfig)
def validate(self) -> None:
super().validate()

# On new API stack, circular buffer should be used, not `minibatch_buffer_size`.
if self.enable_rl_module_and_learner:
if self.minibatch_buffer_size != 1 or self.replay_proportion != 0.0:
raise ValueError(
"`minibatch_buffer_size/replay_proportion` not valid on new API "
"stack with APPO! "
"Use `circular_buffer_num_batches` for the number of train batches "
"in the circular buffer. To change the maximum number of times "
"any batch may be sampled, set "
"`circular_buffer_iterations_per_batch`."
)
if self.num_multi_gpu_tower_stacks != 1:
raise ValueError(
"`num_multi_gpu_tower_stacks` not supported on new API stack with "
"APPO! In order to train on multi-GPU, use "
"`config.learners(num_learners=[number of GPUs], "
"num_gpus_per_learner=1)`. To scale the throughput of batch-to-GPU-"
"pre-loading on each of your `Learners`, set "
"`num_gpu_loader_threads` to a higher number (recommended values: "
"1-8)."
)
if self.learner_queue_size != 16:
raise ValueError(
"`learner_queue_size` not supported on new API stack with "
"APPO! In order set the size of the circular buffer (which acts as "
"a 'learner queue'), use "
"`config.training(circular_buffer_num_batches=..)`. To change the "
"maximum number of times any batch may be sampled, set "
"`config.training(circular_buffer_iterations_per_batch=..)`."
)

@override(IMPALAConfig)
def get_default_learner_class(self):
if self.framework_str == "torch":
Expand Down
6 changes: 6 additions & 0 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional

from ray.rllib.algorithms.appo.appo import APPOConfig
from ray.rllib.algorithms.appo.utils import CircularBuffer
from ray.rllib.algorithms.impala.impala_learner import IMPALALearner
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.utils import update_target_network
Expand All @@ -28,6 +29,11 @@ class APPOLearner(IMPALALearner):

@override(IMPALALearner)
def build(self):
self._learner_thread_in_queue = CircularBuffer(
num_batches=self.config.circular_buffer_num_batches,
iterations_per_batch=self.config.circular_buffer_iterations_per_batch,
)

super().build()

# Make target networks.
Expand Down
89 changes: 88 additions & 1 deletion rllib/algorithms/appo/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,99 @@
"""
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Luo et al. 2020
https://arxiv.org/pdf/1912.00167
"""
from collections import deque
import random
import threading
import time

from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import OldAPIStack


POLICY_SCOPE = "func"
TARGET_POLICY_SCOPE = "target_func"


# TODO (sven): Deprecate once APPO and IMPALA fully on RLModules/Learner APIs.
class CircularBuffer:
"""A circular batch-wise buffer as described in [1] for APPO.
The buffer holds at most N batches, which are sampled at random (uniformly).
If full and a new batch is added, the oldest batch is discarded. Also, each batch
currently in the buffer can be sampled at most K times (after which it is also
discarded).
"""

def __init__(self, num_batches: int, iterations_per_batch: int):
# N from the paper (buffer size).
self.num_batches = num_batches
# K ("replay coefficient") from the paper.
self.iterations_per_batch = iterations_per_batch

self._buffer = deque(maxlen=self.num_batches)
self._lock = threading.Lock()

# The number of valid (not expired) entries in this buffer.
self._num_valid_batches = 0

def add(self, batch):
dropped_entry = None
dropped_ts = 0

# Add buffer and k=0 information to the deque.
with self._lock:
len_ = len(self._buffer)
if len_ == self.num_batches:
dropped_entry = self._buffer[0]
self._buffer.append([batch, 0])
self._num_valid_batches += 1

# A valid entry (w/ a batch whose k has not been reach K yet) was dropped.
if dropped_entry is not None and dropped_entry[0] is not None:
dropped_ts += dropped_entry[0].env_steps() * (
self.iterations_per_batch - dropped_entry[1]
)
self._num_valid_batches -= 1

return dropped_ts

def sample(self):
k = entry = batch = None

while True:
# Only initially, the buffer may be empty -> Just wait for some time.
if len(self) == 0:
time.sleep(0.001)
continue
# Sample a random buffer index.
with self._lock:
entry = self._buffer[random.randint(0, len(self._buffer) - 1)]
batch, k = entry
# Ignore batches that have already been invalidated.
if batch is not None:
break

# Increase k += 1 for this batch.
assert k is not None
entry[1] += 1

# This batch has been exhausted (k == K) -> Invalidate it in the buffer.
if k == self.iterations_per_batch - 1:
entry[0] = None
entry[1] = None
self._num_valid_batches += 1

# Return the sampled batch.
return batch

def __len__(self) -> int:
"""Returns the number of actually valid (non-expired) batches in the buffer."""
return self._num_valid_batches


@OldAPIStack
def make_appo_models(policy) -> ModelV2:
"""Builds model and target model for APPO.
Expand Down
Loading

0 comments on commit fe9da76

Please sign in to comment.