-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New ConnectorV2 API #04: Changes to Learner/LearnerGroup API to allow updating from Episodes. #41235
Changes from 10 commits
e7ae52a
fe640de
b340ddc
9492b7a
52d5e72
6437d7e
242d40a
61be702
bf802fc
16f2c38
ad047a7
083388d
8e02889
10b0700
e439fc8
4633659
cce2c66
bdb20dc
bcdb92f
03fe431
7dd8f3f
b769c05
f5ffe83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,7 @@ | |
from ray.rllib.utils.schedules.scheduler import Scheduler | ||
from ray.rllib.utils.serialization import serialize_type | ||
from ray.rllib.utils.typing import ( | ||
EpisodeType, | ||
LearningRateOrSchedule, | ||
ModuleID, | ||
Optimizer, | ||
|
@@ -1099,13 +1100,18 @@ def additional_update_for_module( | |
|
||
def update( | ||
self, | ||
batch: MultiAgentBatch, | ||
*, | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
# TODO (sven): We should allow passing in a single agent batch here | ||
# as well for simplicity. | ||
batch: Optional[MultiAgentBatch] = None, | ||
episodes: Optional[List[EpisodeType]] = None, | ||
reduce_fn: Callable[[List[Dict[str, Any]]], ResultDict] = ( | ||
_reduce_mean_results | ||
), | ||
# TODO (sven): Deprecate these in favor of config attributes for only those | ||
# algos that actually need (and know how) to do minibatching. | ||
minibatch_size: Optional[int] = None, | ||
num_iters: int = 1, | ||
) -> Union[Dict[str, Any], List[Dict[str, Any]]]: | ||
"""Do `num_iters` minibatch updates given the original batch. | ||
|
||
|
@@ -1114,34 +1120,50 @@ def update( | |
will be used for all module ids in MultiAgentRLModule. | ||
|
||
Args: | ||
batch: A batch of data. | ||
minibatch_size: The size of the minibatch to use for each update. | ||
num_iters: The number of complete passes over all the sub-batches | ||
in the input multi-agent batch. | ||
batch: An optional batch of training data. If None, the `episodes` arg | ||
must be provided. | ||
episodes: An optional list of episode objects. If None, the `batch` arg | ||
must be provided. | ||
reduce_fn: reduce_fn: A function to reduce the results from a list of | ||
minibatch updates. This can be any arbitrary function that takes a | ||
list of dictionaries and returns a single dictionary. For example you | ||
can either take an average (default) or concatenate the results (for | ||
example for metrics) or be more selective about you want to report back | ||
to the algorithm's training_step. If None is passed, the results will | ||
not get reduced. | ||
minibatch_size: The size of the minibatch to use for each update. | ||
num_iters: The number of complete passes over all the sub-batches | ||
in the input multi-agent batch. | ||
|
||
Returns: | ||
A dictionary of results, in numpy format or a list of such dictionaries in | ||
case `reduce_fn` is None and we have more than one minibatch pass. | ||
""" | ||
self._check_is_built() | ||
|
||
missing_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys()) | ||
if len(missing_module_ids) > 0: | ||
raise ValueError( | ||
"Batch contains module ids that are not in the learner: " | ||
f"{missing_module_ids}" | ||
# If a (multi-agent) batch is provided, check, whether our RLModule | ||
# contains all ModuleIDs found in this batch. If not, throw an error. | ||
if batch is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the alternative design (two update methods), we could avoid these rather ugly if-blocks then. |
||
unknown_module_ids = set(batch.policy_batches.keys()) - set( | ||
self.module.keys() | ||
) | ||
if len(unknown_module_ids) > 0: | ||
raise ValueError( | ||
"Batch contains module ids that are not in the learner: " | ||
f"{unknown_module_ids}" | ||
) | ||
|
||
if num_iters < 1: | ||
# We must do at least one pass on the batch for training. | ||
raise ValueError("`num_iters` must be >= 1") | ||
|
||
# Call the train data preprocessor. | ||
batch, episodes = self._preprocess_train_data(batch=batch, episodes=episodes) | ||
|
||
# TODO (sven): Insert a call to the Learner ConnectorV2 pipeline here, providing | ||
# it both `batch` and `episode` for further custom processing before the | ||
# actual `Learner._update()` call. | ||
|
||
if minibatch_size: | ||
batch_iter = MiniBatchCyclicIterator | ||
elif num_iters > 1: | ||
|
@@ -1180,7 +1202,7 @@ def update( | |
metrics_per_module=defaultdict(dict, **metrics_per_module), | ||
) | ||
self._check_result(result) | ||
# TODO (sven): Figure out whether `compile_metrics` should be forced | ||
# TODO (sven): Figure out whether `compile_results` should be forced | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
# to return all numpy/python data, then we can skip this conversion | ||
# step here. | ||
results.append(convert_to_numpy(result)) | ||
|
@@ -1201,6 +1223,39 @@ def update( | |
# dict. | ||
return reduce_fn(results) | ||
|
||
@OverrideToImplementCustomLogic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is any neural network inference, does it happen here or in the connector? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question! The answer is: sometimes both. For example: If you have some preprocessing needs for your training data (no matter whether episodes or batches), then you might want to do some preprocessing on this data (e.g. clip rewards, extend episodes by one artificial timestep for v-trace or GAE) and then perform a pre-forward pass through your network (e.g. to get the value estimates). For that pre-forward pass, you'll need to call your connector first to make sure this batch has all custom-required data formats (e.g. LSTM zero-padding). Only after all these preprocessing steps, you will be able to continue with the regular |
||
def _preprocess_train_data( | ||
self, | ||
*, | ||
batch: Optional[MultiAgentBatch] = None, | ||
episodes: Optional[List[EpisodeType]] = None, | ||
) -> Tuple[Optional[MultiAgentBatch], Optional[List[EpisodeType]]]: | ||
"""Allows custom preprocessing of batch/episode data before the actual update. | ||
|
||
The higher level order, in which this method is called from within | ||
`Learner.update(batch, episodes)` is: | ||
* batch, episodes = self._preprocess_train_data(batch, episodes) | ||
* batch = self._learner_connector(batch, episodes) | ||
* results = self._update(batch) | ||
|
||
The default implementation does not do any processing and is a mere pass | ||
through. However, specific algorithms should override this method to implement | ||
their specific training data preprocessing needs. It is possible to perform | ||
preliminary RLModule forward passes (besides the main "forward_train()" call | ||
during `self._update`) in this method and custom algorithms might also want to | ||
use this Learner's `self._learner_connector` to prepare the data | ||
(batch/episodes) for such extra forward calls. | ||
|
||
Args: | ||
batch: An optional batch of training data to preprocess. | ||
episodes: An optional list of episodes objects to preprocess. | ||
|
||
Returns: | ||
A tuple consisting of the processed `batch` and the processed list of | ||
`episodes`. | ||
""" | ||
return batch, episodes | ||
|
||
@OverrideToImplementCustomLogic | ||
@abc.abstractmethod | ||
def _update( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to discuss the alternative to provide two different (mutually exclusive?) methods that the user/algo can decide to call:
update_from_batch
(for algos that do NOT require episode processing, such as DQN) orupdate_from_episodes
(for algos that require a view on the sampled episodes for e.g. vf-bootstrapping, vtrace, etc..).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it if the two methods are separated. I don't think there would be a case where a specific algorithm's learner would have both methods implemented. ie DQN will only implemented
update_from_batch
, and PPO would only implementupdate_from_episodes
. This is much much cleaner than mixing both into one function. The user will have to deal with less cognitive load if they are separated.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I separated them in the LearnerGroup and Learner APIs:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think it's nicer to have the
async_update
bool option as an extra arg (instead of separate method) for better consistency and less code bloat.