-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New ConnectorV2 API #04: Changes to Learner/LearnerGroup API to allow updating from Episodes. #41235
Changes from 1 commit
e7ae52a
fe640de
b340ddc
9492b7a
52d5e72
6437d7e
242d40a
61be702
bf802fc
16f2c38
ad047a7
083388d
8e02889
10b0700
e439fc8
4633659
cce2c66
bdb20dc
bcdb92f
03fe431
7dd8f3f
b769c05
f5ffe83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1228,13 +1228,16 @@ def additional_update_for_module( | |
|
||
def update( | ||
self, | ||
batch: MultiAgentBatch, | ||
*, | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
batch: Optional[MultiAgentBatch] = None, | ||
episodes: Optional[List[EpisodeType]] = None, | ||
reduce_fn: Callable[[List[Mapping[str, Any]]], ResultDict] = ( | ||
_reduce_mean_results | ||
), | ||
# TODO (sven): Deprecate these in favor of learner hyperparams for only those | ||
# algos actually that need to do minibatching. | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]: | ||
"""Do `num_iters` minibatch updates given the original batch. | ||
|
||
|
@@ -1254,23 +1257,30 @@ def update( | |
example for metrics) or be more selective about you want to report back | ||
to the algorithm's training_step. If None is passed, the results will | ||
not get reduced. | ||
|
||
Returns: | ||
A dictionary of results, in numpy format or a list of such dictionaries in | ||
case `reduce_fn` is None and we have more than one minibatch pass. | ||
""" | ||
self._check_is_built() | ||
|
||
missing_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys()) | ||
if len(missing_module_ids) > 0: | ||
raise ValueError( | ||
"Batch contains module ids that are not in the learner: " | ||
f"{missing_module_ids}" | ||
if batch is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the alternative design (two update methods), we could avoid these rather ugly if-blocks then. |
||
unknown_module_ids = ( | ||
set(batch.policy_batches.keys()) - set(self.module.keys()) | ||
) | ||
if len(unknown_module_ids) > 0: | ||
raise ValueError( | ||
"Batch contains module ids that are not in the learner: " | ||
f"{unknown_module_ids}" | ||
) | ||
|
||
if num_iters < 1: | ||
# We must do at least one pass on the batch for training. | ||
raise ValueError("`num_iters` must be >= 1") | ||
|
||
# Call the train data preprocessor. | ||
batch, episodes = self._preprocess_train_data(batch=batch, episodes=episodes) | ||
|
||
if minibatch_size: | ||
batch_iter = MiniBatchCyclicIterator | ||
elif num_iters > 1: | ||
|
@@ -1309,12 +1319,12 @@ def update( | |
metrics_per_module=defaultdict(dict, **metrics_per_module), | ||
) | ||
self._check_result(result) | ||
# TODO (sven): Figure out whether `compile_metrics` should be forced | ||
# TODO (sven): Figure out whether `compile_results` should be forced | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
# to return all numpy/python data, then we can skip this conversion | ||
# step here. | ||
results.append(convert_to_numpy(result)) | ||
|
||
batch = self._set_slicing_by_batch_id(batch, value=False) | ||
self._set_slicing_by_batch_id(batch, value=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Reduce results across all minibatches, if necessary. | ||
|
||
|
@@ -1330,6 +1340,34 @@ def update( | |
# dict. | ||
return reduce_fn(results) | ||
|
||
@OverrideToImplementCustomLogic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is any neural network inference, does it happen here or in the connector? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question! The answer is: sometimes both. For example: If you have some preprocessing needs for your training data (no matter whether episodes or batches), then you might want to do some preprocessing on this data (e.g. clip rewards, extend episodes by one artificial timestep for v-trace or GAE) and then perform a pre-forward pass through your network (e.g. to get the value estimates). For that pre-forward pass, you'll need to call your connector first to make sure this batch has all custom-required data formats (e.g. LSTM zero-padding). Only after all these preprocessing steps, you will be able to continue with the regular |
||
def _preprocess_train_data(self, *, batch, episodes) -> Tuple[Any, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this should be private or public? |
||
"""Allows custom preprocessing of batch/episode data before the actual update. | ||
|
||
The higher level order, in which this method is called from within | ||
`Learner.update(batch, episodes)` is: | ||
* _preprocess_train_data(batch, episodes) | ||
* _learner_connector(batch, episodes) | ||
* _update_from_batch(batch) | ||
|
||
The default implementation does not do any processing and is a mere pass through. | ||
However, specific algorithms should override this method to implement their | ||
specific training data preprocessing needs. It is possible to perform separate | ||
forward passes (besides the main "forward_train()" one during | ||
`_update_from_batch`) in this method and custom algorithms might also want to | ||
use this Learner's `self._learner_connector` to prepare the data (batch/episodes) | ||
for such an extra forward call. | ||
|
||
Args: | ||
batch: A data batch to preprocess. | ||
episodes: A list of episodes to preprocess. | ||
|
||
Returns: | ||
A tuple consisting of the processed `batch` and the processed list of | ||
`episodes`. | ||
""" | ||
return batch, episodes | ||
|
||
@OverrideToImplementCustomLogic | ||
@abc.abstractmethod | ||
def _update( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,8 @@ | |
from ray.rllib.core.learner.learner import LearnerSpec | ||
from ray.rllib.policy.sample_batch import MultiAgentBatch | ||
from ray.rllib.utils.actor_manager import FaultTolerantActorManager | ||
from ray.rllib.utils.minibatch_utils import ShardBatchIterator | ||
from ray.rllib.utils.typing import ResultDict | ||
from ray.rllib.utils.minibatch_utils import ShardBatchIterator, ShardEpisodesIterator | ||
from ray.rllib.utils.typing import EpisodeType, ResultDict | ||
from ray.rllib.utils.numpy import convert_to_numpy | ||
from ray.train._internal.backend_executor import BackendExecutor | ||
from ray.tune.utils.file_transfer import sync_dir_between_nodes | ||
|
@@ -146,25 +146,38 @@ def is_local(self) -> bool: | |
|
||
def update( | ||
self, | ||
batch: MultiAgentBatch, | ||
*, | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
batch: Optional[MultiAgentBatch] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same discussion as above. |
||
episodes: Optional[List[EpisodeType]] = None, | ||
reduce_fn: Optional[Callable[[List[Mapping[str, Any]]], ResultDict]] = ( | ||
_reduce_mean_results | ||
), | ||
# TODO (sven): Deprecate the following args. They should be extracted from the | ||
# LearnerHyperparameters of those specific algorithms that actually require | ||
# these settings. | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]: | ||
"""Do one or more gradient based updates to the Learner(s) based on given data. | ||
|
||
Args: | ||
batch: The data batch to use for the update. | ||
batch: The (optional) data batch to use for the update. If there are more | ||
than one Learner workers, the batch is split amongst these and one | ||
shard is sent to each Learner. If `batch` is not provided, the user | ||
must provide the `episodes` arg. Sending both `batch` and `episodes` | ||
is also allowed. | ||
episodes: The (optional) list of Episodes to process and perform the update | ||
for. If there are more than one Learner workers, the list of episodes | ||
is split amongst these and one list shard is sent to each Learner. | ||
If `episodes` is not provided, the user must provide the `batch` arg. | ||
Sending both `batch` and `episodes` is also allowed. | ||
minibatch_size: The minibatch size to use for the update. | ||
num_iters: The number of complete passes over all the sub-batches in the | ||
input multi-agent batch. | ||
reduce_fn: An optional callable to reduce the results from a list of the | ||
Learner actors into a single result. This can be any arbitrary function | ||
that takes a list of dictionaries and returns a single dictionary. For | ||
example you can either take an average (default) or concatenate the | ||
example, you can either take an average (default) or concatenate the | ||
results (for example for metrics) or be more selective about you want to | ||
report back to the algorithm's training_step. If None is passed, the | ||
results will not get reduced. | ||
|
@@ -175,60 +188,98 @@ def update( | |
""" | ||
|
||
# Construct a multi-agent batch with only the trainable modules. | ||
train_batch = {} | ||
for module_id in batch.policy_batches.keys(): | ||
if self._is_module_trainable(module_id, batch): | ||
train_batch[module_id] = batch.policy_batches[module_id] | ||
train_batch = MultiAgentBatch(train_batch, batch.count) | ||
# TODO (sven): Move this into individual Learners. It might be that | ||
# batch/episodes postprocessing on each Learner requires the non-trainable | ||
# modules' data. | ||
train_batch = None | ||
if batch is not None: | ||
train_batch = {} | ||
for module_id in batch.policy_batches.keys(): | ||
if self._is_module_trainable(module_id, batch): | ||
train_batch[module_id] = batch.policy_batches[module_id] | ||
train_batch = MultiAgentBatch(train_batch, batch.count) | ||
|
||
if self.is_local: | ||
assert batch is not None or episodes is not None | ||
results = [ | ||
self._learner.update( | ||
train_batch, | ||
batch=train_batch, | ||
episodes=episodes, | ||
minibatch_size=minibatch_size, | ||
num_iters=num_iters, | ||
reduce_fn=reduce_fn, | ||
) | ||
] | ||
else: | ||
|
||
def _learner_update(learner, minibatch): | ||
def _learner_update(learner: Learner, batch_shard=None, episodes_shard=None): | ||
return learner.update( | ||
minibatch, | ||
batch=batch_shard, | ||
episodes=episodes_shard, | ||
minibatch_size=minibatch_size, | ||
num_iters=num_iters, | ||
reduce_fn=reduce_fn, | ||
) | ||
|
||
results = self._get_results( | ||
self._worker_manager.foreach_actor( | ||
[ | ||
partial(_learner_update, minibatch=minibatch) | ||
for minibatch in ShardBatchIterator(batch, len(self._workers)) | ||
] | ||
# Only batch provided, split it up into n shards. | ||
if episodes is None: | ||
assert batch is not None | ||
results = self._get_results( | ||
self._worker_manager.foreach_actor( | ||
[ | ||
partial(_learner_update, batch_shard=batch_shard) | ||
for batch_shard in ShardBatchIterator(train_batch, len(self._workers)) | ||
] | ||
) | ||
) | ||
) | ||
elif batch is None: | ||
assert episodes is not None | ||
results = self._get_results( | ||
self._worker_manager.foreach_actor( | ||
[ | ||
partial(_learner_update, episodes_shard=episodes_shard) | ||
for episodes_shard in ShardEpisodesIterator(episodes, len(self._workers)) | ||
] | ||
) | ||
) | ||
# TODO (sven): Implement the case in which both batch and episodes might | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have mutually exclusive methods: |
||
# already be provided (or figure out whether this makes sense at all). | ||
else: | ||
raise NotImplementedError | ||
|
||
# TODO(sven): Move reduce_fn to the training_step | ||
# TODO (sven): Move reduce_fn to the training_step | ||
if reduce_fn is None: | ||
return results | ||
else: | ||
return reduce_fn(results) | ||
|
||
def async_update( | ||
self, | ||
batch: MultiAgentBatch, | ||
*, | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
batch: Optional[MultiAgentBatch] = None, | ||
episodes: Optional[List[EpisodeType]] = None, | ||
reduce_fn: Optional[Callable[[List[Mapping[str, Any]]], ResultDict]] = ( | ||
_reduce_mean_results | ||
), | ||
# TODO (sven): Deprecate the following args. They should be extracted from the | ||
# LearnerHyperparameters of those specific algorithms that actually require | ||
# these settings. | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
) -> Union[List[Mapping[str, Any]], List[List[Mapping[str, Any]]]]: | ||
"""Asnychronously do gradient based updates to the Learner(s) with `batch`. | ||
|
||
Args: | ||
batch: The data batch to use for the update. | ||
batch: The (optional) data batch to use for the update. If there are more | ||
than one Learner workers, the batch is split amongst these and one | ||
shard is sent to each Learner. If `batch` is not provided, the user | ||
must provide the `episodes` arg. Sending both `batch` and `episodes` | ||
is also allowed. | ||
episodes: The (optional) list of Episodes to process and perform the update | ||
for. If there are more than one Learner workers, the list of episodes | ||
is split amongst these and one list shard is sent to each Learner. | ||
If `episodes` is not provided, the user must provide the `batch` arg. | ||
Sending both `batch` and `episodes` is also allowed. | ||
minibatch_size: The minibatch size to use for the update. | ||
num_iters: The number of complete passes over all the sub-batches in the | ||
input multi-agent batch. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
import math | ||
from typing import List | ||
|
||
from ray.rllib.policy.sample_batch import MultiAgentBatch, concat_samples | ||
from ray.rllib.utils.annotations import DeveloperAPI | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.annotations import DeveloperAPI | ||
from ray.rllib.utils.typing import EpisodeType | ||
|
||
|
||
@DeveloperAPI | ||
|
@@ -153,3 +155,33 @@ def __iter__(self): | |
# TODO (Avnish): int(batch_size) ? How should we shard MA batches really? | ||
new_batch = MultiAgentBatch(batch_to_send, int(batch_size)) | ||
yield new_batch | ||
|
||
|
||
@DeveloperAPI | ||
class ShardEpisodesIterator: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have a unittest for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
"""Iterator for sharding a list of Episodes into num_shards sub-lists of Episodes. | ||
|
||
Args: | ||
episodes: The input list of Episodes. | ||
num_shards: The number of shards to split the episodes into. | ||
|
||
Yields: | ||
A sub-list of Episodes of size roughly `len(episodes) / num_shards`. | ||
""" | ||
def __init__(self, episodes: List[EpisodeType], num_shards: int): | ||
self._episodes = sorted(episodes, key=len, reverse=True) | ||
self._num_shards = num_shards | ||
|
||
def __iter__(self): | ||
# Initialize sub-lists and their total lengths | ||
sublists = [[] for _ in range(self._num_shards)] | ||
lengths = [0 for _ in range(self._num_shards)] | ||
|
||
for episodes in self._episodes: | ||
# Find the sub-list with the minimum total length and add the item to it | ||
min_index = lengths.index(min(lengths)) | ||
sublists[min_index].append(item) | ||
lengths[min_index] += len(item) | ||
|
||
for sublist in sublists: | ||
yield sublist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to discuss the alternative to provide two different (mutually exclusive?) methods that the user/algo can decide to call:
update_from_batch
(for algos that do NOT require episode processing, such as DQN) orupdate_from_episodes
(for algos that require a view on the sampled episodes for e.g. vf-bootstrapping, vtrace, etc..).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it if the two methods are separated. I don't think there would be a case where a specific algorithm's learner would have both methods implemented. ie DQN will only implemented
update_from_batch
, and PPO would only implementupdate_from_episodes
. This is much much cleaner than mixing both into one function. The user will have to deal with less cognitive load if they are separated.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I separated them in the LearnerGroup and Learner APIs:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think it's nicer to have the
async_update
bool option as an extra arg (instead of separate method) for better consistency and less code bloat.