diff --git a/rllib/BUILD b/rllib/BUILD index e2ec7386ae0a..5886d865a052 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -164,23 +164,24 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete", "torch_only"], size = "large", srcs = ["tuned_examples/appo/cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=1"] -) -py_test( - name = "learning_tests_cartpole_appo_gpu", - main = "tuned_examples/appo/cartpole_appo.py", - tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], - size = "large", - srcs = ["tuned_examples/appo/cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=0", "--num-gpus-per-learner=1"] + args = ["--as-test", "--num-learners=1", "--num-cpus=8", "--num-env-runners=6"] ) +# TODO (sven): For some weird reason, this test runs extremely slow on the CI (not on cluster, not locally) -> taking this out for now ... +# py_test( +# name = "learning_tests_cartpole_appo_gpu", +# main = "tuned_examples/appo/cartpole_appo.py", +# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], +# size = "large", +# srcs = ["tuned_examples/appo/cartpole_appo.py"], +# args = ["--as-test", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"] +# ) py_test( name = "learning_tests_cartpole_appo_multi_cpu", main = "tuned_examples/appo/cartpole_appo.py", tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], size = "large", srcs = ["tuned_examples/appo/cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=2"] + args = ["--as-test", "--num-learners=2", "--num-cpus=9", "--num-env-runners=6"] ) py_test( name = "learning_tests_cartpole_appo_multi_gpu", @@ -188,7 +189,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], size = "large", srcs = ["tuned_examples/appo/cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=2", "--num-gpus-per-learner=1"] + args = ["--as-test", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"] ) # MultiAgentCartPole py_test( @@ -197,7 +198,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete", "torch_only"], size = "large", srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=1"] + args = ["--as-test", "--num-agents=2", "--num-learners=1", "--num-cpus=8", "--num-env-runners=6"] ) py_test( name = "learning_tests_multi_agent_cartpole_appo_gpu", @@ -205,7 +206,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], size = "large", srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=6"] + args = ["--as-test", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"] ) py_test( name = "learning_tests_multi_agent_cartpole_appo_multi_cpu", @@ -213,7 +214,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], size = "large", srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=2", "--num-cpus=7"] + args = ["--as-test", "--num-agents=2", "--num-learners=2", "--num-cpus=9", "--num-env-runners=6"] ) py_test( name = "learning_tests_multi_agent_cartpole_appo_multi_gpu", @@ -221,7 +222,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], size = "large", srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7"] + args = ["--as-test", "--num-agents=2", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"] ) # StatelessCartPole py_test( @@ -230,7 +231,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], size = "large", srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=1"] + args = ["--as-test", "--num-learners=1", "--num-cpus=8", "--num-env-runners=6"] ) py_test( name = "learning_tests_stateless_cartpole_appo_gpu", @@ -238,7 +239,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], size = "large", srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1"] + args = ["--as-test", "--num-agents=2", "--num-learners=0", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"] ) py_test( name = "learning_tests_stateless_cartpole_appo_multi_cpu", @@ -246,7 +247,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], size = "large", srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=2"] + args = ["--as-test", "--num-learners=2", "--num-cpus=9", "--num-env-runners=6"] ) py_test( name = "learning_tests_stateless_cartpole_appo_multi_gpu", @@ -254,7 +255,7 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], size = "large", srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-learners=2", "--num-gpus-per-learner=1"] + args = ["--as-test", "--num-learners=2", "--num-gpus-per-learner=1", "--num-cpus=7", "--num-env-runners=6"] ) # MultiAgentStatelessCartPole # py_test( diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index 1640cf4b5338..3632ffab954b 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -1,13 +1,13 @@ -""" -Asynchronous Proximal Policy Optimization (APPO) -================================================ +"""Asynchronous Proximal Policy Optimization (APPO) -This file defines the distributed Algorithm class for the asynchronous version -of proximal policy optimization (APPO). -See `appo_[tf|torch]_policy.py` for the definition of the policy loss. +The algorithm is described in [1] (under the name of "IMPACT"): Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#appo + +[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks. +Luo et al. 2020 +https://arxiv.org/pdf/1912.00167 """ from typing import Optional, Type @@ -108,18 +108,19 @@ def __init__(self, algo_class=None): self.kl_coeff = 1.0 self.kl_target = 0.01 self.target_worker_clipping = 2.0 - # TODO (sven): Activate once v-trace sequences in non-RNN batch are solved. - # If we switch this on right now, the shuffling would destroy the rollout - # sequences (non-zero-padded!) needed in the batch for v-trace. - # self.shuffle_batch_per_epoch = True + + # Circular replay buffer settings. + # Used in [1] for discrete action tasks: + # `circular_buffer_num_batches=4` and `circular_buffer_iterations_per_batch=2` + # For cont. action tasks: + # `circular_buffer_num_batches=16` and `circular_buffer_iterations_per_batch=20` + self.circular_buffer_num_batches = 4 + self.circular_buffer_iterations_per_batch = 2 # Override some of IMPALAConfig's default values with APPO-specific values. self.num_env_runners = 2 self.min_time_s_per_iteration = 10 self.target_network_update_freq = 1 - self.learner_queue_size = 16 - self.learner_queue_timeout = 300 - self.max_sample_requests_in_flight_per_worker = 2 self.broadcast_interval = 1 self.grad_clip = 40.0 # Note: Only when using enable_rl_module_and_learner=True can the clipping mode @@ -145,6 +146,8 @@ def __init__(self, algo_class=None): self.minibatch_buffer_size = 1 # @OldAPIStack self.replay_proportion = 0.0 # @OldAPIStack self.replay_buffer_num_slots = 100 # @OldAPIStack + self.learner_queue_size = 16 # @OldAPIStack + self.learner_queue_timeout = 300 # @OldAPIStack # Deprecated keys. self.target_update_frequency = DEPRECATED_VALUE @@ -164,6 +167,8 @@ def training( tau: Optional[float] = NotProvided, target_network_update_freq: Optional[int] = NotProvided, target_worker_clipping: Optional[float] = NotProvided, + circular_buffer_num_batches: Optional[int] = NotProvided, + circular_buffer_iterations_per_batch: Optional[int] = NotProvided, # Deprecated keys. target_update_frequency=DEPRECATED_VALUE, **kwargs, @@ -197,6 +202,14 @@ def training( target_worker_clipping: The maximum value for the target-worker-clipping used for computing the IS ratio, described in [1] IS = min(π(i) / π(target), ρ) * (π / π(i)) + circular_buffer_num_batches: The number of train batches that fit + into the circular buffer. Each such train batch can be sampled for + training max. `circular_buffer_iterations_per_batch` times. + circular_buffer_iterations_per_batch: The number of times any train + batch in the circular buffer can be sampled for training. A batch gets + evicted from the buffer either if it's the oldest batch in the buffer + and a new batch is added OR if the batch reaches this max. number of + being sampled. Returns: This updated AlgorithmConfig object. @@ -233,9 +246,50 @@ def training( self.target_network_update_freq = target_network_update_freq if target_worker_clipping is not NotProvided: self.target_worker_clipping = target_worker_clipping + if circular_buffer_num_batches is not NotProvided: + self.circular_buffer_num_batches = circular_buffer_num_batches + if circular_buffer_iterations_per_batch is not NotProvided: + self.circular_buffer_iterations_per_batch = ( + circular_buffer_iterations_per_batch + ) return self + @override(IMPALAConfig) + def validate(self) -> None: + super().validate() + + # On new API stack, circular buffer should be used, not `minibatch_buffer_size`. + if self.enable_rl_module_and_learner: + if self.minibatch_buffer_size != 1 or self.replay_proportion != 0.0: + raise ValueError( + "`minibatch_buffer_size/replay_proportion` not valid on new API " + "stack with APPO! " + "Use `circular_buffer_num_batches` for the number of train batches " + "in the circular buffer. To change the maximum number of times " + "any batch may be sampled, set " + "`circular_buffer_iterations_per_batch`." + ) + if self.num_multi_gpu_tower_stacks != 1: + raise ValueError( + "`num_multi_gpu_tower_stacks` not supported on new API stack with " + "APPO! In order to train on multi-GPU, use " + "`config.learners(num_learners=[number of GPUs], " + "num_gpus_per_learner=1)`. To scale the throughput of batch-to-GPU-" + "pre-loading on each of your `Learners`, set " + "`num_gpu_loader_threads` to a higher number (recommended values: " + "1-8)." + ) + if self.learner_queue_size != 16: + raise ValueError( + "`learner_queue_size` not supported on new API stack with " + "APPO! In order set the size of the circular buffer (which acts as " + "a 'learner queue'), use " + "`config.training(circular_buffer_num_batches=..)`. To change the " + "maximum number of times any batch may be sampled, set " + "`config.training(circular_buffer_iterations_per_batch=..)`." + ) + @override(IMPALAConfig) def get_default_learner_class(self): if self.framework_str == "torch": diff --git a/rllib/algorithms/appo/appo_learner.py b/rllib/algorithms/appo/appo_learner.py index 7b4cf2b14d8f..920d7b7ea992 100644 --- a/rllib/algorithms/appo/appo_learner.py +++ b/rllib/algorithms/appo/appo_learner.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional from ray.rllib.algorithms.appo.appo import APPOConfig +from ray.rllib.algorithms.appo.utils import CircularBuffer from ray.rllib.algorithms.impala.impala_learner import IMPALALearner from ray.rllib.core.learner.learner import Learner from ray.rllib.core.learner.utils import update_target_network @@ -28,6 +29,11 @@ class APPOLearner(IMPALALearner): @override(IMPALALearner) def build(self): + self._learner_thread_in_queue = CircularBuffer( + num_batches=self.config.circular_buffer_num_batches, + iterations_per_batch=self.config.circular_buffer_iterations_per_batch, + ) + super().build() # Make target networks. diff --git a/rllib/algorithms/appo/utils.py b/rllib/algorithms/appo/utils.py index cbd2efe82161..9a4f1e66d0a9 100644 --- a/rllib/algorithms/appo/utils.py +++ b/rllib/algorithms/appo/utils.py @@ -1,12 +1,99 @@ +""" +[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks. +Luo et al. 2020 +https://arxiv.org/pdf/1912.00167 +""" +from collections import deque +import random +import threading +import time + from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import OldAPIStack POLICY_SCOPE = "func" TARGET_POLICY_SCOPE = "target_func" -# TODO (sven): Deprecate once APPO and IMPALA fully on RLModules/Learner APIs. +class CircularBuffer: + """A circular batch-wise buffer as described in [1] for APPO. + + The buffer holds at most N batches, which are sampled at random (uniformly). + If full and a new batch is added, the oldest batch is discarded. Also, each batch + currently in the buffer can be sampled at most K times (after which it is also + discarded). + """ + + def __init__(self, num_batches: int, iterations_per_batch: int): + # N from the paper (buffer size). + self.num_batches = num_batches + # K ("replay coefficient") from the paper. + self.iterations_per_batch = iterations_per_batch + + self._buffer = deque(maxlen=self.num_batches) + self._lock = threading.Lock() + + # The number of valid (not expired) entries in this buffer. + self._num_valid_batches = 0 + + def add(self, batch): + dropped_entry = None + dropped_ts = 0 + + # Add buffer and k=0 information to the deque. + with self._lock: + len_ = len(self._buffer) + if len_ == self.num_batches: + dropped_entry = self._buffer[0] + self._buffer.append([batch, 0]) + self._num_valid_batches += 1 + + # A valid entry (w/ a batch whose k has not been reach K yet) was dropped. + if dropped_entry is not None and dropped_entry[0] is not None: + dropped_ts += dropped_entry[0].env_steps() * ( + self.iterations_per_batch - dropped_entry[1] + ) + self._num_valid_batches -= 1 + + return dropped_ts + + def sample(self): + k = entry = batch = None + + while True: + # Only initially, the buffer may be empty -> Just wait for some time. + if len(self) == 0: + time.sleep(0.001) + continue + # Sample a random buffer index. + with self._lock: + entry = self._buffer[random.randint(0, len(self._buffer) - 1)] + batch, k = entry + # Ignore batches that have already been invalidated. + if batch is not None: + break + + # Increase k += 1 for this batch. + assert k is not None + entry[1] += 1 + + # This batch has been exhausted (k == K) -> Invalidate it in the buffer. + if k == self.iterations_per_batch - 1: + entry[0] = None + entry[1] = None + self._num_valid_batches += 1 + + # Return the sampled batch. + return batch + + def __len__(self) -> int: + """Returns the number of actually valid (non-expired) batches in the buffer.""" + return self._num_valid_batches + + +@OldAPIStack def make_appo_models(policy) -> ModelV2: """Builds model and target model for APPO. diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index c38315d543b7..1929f9f010d6 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -3,11 +3,12 @@ import queue import threading import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Union import tree # pip install dm_tree import ray +from ray.rllib.algorithms.appo.utils import CircularBuffer from ray.rllib.algorithms.impala.impala import LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import Learner @@ -71,7 +72,7 @@ def build(self) -> None: ): self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate()) # Leave all batches on the CPU (they'll be moved to the GPU, if applicable, - # by the n GPU loader threads. + # by the n GPU loader threads). numpy_to_tensor_connector = self._learner_connector[NumpyToTensor][0] numpy_to_tensor_connector._device = "cpu" # TODO (sven): Provide API? @@ -80,7 +81,9 @@ def build(self) -> None: # on the "update queue" for the actual RLModule forward pass and loss # computations. self._gpu_loader_in_queue = queue.Queue() - self._learner_thread_in_queue = deque(maxlen=self.config.learner_queue_size) + # Default is to have a learner thread. + if not hasattr(self, "_learner_thread_in_queue"): + self._learner_thread_in_queue = deque(maxlen=self.config.learner_queue_size) self._learner_thread_out_queue = queue.Queue() # Create and start the GPU loader thread(s). @@ -103,9 +106,6 @@ def build(self) -> None: in_queue=self._learner_thread_in_queue, out_queue=self._learner_thread_out_queue, metrics_logger=self.metrics, - num_epochs=self.config.num_epochs, - minibatch_size=self.config.minibatch_size, - shuffle_batch_per_epoch=self.config.shuffle_batch_per_epoch, ) self._learner_thread.start() @@ -115,13 +115,6 @@ def update_from_episodes( episodes: List[EpisodeType], *, timesteps: Dict[str, Any], - # TODO (sven): Deprecate these in favor of config attributes for only those - # algos that actually need (and know how) to do minibatching. - minibatch_size: Optional[int] = None, - num_epochs: int = 1, - shuffle_batch_per_epoch: bool = False, - num_total_minibatches: int = 0, - reduce_fn=None, # Deprecated args. **kwargs, ) -> ResultDict: self.metrics.set_value( @@ -175,15 +168,25 @@ def update_from_episodes( self._gpu_loader_in_queue.qsize(), ) else: - # Enqueue to Learner thread's in-queue. - _LearnerThread.enqueue( - self._learner_thread_in_queue, - MultiAgentBatch( - {mid: SampleBatch(b) for mid, b in batch.items()}, - env_steps=env_steps, - ), - self.metrics, + ma_batch = MultiAgentBatch( + {mid: SampleBatch(b) for mid, b in batch.items()}, + env_steps=env_steps, ) + # Add the batch directly to the circular buffer. + if isinstance(self._learner_thread_in_queue, CircularBuffer): + ts_dropped = self._learner_thread_in_queue.add(ma_batch) + self.metrics.log_value( + (ALL_MODULES, LEARNER_THREAD_ENV_STEPS_DROPPED), + ts_dropped, + reduce="sum", + ) + else: + # Enqueue to Learner thread's in-queue. + _LearnerThread.enqueue( + self._learner_thread_in_queue, + ma_batch, + self.metrics, + ) # Return all queued result dicts thus far (after reducing over them). results = {} @@ -263,8 +266,17 @@ def _step(self) -> None: policy_batches={mid: SampleBatch(b) for mid, b in batch_on_gpu.items()}, env_steps=env_steps, ) - # Enqueue to Learner thread's in-queue. - _LearnerThread.enqueue(self._out_queue, ma_batch_on_gpu, self.metrics) + + if isinstance(self._out_queue, CircularBuffer): + ts_dropped = self._out_queue.add(ma_batch_on_gpu) + self.metrics.log_value( + (ALL_MODULES, LEARNER_THREAD_ENV_STEPS_DROPPED), + ts_dropped, + reduce="sum", + ) + else: + # Enqueue to Learner thread's in-queue. + _LearnerThread.enqueue(self._out_queue, ma_batch_on_gpu, self.metrics) class _LearnerThread(threading.Thread): @@ -275,9 +287,6 @@ def __init__( in_queue: deque, out_queue: queue.Queue, metrics_logger, - num_epochs, - minibatch_size, - shuffle_batch_per_epoch, ): super().__init__() self.daemon = True @@ -285,13 +294,9 @@ def __init__( self.stopped = False self._update_method = update_method - self._in_queue: deque = in_queue + self._in_queue: Union[deque, CircularBuffer] = in_queue self._out_queue: queue.Queue = out_queue - self._num_epochs = num_epochs - self._minibatch_size = minibatch_size - self._shuffle_batch_per_epoch = shuffle_batch_per_epoch - def run(self) -> None: while not self.stopped: self.step() @@ -299,14 +304,19 @@ def run(self) -> None: def step(self): # Get a new batch from the GPU-data (deque.pop -> newest item first). with self.metrics.log_time((ALL_MODULES, LEARNER_THREAD_IN_QUEUE_WAIT_TIMER)): - if not self._in_queue: - time.sleep(0.001) - return - # Consume from the left (oldest batches first). - # If we consumed from the right, we would run into the danger of learning - # from newer batches (left side) most times, BUT sometimes grabbing a - # really old batches (right area of deque). - ma_batch_on_gpu = self._in_queue.popleft() + # Get a new batch from the GPU-data (learner queue OR circular buffer). + if isinstance(self._in_queue, CircularBuffer): + ma_batch_on_gpu = self._in_queue.sample() + else: + # Queue is empty: Sleep a tiny bit to avoid CPU-thrashing. + if not self._in_queue: + time.sleep(0.001) + return + # Consume from the left (oldest batches first). + # If we consumed from the right, we would run into the danger of + # learning from newer batches (left side) most times, BUT sometimes + # grabbing older batches (right area of deque). + ma_batch_on_gpu = self._in_queue.popleft() # Call the update method on the batch. with self.metrics.log_time((ALL_MODULES, LEARNER_THREAD_UPDATE_TIMER)): @@ -321,9 +331,6 @@ def step(self): (ALL_MODULES, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0 ) }, - num_epochs=self._num_epochs, - minibatch_size=self._minibatch_size, - shuffle_batch_per_epoch=self._shuffle_batch_per_epoch, ) # We have to deepcopy the results dict, b/c we must avoid having a returned # Stats object sit in the queue and getting a new (possibly even tensor) diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index c26cd6a22a94..2b58743a52df 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -1350,15 +1350,6 @@ def _update_from_batch_or_episodes( {next(iter(self.module.keys())): batch}, env_steps=len(batch) ) - # TODO (sven): Remove this leftover hack here for the situation in which we - # did not go through the learner connector. - # Options: - # a) Either also pass given batches through the learner connector (even if - # episodes is None). (preferred solution) - # b) Get rid of the option to pass in a batch altogether. - # if episodes is None: - # batch = self._convert_batch_type(batch) - # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs # found in this batch. If not, throw an error. unknown_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys()) diff --git a/rllib/tuned_examples/appo/cartpole_appo.py b/rllib/tuned_examples/appo/cartpole_appo.py index 0af651b6c607..a85a9120ba2a 100644 --- a/rllib/tuned_examples/appo/cartpole_appo.py +++ b/rllib/tuned_examples/appo/cartpole_appo.py @@ -16,6 +16,7 @@ APPOConfig() .environment("CartPole-v1") .training( + circular_buffer_iterations_per_batch=2, vf_loss_coeff=0.05, entropy_coeff=0.0, )