diff --git a/doc/source/rllib/package_ref/learner.rst b/doc/source/rllib/package_ref/learner.rst index b7fcfa681902..56178065a45b 100644 --- a/doc/source/rllib/package_ref/learner.rst +++ b/doc/source/rllib/package_ref/learner.rst @@ -72,7 +72,8 @@ Performing Updates :nosignatures: :toctree: doc/ - Learner.update + Learner.update_from_batch + Learner.update_from_episodes Learner._update Learner.additional_update Learner.additional_update_for_module diff --git a/doc/source/rllib/rllib-learner.rst b/doc/source/rllib/rllib-learner.rst index 7ef5d642b140..c67b54c02d2b 100644 --- a/doc/source/rllib/rllib-learner.rst +++ b/doc/source/rllib/rllib-learner.rst @@ -229,19 +229,23 @@ Updates .. testcode:: - # This is a blocking update - results = learner_group.update(DUMMY_BATCH) + # This is a blocking update. + results = learner_group.update_from_batch(batch=DUMMY_BATCH) # This is a non-blocking update. The results are returned in a future - # call to `async_update` - _ = learner_group.async_update(DUMMY_BATCH) + # call to `update_from_batch(..., async_update=True)` + _ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True) # Artificially wait for async request to be done to get the results - # in the next call to `LearnerGroup.async_update()`. + # in the next call to + # `LearnerGroup.update_from_batch(..., async_update=True)`. time.sleep(5) - results = learner_group.async_update(DUMMY_BATCH) + results = learner_group.update_from_batch( + batch=DUMMY_BATCH, async_update=True + ) # `results` is a list of results dict. The items in the list represent the different - # remote results from the different calls to `async_update()`. + # remote results from the different calls to + # `update_from_batch(..., async_update=True)`. assert len(results) > 0 # Each item is a results dict, already reduced over the n Learner workers. assert isinstance(results[0], dict), results[0] @@ -256,8 +260,8 @@ Updates .. testcode:: - # This is a blocking update. - result = learner.update(DUMMY_BATCH) + # This is a blocking update (given a training batch). + result = learner.update_from_batch(batch=DUMMY_BATCH) # This is an additional non-gradient based update. learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index adc253241ffb..008280e61729 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -5,12 +5,12 @@ import functools import gymnasium as gym import importlib +import importlib.metadata import json import logging import numpy as np import os from packaging import version -import importlib.metadata import re import tempfile import time @@ -1607,7 +1607,7 @@ def training_step(self) -> ResultDict: # TODO: (sven) rename MultiGPUOptimizer into something more # meaningful. if self.config._enable_new_api_stack: - train_results = self.learner_group.update(train_batch) + train_results = self.learner_group.update_from_batch(batch=train_batch) elif self.config.get("simple_optimizer") is True: train_results = train_one_step(self, train_batch) else: diff --git a/rllib/algorithms/appo/tests/test_appo_learner.py b/rllib/algorithms/appo/tests/test_appo_learner.py index fe15b84e02c1..f8188c58496c 100644 --- a/rllib/algorithms/appo/tests/test_appo_learner.py +++ b/rllib/algorithms/appo/tests/test_appo_learner.py @@ -97,7 +97,7 @@ def test_appo_loss(self): env=algo.workers.local_worker().env ) learner_group.set_weights(algo.get_weights()) - learner_group.update(train_batch.as_multi_agent()) + learner_group.update_from_batch(batch=train_batch.as_multi_agent()) algo.stop() diff --git a/rllib/algorithms/bc/bc.py b/rllib/algorithms/bc/bc.py index 569f00632dc7..29e2693f159c 100644 --- a/rllib/algorithms/bc/bc.py +++ b/rllib/algorithms/bc/bc.py @@ -171,7 +171,7 @@ def training_step(self) -> ResultDict: self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() # Updating the policy. - train_results = self.learner_group.update(train_batch) + train_results = self.learner_group.update_from_batch(batch=train_batch) # Synchronize weights. # As the results contain for each policy the loss and in addition the diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index a3cf4eb9c4f4..7df306c8540c 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -606,8 +606,8 @@ def training_step(self) -> ResultDict: ) # Perform the actual update via our learner group. - train_results = self.learner_group.update( - SampleBatch(sample).as_multi_agent(), + train_results = self.learner_group.update_from_batch( + batch=SampleBatch(sample).as_multi_agent(), reduce_fn=self._reduce_results, ) self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps diff --git a/rllib/algorithms/dreamerv3/utils/summaries.py b/rllib/algorithms/dreamerv3/utils/summaries.py index 4cc17de41c0a..2e670edbf411 100644 --- a/rllib/algorithms/dreamerv3/utils/summaries.py +++ b/rllib/algorithms/dreamerv3/utils/summaries.py @@ -133,7 +133,8 @@ def report_predicted_vs_sampled_obs( Continues: Compute MSE (sampled vs predicted). Args: - results: The results dict that was returned by `LearnerGroup.update()`. + results: The results dict that was returned by + `LearnerGroup.update_from_batch()`. sample: The sampled data (dict) from the replay buffer. Already tf-tensor converted. batch_size_B: The batch size (B). This is the number of trajectories sampled diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 0f29ba3939d7..cc2c2ea6f079 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -946,24 +946,18 @@ def learn_on_processed_samples(self) -> ResultDict: self.batches_to_place_on_learner.clear() # If there are no learner workers and learning is directly on the driver # Then we can't do async updates, so we need to block. - blocking = self.config.num_learner_workers == 0 + async_update = self.config.num_learner_workers > 0 results = [] for batch in batches: - if blocking: - result = self.learner_group.update( - batch, - reduce_fn=_reduce_impala_results, - num_iters=self.config.num_sgd_iter, - minibatch_size=self.config.minibatch_size, - ) + result = self.learner_group.update_from_batch( + batch=batch, + async_update=async_update, + reduce_fn=_reduce_impala_results, + num_iters=self.config.num_sgd_iter, + minibatch_size=self.config.minibatch_size, + ) + if not async_update: results = [result] - else: - results = self.learner_group.async_update( - batch, - reduce_fn=_reduce_impala_results, - num_iters=self.config.num_sgd_iter, - minibatch_size=self.config.minibatch_size, - ) for r in results: self._counters[NUM_ENV_STEPS_TRAINED] += r[ALL_MODULES].pop( @@ -973,14 +967,14 @@ def learn_on_processed_samples(self) -> ResultDict: NUM_AGENT_STEPS_TRAINED ) - self._counters.update(self.learner_group.get_in_queue_stats()) + self._counters.update(self.learner_group.get_stats()) # If there are results, reduce-mean over each individual value and return. if results: return tree.map_structure(lambda *x: np.mean(x), *results) # Nothing on the queue -> Don't send requests to learner group - # or no results ready (from previous `self.learner_group.update()` calls) for - # reducing. + # or no results ready (from previous `self.learner_group.update_from_batch()` + # calls) for reducing. return {} def place_processed_samples_on_learner_thread_queue(self) -> None: diff --git a/rllib/algorithms/impala/tests/test_impala_learner.py b/rllib/algorithms/impala/tests/test_impala_learner.py index b22d25553e54..59159f7ff4f7 100644 --- a/rllib/algorithms/impala/tests/test_impala_learner.py +++ b/rllib/algorithms/impala/tests/test_impala_learner.py @@ -94,7 +94,7 @@ def test_impala_loss(self): env=algo.workers.local_worker().env ) learner_group.set_weights(algo.get_weights()) - learner_group.update(train_batch.as_multi_agent()) + learner_group.update_from_batch(batch=train_batch.as_multi_agent()) algo.stop() diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 9f9605312e2e..498baa73a5bc 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -424,8 +424,8 @@ def training_step(self) -> ResultDict: if self.config._enable_new_api_stack: # TODO (Kourosh) Clearly define what train_batch_size # vs. sgd_minibatch_size and num_sgd_iter is in the config. - train_results = self.learner_group.update( - train_batch, + train_results = self.learner_group.update_from_batch( + batch=train_batch, minibatch_size=self.config.sgd_minibatch_size, num_iters=self.config.num_sgd_iter, ) diff --git a/rllib/algorithms/ppo/tests/test_ppo_learner.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py index 8659cc5da0f2..2ebac016d67f 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_learner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -101,7 +101,7 @@ def test_loss(self): # Load the algo weights onto the learner_group. learner_group.set_weights(algo.get_weights()) - learner_group.update(train_batch.as_multi_agent()) + learner_group.update_from_batch(batch=train_batch.as_multi_agent()) algo.stop() diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index bb832a57be03..37fcd66fac64 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -225,7 +225,7 @@ def get_value(): assert init_std == 0.0, init_std batch = compute_gae_for_sample_batch(policy, PENDULUM_FAKE_BATCH.copy()) batch = policy._lazy_tensor_dict(batch) - algo.learner_group.update(batch.as_multi_agent()) + algo.learner_group.update_from_batch(batch=batch.as_multi_agent()) # Check the variable is updated. post_std = get_value() diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index 801cc9f83847..fb76293b2a18 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -48,6 +48,7 @@ from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.serialization import serialize_type from ray.rllib.utils.typing import ( + EpisodeType, LearningRateOrSchedule, ModuleID, Optimizer, @@ -1097,27 +1098,26 @@ def additional_update_for_module( return results - def update( + def update_from_batch( self, batch: MultiAgentBatch, *, - minibatch_size: Optional[int] = None, - num_iters: int = 1, reduce_fn: Callable[[List[Dict[str, Any]]], ResultDict] = ( _reduce_mean_results ), + # TODO (sven): Deprecate these in favor of config attributes for only those + # algos that actually need (and know how) to do minibatching. + minibatch_size: Optional[int] = None, + num_iters: int = 1, ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - """Do `num_iters` minibatch updates given the original batch. + """Do `num_iters` minibatch updates given a train batch. - Given a batch of episodes you can use this method to take more - than one backward pass on the batch. The same minibatch_size and num_iters - will be used for all module ids in MultiAgentRLModule. + You can use this method to take more than one backward pass on the batch. + The same `minibatch_size` and `num_iters` will be used for all module ids in + MultiAgentRLModule. Args: - batch: A batch of data. - minibatch_size: The size of the minibatch to use for each update. - num_iters: The number of complete passes over all the sub-batches - in the input multi-agent batch. + batch: A batch of training data to update from. reduce_fn: reduce_fn: A function to reduce the results from a list of minibatch updates. This can be any arbitrary function that takes a list of dictionaries and returns a single dictionary. For example you @@ -1125,81 +1125,97 @@ def update( example for metrics) or be more selective about you want to report back to the algorithm's training_step. If None is passed, the results will not get reduced. + minibatch_size: The size of the minibatch to use for each update. + num_iters: The number of complete passes over all the sub-batches + in the input multi-agent batch. + Returns: A dictionary of results, in numpy format or a list of such dictionaries in case `reduce_fn` is None and we have more than one minibatch pass. """ - self._check_is_built() - - missing_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys()) - if len(missing_module_ids) > 0: - raise ValueError( - "Batch contains module ids that are not in the learner: " - f"{missing_module_ids}" - ) + return self._update_from_batch_or_episodes( + batch=batch, + episodes=None, + reduce_fn=reduce_fn, + minibatch_size=minibatch_size, + num_iters=num_iters, + ) - if num_iters < 1: - # We must do at least one pass on the batch for training. - raise ValueError("`num_iters` must be >= 1") + def update_from_episodes( + self, + episodes: List[EpisodeType], + *, + reduce_fn: Callable[[List[Dict[str, Any]]], ResultDict] = ( + _reduce_mean_results + ), + # TODO (sven): Deprecate these in favor of config attributes for only those + # algos that actually need (and know how) to do minibatching. + minibatch_size: Optional[int] = None, + num_iters: int = 1, + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """Do `num_iters` minibatch updates given a list of episodes. - if minibatch_size: - batch_iter = MiniBatchCyclicIterator - elif num_iters > 1: - # `minibatch_size` was not set but `num_iters` > 1. - # Under the old training stack, users could do multiple sgd passes - # over a batch without specifying a minibatch size. We enable - # this behavior here by setting the minibatch size to be the size - # of the batch (e.g. 1 minibatch of size batch.count) - minibatch_size = batch.count - batch_iter = MiniBatchCyclicIterator - else: - # `minibatch_size` and `num_iters` are not set by the user. - batch_iter = MiniBatchDummyIterator + You can use this method to take more than one backward pass on the batch. + The same `minibatch_size` and `num_iters` will be used for all module ids in + MultiAgentRLModule. - results = [] - # Convert input batch into a tensor batch (MultiAgentBatch) on the correct - # device (e.g. GPU). We move the batch already here to avoid having to move - # every single minibatch that is created in the `batch_iter` below. - batch = self._convert_batch_type(batch) - batch = self._set_slicing_by_batch_id(batch, value=True) - - for tensor_minibatch in batch_iter(batch, minibatch_size, num_iters): - # Make the actual in-graph/traced `_update` call. This should return - # all tensor values (no numpy). - nested_tensor_minibatch = NestedDict(tensor_minibatch.policy_batches) - ( - fwd_out, - loss_per_module, - metrics_per_module, - ) = self._update(nested_tensor_minibatch) + Args: + episodes: An list of episode objects to update from. + reduce_fn: reduce_fn: A function to reduce the results from a list of + minibatch updates. This can be any arbitrary function that takes a + list of dictionaries and returns a single dictionary. For example you + can either take an average (default) or concatenate the results (for + example for metrics) or be more selective about you want to report back + to the algorithm's training_step. If None is passed, the results will + not get reduced. + minibatch_size: The size of the minibatch to use for each update. + num_iters: The number of complete passes over all the sub-batches + in the input multi-agent batch. - result = self.compile_results( - batch=tensor_minibatch, - fwd_out=fwd_out, - loss_per_module=loss_per_module, - metrics_per_module=defaultdict(dict, **metrics_per_module), - ) - self._check_result(result) - # TODO (sven): Figure out whether `compile_metrics` should be forced - # to return all numpy/python data, then we can skip this conversion - # step here. - results.append(convert_to_numpy(result)) + Returns: + A dictionary of results, in numpy format or a list of such dictionaries in + case `reduce_fn` is None and we have more than one minibatch pass. + """ + return self._update_from_batch_or_episodes( + batch=None, + episodes=episodes, + reduce_fn=reduce_fn, + minibatch_size=minibatch_size, + num_iters=num_iters, + ) - self._set_slicing_by_batch_id(batch, value=False) + @OverrideToImplementCustomLogic + def _preprocess_train_data( + self, + *, + batch: Optional[MultiAgentBatch] = None, + episodes: Optional[List[EpisodeType]] = None, + ) -> Tuple[Optional[MultiAgentBatch], Optional[List[EpisodeType]]]: + """Allows custom preprocessing of batch/episode data before the actual update. + + The higher level order, in which this method is called from within + `Learner.update(batch, episodes)` is: + * batch, episodes = self._preprocess_train_data(batch, episodes) + * batch = self._learner_connector(batch, episodes) + * results = self._update(batch) + + The default implementation does not do any processing and is a mere pass + through. However, specific algorithms should override this method to implement + their specific training data preprocessing needs. It is possible to perform + preliminary RLModule forward passes (besides the main "forward_train()" call + during `self._update`) in this method and custom algorithms might also want to + use this Learner's `self._learner_connector` to prepare the data + (batch/episodes) for such extra forward calls. - # Reduce results across all minibatches, if necessary. + Args: + batch: An optional batch of training data to preprocess. + episodes: An optional list of episodes objects to preprocess. - # If we only have one result anyways, then the user will not expect a list - # to be reduced here (and might not provide a `reduce_fn` therefore) -> - # Return single results dict. - if len(results) == 1: - return results[0] - # If no `reduce_fn` provided, return list of results dicts. - elif reduce_fn is None: - return results - # Pass list of results dicts through `reduce_fn` and return a single results - # dict. - return reduce_fn(results) + Returns: + A tuple consisting of the processed `batch` and the processed list of + `episodes`. + """ + return batch, episodes @OverrideToImplementCustomLogic @abc.abstractmethod @@ -1284,6 +1300,105 @@ def get_optimizer_state(self) -> Dict[str, Any]: """ raise NotImplementedError + def _update_from_batch_or_episodes( + self, + *, + # TODO (sven): We should allow passing in a single agent batch here + # as well for simplicity. + batch: Optional[MultiAgentBatch] = None, + episodes: Optional[List[EpisodeType]] = None, + reduce_fn: Callable[[List[Dict[str, Any]]], ResultDict] = ( + _reduce_mean_results + ), + # TODO (sven): Deprecate these in favor of config attributes for only those + # algos that actually need (and know how) to do minibatching. + minibatch_size: Optional[int] = None, + num_iters: int = 1, + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + self._check_is_built() + + # If a (multi-agent) batch is provided, check, whether our RLModule + # contains all ModuleIDs found in this batch. If not, throw an error. + if batch is not None: + unknown_module_ids = set(batch.policy_batches.keys()) - set( + self.module.keys() + ) + if len(unknown_module_ids) > 0: + raise ValueError( + "Batch contains module ids that are not in the learner: " + f"{unknown_module_ids}" + ) + + if num_iters < 1: + # We must do at least one pass on the batch for training. + raise ValueError("`num_iters` must be >= 1") + + # Call the train data preprocessor. + batch, episodes = self._preprocess_train_data(batch=batch, episodes=episodes) + + # TODO (sven): Insert a call to the Learner ConnectorV2 pipeline here, providing + # it both `batch` and `episode` for further custom processing before the + # actual `Learner._update()` call. + + if minibatch_size: + batch_iter = MiniBatchCyclicIterator + elif num_iters > 1: + # `minibatch_size` was not set but `num_iters` > 1. + # Under the old training stack, users could do multiple sgd passes + # over a batch without specifying a minibatch size. We enable + # this behavior here by setting the minibatch size to be the size + # of the batch (e.g. 1 minibatch of size batch.count) + minibatch_size = batch.count + batch_iter = MiniBatchCyclicIterator + else: + # `minibatch_size` and `num_iters` are not set by the user. + batch_iter = MiniBatchDummyIterator + + results = [] + # Convert input batch into a tensor batch (MultiAgentBatch) on the correct + # device (e.g. GPU). We move the batch already here to avoid having to move + # every single minibatch that is created in the `batch_iter` below. + batch = self._convert_batch_type(batch) + batch = self._set_slicing_by_batch_id(batch, value=True) + + for tensor_minibatch in batch_iter(batch, minibatch_size, num_iters): + # Make the actual in-graph/traced `_update` call. This should return + # all tensor values (no numpy). + nested_tensor_minibatch = NestedDict(tensor_minibatch.policy_batches) + ( + fwd_out, + loss_per_module, + metrics_per_module, + ) = self._update(nested_tensor_minibatch) + + result = self.compile_results( + batch=tensor_minibatch, + fwd_out=fwd_out, + loss_per_module=loss_per_module, + metrics_per_module=defaultdict(dict, **metrics_per_module), + ) + self._check_result(result) + # TODO (sven): Figure out whether `compile_results` should be forced + # to return all numpy/python data, then we can skip this conversion + # step here. + results.append(convert_to_numpy(result)) + + self._set_slicing_by_batch_id(batch, value=False) + + # Reduce results across all minibatches, if necessary. + + # If we only have one result anyways, then the user will not expect a list + # to be reduced here (and might not provide a `reduce_fn` therefore) -> + # Return single results dict. + if len(results) == 1: + return results[0] + # If no `reduce_fn` provided, return list of results dicts. + elif reduce_fn is None: + return results + # Pass list of results dicts through `reduce_fn` and return a single results + # dict. + return reduce_fn(results) + def _set_slicing_by_batch_id( self, batch: MultiAgentBatch, *, value: bool ) -> MultiAgentBatch: diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index d072db47819c..a812f3334963 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -1,4 +1,4 @@ -from collections import defaultdict, deque +from collections import defaultdict, Counter from functools import partial import pathlib from typing import ( @@ -9,13 +9,13 @@ Optional, Set, Type, - TYPE_CHECKING, Union, ) import uuid import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.core.learner.learner import Learner from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results from ray.rllib.core.rl_module.rl_module import ( SingleAgentRLModuleSpec, @@ -24,9 +24,13 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.actor_manager import FaultTolerantActorManager from ray.rllib.utils.deprecation import Deprecated, deprecation_warning -from ray.rllib.utils.minibatch_utils import ShardBatchIterator +from ray.rllib.utils.minibatch_utils import ( + ShardBatchIterator, + ShardEpisodesIterator, +) from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.typing import ( + EpisodeType, ModuleID, ResultDict, RLModuleSpec, @@ -37,11 +41,7 @@ from ray.util.annotations import PublicAPI -if TYPE_CHECKING: - from ray.rllib.core.learner.learner import Learner - - -def _get_backend_config(learner_class: Type["Learner"]) -> str: +def _get_backend_config(learner_class: Type[Learner]) -> str: if learner_class.framework == "torch": from ray.train.torch import TorchConfig @@ -71,7 +71,8 @@ def _default_should_module_be_updated_fn( class LearnerGroup: """Coordinator of n (possibly remote) Learner workers. - Each Learner worker + Each Learner worker has a copy of the RLModule, the loss function(s), and + one or more optimizers. """ def __init__( @@ -134,14 +135,13 @@ def __init__( self._should_module_be_updated_fn = _default_should_module_be_updated_fn # How many timesteps had to be dropped due to a full input queue? - self._in_queue_ts_dropped = 0 + self._ts_dropped = 0 # A single local Learner. if not self.is_remote: self._learner = learner_class(config=config, module_spec=module_spec) self._learner.build() self._worker_manager = None - self._in_queue = [] # N remote Learner workers. else: backend_config = _get_backend_config(learner_class) @@ -182,17 +182,21 @@ def __init__( # an async algo, remove this restriction entirely. max_remote_requests_in_flight_per_actor=3, ) - # This is a list of the tags for asynchronous update requests that are - # inflight, and is used for grouping together the results of requests - # that were sent to the workers at the same time. - self._inflight_request_tags: Set[str] = set() - self._in_queue = deque(maxlen=max_queue_len) + # Counters for the tags for asynchronous update requests that are + # in-flight. Used for keeping trakc of and grouping together the results of + # requests that were sent to the workers at the same time. + self._update_request_tags = Counter() + self._additional_update_request_tags = Counter() - def get_in_queue_stats(self) -> Dict[str, Any]: + def get_stats(self) -> Dict[str, Any]: """Returns the current stats for the input queue for this learner group.""" return { - "learner_group_queue_size": len(self._in_queue), - "learner_group_queue_ts_dropped": self._in_queue_ts_dropped, + "learner_group_ts_dropped": self._ts_dropped, + "actor_manager_num_outstanding_async_reqs": ( + 0 + if self.is_local + else self._worker_manager.num_outstanding_async_reqs() + ), } @property @@ -203,187 +207,228 @@ def is_remote(self) -> bool: def is_local(self) -> bool: return not self.is_remote - def update( + def update_from_batch( self, batch: MultiAgentBatch, *, - minibatch_size: Optional[int] = None, - num_iters: int = 1, + async_update: bool = False, reduce_fn: Optional[Callable[[List[Dict[str, Any]]], ResultDict]] = ( _reduce_mean_results ), - ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - """Do one or more gradient based updates to the Learner(s) based on given data. + # TODO (sven): Deprecate the following args. They should be extracted from the + # LearnerHyperparameters of those specific algorithms that actually require + # these settings. + minibatch_size: Optional[int] = None, + num_iters: int = 1, + ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + """Performs gradient based update(s) on the Learner(s), based on given batch. Args: - batch: The data batch to use for the update. - minibatch_size: The minibatch size to use for the update. - num_iters: The number of complete passes over all the sub-batches in the - input multi-agent batch. + batch: A data batch to use for the update. If there are more + than one Learner workers, the batch is split amongst these and one + shard is sent to each Learner. + async_update: Whether the update request(s) to the Learner workers should be + sent asynchronously. If True, will return NOT the results from the + update on the given data, but all results from prior asynchronous update + requests that have not been returned thus far. reduce_fn: An optional callable to reduce the results from a list of the Learner actors into a single result. This can be any arbitrary function that takes a list of dictionaries and returns a single dictionary. For - example you can either take an average (default) or concatenate the + example, you can either take an average (default) or concatenate the results (for example for metrics) or be more selective about you want to report back to the algorithm's training_step. If None is passed, the results will not get reduced. + minibatch_size: The minibatch size to use for the update. + num_iters: The number of complete passes over all the sub-batches in the + input multi-agent batch. Returns: - A dictionary with the reduced results of the updates from the Learner(s) or - a list of dictionaries of results from the updates from the Learner(s). + If `async_update` is False, a dictionary with the reduced results of the + updates from the Learner(s) or a list of dictionaries of results from the + updates from the Learner(s). + If `async_update` is True, a list of list of dictionaries of results, where + the outer list corresponds to separate previous calls to this method, and + the inner list corresponds to the results from each Learner(s). Or if the + results are reduced, a list of dictionaries of the reduced results from each + call to async_update that is ready. """ - # Construct a multi-agent batch with only those modules in it that should # be updated. + # TODO (sven): Move this filtering of input data into individual Learners. + # It might be that the postprocessing of batch/episodes on each Learner + # requires the non-trainable modules' data. train_batch = {} for module_id in batch.policy_batches.keys(): if self.should_module_be_updated_fn(module_id, batch): train_batch[module_id] = batch.policy_batches[module_id] train_batch = MultiAgentBatch(train_batch, batch.count) - if self.is_local: - results = [ - self._learner.update( - train_batch, - minibatch_size=minibatch_size, - num_iters=num_iters, - reduce_fn=reduce_fn, - ) - ] - else: - - def _learner_update(learner, minibatch): - return learner.update( - minibatch, - minibatch_size=minibatch_size, - num_iters=num_iters, - reduce_fn=reduce_fn, - ) - - results = self._get_results( - self._worker_manager.foreach_actor( - [ - partial(_learner_update, minibatch=minibatch) - for minibatch in ShardBatchIterator(batch, len(self._workers)) - ] - ) - ) - - # TODO(sven): Move reduce_fn to the training_step - if reduce_fn is None: - return results - else: - return reduce_fn(results) + return self._update( + batch=train_batch, + episodes=None, + async_update=async_update, + reduce_fn=reduce_fn, + minibatch_size=minibatch_size, + num_iters=num_iters, + ) - def async_update( + def update_from_episodes( self, - batch: MultiAgentBatch, + episodes: List[EpisodeType], *, - minibatch_size: Optional[int] = None, - num_iters: int = 1, + async_update: bool = False, reduce_fn: Optional[Callable[[List[Dict[str, Any]]], ResultDict]] = ( _reduce_mean_results ), - ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: - """Asnychronously do gradient based updates to the Learner(s) with `batch`. + # TODO (sven): Deprecate the following args. They should be extracted from the + # LearnerHyperparameters of those specific algorithms that actually require + # these settings. + minibatch_size: Optional[int] = None, + num_iters: int = 1, + ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + """Performs gradient based update(s) on the Learner(s), based on given episodes. Args: - batch: The data batch to use for the update. + episodes: A list of Episodes to process and perform the update + for. If there are more than one Learner workers, the list of episodes + is split amongst these and one list shard is sent to each Learner. + async_update: Whether the update request(s) to the Learner workers should be + sent asynchronously. If True, will return NOT the results from the + update on the given data, but all results from prior asynchronous update + requests that have not been returned thus far. minibatch_size: The minibatch size to use for the update. num_iters: The number of complete passes over all the sub-batches in the input multi-agent batch. reduce_fn: An optional callable to reduce the results from a list of the Learner actors into a single result. This can be any arbitrary function that takes a list of dictionaries and returns a single dictionary. For - example you can either take an average (default) or concatenate the + example, you can either take an average (default) or concatenate the results (for example for metrics) or be more selective about you want to report back to the algorithm's training_step. If None is passed, the results will not get reduced. Returns: - A list of list of dictionaries of results, where the outer list - corresponds to separate calls to `async_update`, and the inner - list corresponds to the results from each Learner(s). Or if the results - are reduced, a list of dictionaries of the reduced results from each + If async_update is False, a dictionary with the reduced results of the + updates from the Learner(s) or a list of dictionaries of results from the + updates from the Learner(s). + If async_update is True, a list of list of dictionaries of results, where + the outer list corresponds to separate previous calls to this method, and + the inner list corresponds to the results from each Learner(s). Or if the + results are reduced, a list of dictionaries of the reduced results from each call to async_update that is ready. """ - if self.is_local: - raise ValueError( - "Cannot call `async_update` when running in local mode with " - "num_workers=0." - ) - else: - if minibatch_size is not None: - minibatch_size //= len(self._workers) + return self._update( + batch=None, + episodes=episodes, + async_update=async_update, + reduce_fn=reduce_fn, + minibatch_size=minibatch_size, + num_iters=num_iters, + ) + + def _update( + self, + *, + batch: Optional[MultiAgentBatch] = None, + episodes: Optional[List[EpisodeType]] = None, + async_update: bool = False, + reduce_fn: Optional[Callable[[List[Dict[str, Any]]], ResultDict]] = ( + _reduce_mean_results + ), + minibatch_size: Optional[int] = None, + num_iters: int = 1, + ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: - def _learner_update(learner, minibatch): - return learner.update( - minibatch, + # Define function to be called on all Learner actors (or the local learner). + def _learner_update(learner: Learner, batch_shard=None, episodes_shard=None): + if batch_shard is not None: + return learner.update_from_batch( + batch=batch_shard, + reduce_fn=reduce_fn, minibatch_size=minibatch_size, num_iters=num_iters, + ) + else: + return learner.update_from_episodes( + episodes=episodes_shard, reduce_fn=reduce_fn, + minibatch_size=minibatch_size, + num_iters=num_iters, ) - # Queue the new batches. - # If queue is full, kick out the oldest item (and thus add its - # length to the "dropped ts" counter). - if len(self._in_queue) == self._in_queue.maxlen: - self._in_queue_ts_dropped += len(self._in_queue[0]) - - self._in_queue.append(batch) + if self.is_local: + if async_update: + raise ValueError( + "Cannot call `update_from_batch(update_async=True)` when running in" + " local mode! Try setting `config.num_learner_workers > 0`." + ) - # Retrieve all ready results (kicked off by prior calls to this method). - results = self._worker_manager.fetch_ready_async_reqs( - tags=list(self._inflight_request_tags) - ) - # Only if there are no more requests in-flight on any of the learners, - # we can send in one new batch for sharding and parallel learning. - if self._worker_manager_ready(): - count = 0 - # TODO (sven): This probably works even without any restriction - # (allowing for any arbitrary number of requests in-flight). Test with - # 3 first, then with unlimited, and if both show the same behavior on - # an async algo, remove this restriction entirely. - while len(self._in_queue) > 0 and count < 3: - # Pull a single batch from the queue (from the left side, meaning: - # use the oldest one first). - update_tag = str(uuid.uuid4()) - self._inflight_request_tags.add(update_tag) - batch = self._in_queue.popleft() - self._worker_manager.foreach_actor_async( - [ - partial(_learner_update, minibatch=minibatch) - for minibatch in ShardBatchIterator( - batch, len(self._workers) - ) - ], - tag=update_tag, + results = [ + _learner_update( + learner=self._learner, + batch_shard=batch, + episodes_shard=episodes, + ) + ] + else: + if episodes is None: + partials = [ + partial(_learner_update, batch_shard=batch_shard) + for batch_shard in ShardBatchIterator(batch, len(self._workers)) + ] + else: + partials = [ + partial(_learner_update, episodes_shard=episodes_shard) + for episodes_shard in ShardEpisodesIterator( + episodes, len(self._workers) + ) + ] + + if async_update: + # Retrieve all ready results (kicked off by prior calls to this method). + results = None + if self._update_request_tags: + results = self._worker_manager.fetch_ready_async_reqs( + tags=list(self._update_request_tags) ) - count += 1 - # NOTE: There is a strong assumption here that the requests launched to - # learner workers will return at the same time, since they are have a - # barrier inside of themselves for gradient aggregation. Therefore results - # should be a list of lists where each inner list should be the length of - # the number of learner workers, if results from an non-blocking update are - # ready. - results = self._get_async_results(results) + update_tag = str(uuid.uuid4()) + + num_sent_requests = self._worker_manager.foreach_actor_async( + partials, tag=update_tag + ) + + if num_sent_requests: + self._update_request_tags[update_tag] = num_sent_requests + + # Some requests were dropped, record lost ts/data. + if num_sent_requests != len(self._workers): + # assert num_sent_requests == 0, num_sent_requests + factor = 1 - (num_sent_requests / len(self._workers)) + if episodes is None: + self._ts_dropped += factor * len(batch) + else: + self._ts_dropped += factor * sum(len(e) for e in episodes) + # NOTE: There is a strong assumption here that the requests launched to + # learner workers will return at the same time, since they are have a + # barrier inside of themselves for gradient aggregation. Therefore + # results should be a list of lists where each inner list should be the + # length of the number of learner workers, if results from an + # non-blocking update are ready. + results = self._get_async_results(results) - # TODO(sven): Move reduce_fn to the training_step - if reduce_fn is None: - return results else: - return [reduce_fn(r) for r in results] - - def _worker_manager_ready(self): - # TODO (sven): This probably works even without any restriction (allowing for - # any arbitrary number of requests in-flight). Test with 3 first, then with - # unlimited, and if both show the same behavior on an async algo, remove - # this method entirely. - return ( - self._worker_manager.num_outstanding_async_reqs() - <= self._worker_manager.num_actors() * 2 - ) + results = self._get_results( + self._worker_manager.foreach_actor(partials) + ) + + # TODO (sven): Move reduce_fn to the training_step + if reduce_fn is None: + return results + elif not async_update: + return reduce_fn(results) + else: + return [reduce_fn(r) for r in results] def _get_results(self, results): processed_results = [] @@ -403,19 +448,34 @@ def _get_async_results(self, results): for same tags. """ + if results is None: + return [] + unprocessed_results = defaultdict(list) for result in results: result_or_error = result.get() if result.ok: - assert ( - result.tag - ), "Cannot call _get_async_results on untagged async requests." - unprocessed_results[result.tag].append(result_or_error) + tag = result.tag + if not tag: + raise RuntimeError( + "Cannot call `LearnerGroup._get_async_results()` on untagged " + "async requests!" + ) + unprocessed_results[tag].append(result_or_error) + + if tag in self._update_request_tags: + self._update_request_tags[tag] -= 1 + if self._update_request_tags[tag] == 0: + del self._update_request_tags[tag] + else: + assert tag in self._additional_update_request_tags + self._additional_update_request_tags[tag] -= 1 + if self._additional_update_request_tags[tag] == 0: + del self._additional_update_request_tags[tag] + else: raise result_or_error - for tag in unprocessed_results.keys(): - self._inflight_request_tags.remove(tag) return list(unprocessed_results.values()) def additional_update( @@ -519,7 +579,7 @@ def set_weights(self, weights: Dict[str, Any]) -> None: """Set the weights of the MultiAgentRLModule maintained by each Learner. The weights don't have to include all the modules in the MARLModule. - This way the weights of only some of the Agents can be set. + This way the weights of only some of the Agents can be set. Args: weights: The weights to set each RLModule in the MARLModule to. @@ -552,6 +612,7 @@ def get_state(self) -> Dict[str, Any]: lambda w: w.get_state(), remote_actor_ids=[worker] ) learner_state = self._get_results(results)[0] + return { "learner_state": learner_state, "should_module_be_updated_fn": self.should_module_be_updated_fn, @@ -776,16 +837,14 @@ def load_module_state( # so we should not load any modules in the MARLModule checkpoint that are # also in the RLModule checkpoints. if modules_to_load: - if any( - module_id in modules_to_load - for module_id in rl_module_ckpt_dirs.keys() - ): - raise ValueError( - f"module_id {module_id} was specified in both " - "modules_to_load and rl_module_ckpt_dirs. Please only " - "specify a module to be loaded only once, either in " - "modules_to_load or rl_module_ckpt_dirs, but not both." - ) + for module_id in rl_module_ckpt_dirs.keys(): + if module_id in modules_to_load: + raise ValueError( + f"module_id {module_id} was specified in both " + "`modules_to_load` AND `rl_module_ckpt_dirs`! " + "Specify a module to be loaded either in `modules_to_load` " + "or `rl_module_ckpt_dirs`, but not in both." + ) else: modules_to_load = module_keys - set(rl_module_ckpt_dirs.keys()) @@ -943,6 +1002,18 @@ def __del__(self): if not self._is_shut_down: self.shutdown() + @Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False) + def update(self, *args, **kwargs): + # Just in case, we would like to revert this API retirement, we can do so + # easily. + return self._update(*args, **kwargs, async_update=False) + + @Deprecated(new="LearnerGroup.update_from_batch(async=True)", error=False) + def async_update(self, *args, **kwargs): + # Just in case, we would like to revert this API retirement, we can do so + # easily. + return self._update(*args, **kwargs, async_update=True) + @Deprecated(new="LearnerGroup.set_should_module_be_updated_fn()", error=True) def set_is_module_trainable(self, *args, **kwargs): pass diff --git a/rllib/core/learner/tests/test_learner.py b/rllib/core/learner/tests/test_learner.py index 48419f0a05c5..ea731e9995d8 100644 --- a/rllib/core/learner/tests/test_learner.py +++ b/rllib/core/learner/tests/test_learner.py @@ -44,7 +44,7 @@ def test_end_to_end_update(self): min_loss = float("inf") for iter_i in range(1000): batch = reader.next() - results = learner.update(batch.as_multi_agent()) + results = learner.update_from_batch(batch=batch.as_multi_agent()) loss = results[ALL_MODULES][Learner.TOTAL_LOSS_KEY] min_loss = min(loss, min_loss) diff --git a/rllib/core/learner/tests/test_learner_group.py b/rllib/core/learner/tests/test_learner_group.py index 45ff1086157c..2574b44308e1 100644 --- a/rllib/core/learner/tests/test_learner_group.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -84,8 +84,8 @@ def local_training_helper(self, fw, scaling_mode) -> None: reader = get_cartpole_dataset_reader(batch_size=500) batch = reader.next() batch = batch.as_multi_agent() - learner_update = local_learner.update(batch) - learner_group_update = learner_group.update(batch) + learner_update = local_learner.update_from_batch(batch=batch) + learner_group_update = learner_group.update_from_batch(batch=batch) check(learner_update, learner_group_update) new_module_id = "test_module" @@ -109,12 +109,12 @@ def local_training_helper(self, fw, scaling_mode) -> None: # the optimizer state is not initialized fully until the first time that # training is completed. A call to get state before that won't contain the # optimizer state. So we do a dummy update here to initialize the optimizer - local_learner.update(ma_batch) - learner_group.update(ma_batch) + local_learner.update_from_batch(batch=ma_batch) + learner_group.update_from_batch(batch=ma_batch) check(local_learner.get_state(), learner_group.get_state()["learner_state"]) - local_learner_results = local_learner.update(ma_batch) - learner_group_results = learner_group.update(ma_batch) + local_learner_results = local_learner.update_from_batch(batch=ma_batch) + learner_group_results = learner_group.update_from_batch(batch=ma_batch) check(local_learner_results, learner_group_results) @@ -189,7 +189,9 @@ def test_update_multigpu(self): min_loss = float("inf") for iter_i in range(1000): batch = reader.next() - results = learner_group.update(batch.as_multi_agent(), reduce_fn=None) + results = learner_group.update_from_batch( + batch=batch.as_multi_agent(), reduce_fn=None + ) loss = np.mean( [res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results] @@ -229,7 +231,9 @@ def test_add_remove_module(self): batch = reader.next() # update once with the default policy - results = learner_group.update(batch.as_multi_agent(), reduce_fn=None) + results = learner_group.update_from_batch( + batch=batch.as_multi_agent(), reduce_fn=None + ) module_ids_before_add = {DEFAULT_POLICY_ID} new_module_id = "test_module" @@ -239,8 +243,8 @@ def test_add_remove_module(self): ) # do training that includes the test_module - results = learner_group.update( - MultiAgentBatch( + results = learner_group.update_from_batch( + batch=MultiAgentBatch( {new_module_id: batch, DEFAULT_POLICY_ID: batch}, batch.count ), reduce_fn=None, @@ -260,7 +264,9 @@ def test_add_remove_module(self): learner_group.remove_module(module_id=new_module_id) # run training without the test_module - results = learner_group.update(batch.as_multi_agent(), reduce_fn=None) + results = learner_group.update_from_batch( + batch=batch.as_multi_agent(), reduce_fn=None + ) self._check_multi_worker_weights(results) @@ -415,7 +421,7 @@ def test_load_module_state_errors(self): module_1.save_to_checkpoint(tmpdir2) with self.assertRaisesRegex( (ValueError,), - ".*modules_to_load and rl_module_ckpt_dirs. Please only.*", + ".*`modules_to_load` AND `rl_module_ckpt_dirs`!.*", ): # check that loading marl modules and specifing a module id to # be loaded using modules_to_load and rl_module_ckpt_dirs raises @@ -462,7 +468,9 @@ def test_save_load_state(self): initial_learner_group_weights = initial_learner_group.get_weights() # do a single update - initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None) + initial_learner_group.update_from_batch( + batch=batch.as_multi_agent(), reduce_fn=None + ) # checkpoint the learner state after 1 update for later comparison learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name @@ -476,8 +484,8 @@ def test_save_load_state(self): new_learner_group.load_state(learner_after_1_update_checkpoint_dir) # do another update - results_with_break = new_learner_group.update( - batch.as_multi_agent(), reduce_fn=None + results_with_break = new_learner_group.update_from_batch( + batch=batch.as_multi_agent(), reduce_fn=None ) weights_after_1_update_with_break = new_learner_group.get_weights() new_learner_group.shutdown() @@ -487,9 +495,9 @@ def test_save_load_state(self): learner_group = config.build_learner_group(env=env) learner_group.load_state(initial_learner_checkpoint_dir) check(learner_group.get_weights(), initial_learner_group_weights) - learner_group.update(batch.as_multi_agent(), reduce_fn=None) - results_without_break = learner_group.update( - batch.as_multi_agent(), reduce_fn=None + learner_group.update_from_batch(batch.as_multi_agent(), reduce_fn=None) + results_without_break = learner_group.update_from_batch( + batch=batch.as_multi_agent(), reduce_fn=None ) weights_after_1_update_without_break = learner_group.get_weights() learner_group.shutdown() @@ -531,10 +539,12 @@ def test_async_update(self): timer_sync = _Timer() timer_async = _Timer() with timer_sync: - learner_group.update(batch.as_multi_agent(), reduce_fn=None) + learner_group.update_from_batch( + batch=batch.as_multi_agent(), async_update=False, reduce_fn=None + ) with timer_async: - result_async = learner_group.async_update( - batch.as_multi_agent(), reduce_fn=None + result_async = learner_group.update_from_batch( + batch=batch.as_multi_agent(), async_update=True, reduce_fn=None ) # ideally the the first async update will return nothing, and an easy # way to check that is if the time for an async update call is faster @@ -545,8 +555,8 @@ def test_async_update(self): iter_i = 0 while True: batch = reader.next() - async_results = learner_group.async_update( - batch.as_multi_agent(), reduce_fn=None + async_results = learner_group.update_from_batch( + batch.as_multi_agent(), async_update=True, reduce_fn=None ) if not async_results: continue diff --git a/rllib/utils/minibatch_utils.py b/rllib/utils/minibatch_utils.py index 805050a0ca9e..ffee2c3c3ad5 100644 --- a/rllib/utils/minibatch_utils.py +++ b/rllib/utils/minibatch_utils.py @@ -1,8 +1,10 @@ import math +from typing import List from ray.rllib.policy.sample_batch import MultiAgentBatch, concat_samples -from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import EpisodeType @DeveloperAPI @@ -153,3 +155,62 @@ def __iter__(self): # TODO (Avnish): int(batch_size) ? How should we shard MA batches really? new_batch = MultiAgentBatch(batch_to_send, int(batch_size)) yield new_batch + + +@DeveloperAPI +class ShardEpisodesIterator: + """Iterator for sharding a list of Episodes into num_shards sub-lists of Episodes. + + Args: + episodes: The input list of Episodes. + num_shards: The number of shards to split the episodes into. + + Yields: + A sub-list of Episodes of size roughly `len(episodes) / num_shards`. The yielded + sublists might have slightly different total sums of episode lengths, in order + to not have to drop even a single timestep. + """ + + def __init__(self, episodes: List[EpisodeType], num_shards: int): + self._episodes = sorted(episodes, key=len, reverse=True) + self._num_shards = num_shards + self._total_length = sum(len(e) for e in episodes) + self._target_lengths = [0 for _ in range(self._num_shards)] + remaining_length = self._total_length + for s in range(self._num_shards): + len_ = remaining_length // (num_shards - s) + self._target_lengths[s] = len_ + remaining_length -= len_ + + def __iter__(self): + sublists = [[] for _ in range(self._num_shards)] + lengths = [0 for _ in range(self._num_shards)] + episode_index = 0 + + while episode_index < len(self._episodes): + episode = self._episodes[episode_index] + min_index = lengths.index(min(lengths)) + + if lengths[min_index] + len(episode) <= self._target_lengths[min_index]: + # Add the whole episode if it fits within the target length + sublists[min_index].append(episode) + lengths[min_index] += len(episode) + episode_index += 1 + else: + # Otherwise, slice the episode + remaining_length = self._target_lengths[min_index] - lengths[min_index] + if remaining_length > 0: + slice_part, remaining_part = ( + episode[:remaining_length], + episode[remaining_length:], + ) + sublists[min_index].append(slice_part) + lengths[min_index] += len(slice_part) + self._episodes[episode_index] = remaining_part + else: + assert remaining_length == 0 + sublists[min_index].append(episode) + episode_index += 1 + + for sublist in sublists: + yield sublist diff --git a/rllib/utils/tests/test_minibatch_utils.py b/rllib/utils/tests/test_minibatch_utils.py index 0256e9ffab31..fe36470f6c36 100644 --- a/rllib/utils/tests/test_minibatch_utils.py +++ b/rllib/utils/tests/test_minibatch_utils.py @@ -3,7 +3,10 @@ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.minibatch_utils import MiniBatchCyclicIterator +from ray.rllib.utils.minibatch_utils import ( + MiniBatchCyclicIterator, + ShardEpisodesIterator, +) from ray.rllib.utils.test_utils import check tf1, tf, tfv = try_import_tf() @@ -113,6 +116,56 @@ def test_minibatch_cyclic_iterator(self): check(iteration_counter, expected_iteration_counter) print(f"iteration_counter: {iteration_counter}") + def test_shard_episodes_iterator(self): + class DummyEpisode: + def __init__(self, length): + self.length = length + # Dummy data to represent the episode content. + self.data = [0] * length + + def __len__(self): + return self.length + + def __getitem__(self, key): + assert isinstance(key, slice) + # Create a new Episode object with the sliced length + return DummyEpisode(len(self.data[key])) + + def __repr__(self): + return f"{(type(self).__name__)}({self.length})" + + # Create a list of episodes with varying lengths + episode_lens = [10, 21, 3, 4, 35, 41, 5, 15, 44] + + episodes = [DummyEpisode(len_) for len_ in episode_lens] + + # Number of shards + num_shards = 3 + # Create the iterator + iterator = ShardEpisodesIterator(episodes, num_shards) + # Iterate and collect the results + shards = list(iterator) + # The sharder should try to split as few times as possible. In our + # case here, only the len=4 episode is split into 1 and 3. All other + # episodes are kept as-is. Yet, the resulting sub-lists have all + # either size 59 or 60. + check([len(e) for e in shards[0]], [44, 10, 5]) # 59 + check([len(e) for e in shards[1]], [41, 15, 3]) # 59 + check([len(e) for e in shards[2]], [35, 21, 1, 3]) # 60 + + # Different number of shards. + num_shards = 4 + # Create the iterator. + iterator = ShardEpisodesIterator(episodes, num_shards) + # Iterate and collect the results + shards = list(iterator) + # The sharder should try to split as few times as possible, keeping + # as many episodes as-is (w/o splitting). + check([len(e) for e in shards[0]], [44]) # 44 + check([len(e) for e in shards[1]], [41, 3]) # 44 + check([len(e) for e in shards[2]], [35, 10]) # 45 + check([len(e) for e in shards[3]], [21, 15, 5, 1, 3]) # 45 + if __name__ == "__main__": import pytest