Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New ConnectorV2 API #04: Changes to Learner/LearnerGroup API to allow updating from Episodes. #41235

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ Updates
.. testcode::

# This is a blocking update
results = learner_group.update(DUMMY_BATCH)
results = learner_group.update(batch=DUMMY_BATCH)

# This is a non-blocking update. The results are returned in a future
# call to `async_update`
_ = learner_group.async_update(DUMMY_BATCH)
_ = learner_group.async_update(batch=DUMMY_BATCH)

# Artificially wait for async request to be done to get the results
# in the next call to `LearnerGroup.async_update()`.
time.sleep(5)
results = learner_group.async_update(DUMMY_BATCH)
results = learner_group.async_update(batch=DUMMY_BATCH)
# `results` is a list of results dict. The items in the list represent the different
# remote results from the different calls to `async_update()`.
assert len(results) > 0
Expand All @@ -257,7 +257,7 @@ Updates
.. testcode::

# This is a blocking update.
result = learner.update(DUMMY_BATCH)
result = learner.update(batch=DUMMY_BATCH)

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import functools
import gymnasium as gym
import importlib
import importlib.metadata
import json
import logging
import numpy as np
import os
from packaging import version
import importlib.metadata
import re
import tempfile
import time
Expand Down Expand Up @@ -1613,7 +1613,7 @@ def training_step(self) -> ResultDict:
# TODO: (sven) rename MultiGPUOptimizer into something more
# meaningful.
if self.config._enable_new_api_stack:
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update(batch=train_batch)
elif self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self) -> ResultDict:
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

# Updating the policy.
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update(batch=train_batch)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def training_step(self) -> ResultDict:

# Perform the actual update via our learner group.
train_results = self.learner_group.update(
SampleBatch(sample).as_multi_agent(),
batch=SampleBatch(sample).as_multi_agent(),
reduce_fn=self._reduce_results,
)
self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def learn_on_processed_samples(self) -> ResultDict:
for batch in batches:
if blocking:
result = self.learner_group.update(
batch,
batch=batch,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/tests/test_impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_impala_loss(self):
env=algo.workers.local_worker().env
)
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def training_step(self) -> ResultDict:
# TODO (Kourosh) Clearly define what train_batch_size
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
train_results = self.learner_group.update(
train_batch,
batch=train_batch,
minibatch_size=self.config.sgd_minibatch_size,
num_iters=self.config.num_sgd_iter,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_loss(self):

# Load the algo weights onto the learner_group.
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_value():
assert init_std == 0.0, init_std
batch = compute_gae_for_sample_batch(policy, PENDULUM_FAKE_BATCH.copy())
batch = policy._lazy_tensor_dict(batch)
algo.learner_group.update(batch.as_multi_agent())
algo.learner_group.update(batch=batch.as_multi_agent())

# Check the variable is updated.
post_std = get_value()
Expand Down
81 changes: 68 additions & 13 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.serialization import serialize_type
from ray.rllib.utils.typing import (
EpisodeType,
LearningRateOrSchedule,
ModuleID,
Optimizer,
Expand Down Expand Up @@ -1099,13 +1100,18 @@ def additional_update_for_module(

def update(
self,
batch: MultiAgentBatch,
*,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
# TODO (sven): We should allow passing in a single agent batch here
# as well for simplicity.
batch: Optional[MultiAgentBatch] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to discuss the alternative to provide two different (mutually exclusive?) methods that the user/algo can decide to call: update_from_batch (for algos that do NOT require episode processing, such as DQN) or update_from_episodes (for algos that require a view on the sampled episodes for e.g. vf-bootstrapping, vtrace, etc..).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it if the two methods are separated. I don't think there would be a case where a specific algorithm's learner would have both methods implemented. ie DQN will only implemented update_from_batch, and PPO would only implement update_from_episodes. This is much much cleaner than mixing both into one function. The user will have to deal with less cognitive load if they are separated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated them in the LearnerGroup and Learner APIs:

  • update_from_batch(async=False|True)
  • update_from_episodes(async=False|True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think it's nicer to have the async_update bool option as an extra arg (instead of separate method) for better consistency and less code bloat.

episodes: Optional[List[EpisodeType]] = None,
reduce_fn: Callable[[List[Dict[str, Any]]], ResultDict] = (
_reduce_mean_results
),
# TODO (sven): Deprecate these in favor of config attributes for only those
# algos that actually need (and know how) to do minibatching.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Do `num_iters` minibatch updates given the original batch.

Expand All @@ -1114,34 +1120,50 @@ def update(
will be used for all module ids in MultiAgentRLModule.

Args:
batch: A batch of data.
minibatch_size: The size of the minibatch to use for each update.
num_iters: The number of complete passes over all the sub-batches
in the input multi-agent batch.
batch: An optional batch of training data. If None, the `episodes` arg
must be provided.
episodes: An optional list of episode objects. If None, the `batch` arg
must be provided.
reduce_fn: reduce_fn: A function to reduce the results from a list of
minibatch updates. This can be any arbitrary function that takes a
list of dictionaries and returns a single dictionary. For example you
can either take an average (default) or concatenate the results (for
example for metrics) or be more selective about you want to report back
to the algorithm's training_step. If None is passed, the results will
not get reduced.
minibatch_size: The size of the minibatch to use for each update.
num_iters: The number of complete passes over all the sub-batches
in the input multi-agent batch.

Returns:
A dictionary of results, in numpy format or a list of such dictionaries in
case `reduce_fn` is None and we have more than one minibatch pass.
"""
self._check_is_built()

missing_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys())
if len(missing_module_ids) > 0:
raise ValueError(
"Batch contains module ids that are not in the learner: "
f"{missing_module_ids}"
# If a (multi-agent) batch is provided, check, whether our RLModule
# contains all ModuleIDs found in this batch. If not, throw an error.
if batch is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the alternative design (two update methods), we could avoid these rather ugly if-blocks then.

unknown_module_ids = set(batch.policy_batches.keys()) - set(
self.module.keys()
)
if len(unknown_module_ids) > 0:
raise ValueError(
"Batch contains module ids that are not in the learner: "
f"{unknown_module_ids}"
)

if num_iters < 1:
# We must do at least one pass on the batch for training.
raise ValueError("`num_iters` must be >= 1")

# Call the train data preprocessor.
batch, episodes = self._preprocess_train_data(batch=batch, episodes=episodes)

# TODO (sven): Insert a call to the Learner ConnectorV2 pipeline here, providing
# it both `batch` and `episode` for further custom processing before the
# actual `Learner._update()` call.

if minibatch_size:
batch_iter = MiniBatchCyclicIterator
elif num_iters > 1:
Expand Down Expand Up @@ -1180,7 +1202,7 @@ def update(
metrics_per_module=defaultdict(dict, **metrics_per_module),
)
self._check_result(result)
# TODO (sven): Figure out whether `compile_metrics` should be forced
# TODO (sven): Figure out whether `compile_results` should be forced
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

# to return all numpy/python data, then we can skip this conversion
# step here.
results.append(convert_to_numpy(result))
Expand All @@ -1201,6 +1223,39 @@ def update(
# dict.
return reduce_fn(results)

@OverrideToImplementCustomLogic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is any neural network inference, does it happen here or in the connector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! The answer is: sometimes both.

For example: If you have some preprocessing needs for your training data (no matter whether episodes or batches), then you might want to do some preprocessing on this data (e.g. clip rewards, extend episodes by one artificial timestep for v-trace or GAE) and then perform a pre-forward pass through your network (e.g. to get the value estimates). For that pre-forward pass, you'll need to call your connector first to make sure this batch has all custom-required data formats (e.g. LSTM zero-padding). Only after all these preprocessing steps, you will be able to continue with the regular forward_train + loss + ... procedure.

def _preprocess_train_data(
self,
*,
batch: Optional[MultiAgentBatch] = None,
episodes: Optional[List[EpisodeType]] = None,
) -> Tuple[Optional[MultiAgentBatch], Optional[List[EpisodeType]]]:
"""Allows custom preprocessing of batch/episode data before the actual update.

The higher level order, in which this method is called from within
`Learner.update(batch, episodes)` is:
* batch, episodes = self._preprocess_train_data(batch, episodes)
* batch = self._learner_connector(batch, episodes)
* results = self._update(batch)

The default implementation does not do any processing and is a mere pass
through. However, specific algorithms should override this method to implement
their specific training data preprocessing needs. It is possible to perform
preliminary RLModule forward passes (besides the main "forward_train()" call
during `self._update`) in this method and custom algorithms might also want to
use this Learner's `self._learner_connector` to prepare the data
(batch/episodes) for such extra forward calls.

Args:
batch: An optional batch of training data to preprocess.
episodes: An optional list of episodes objects to preprocess.

Returns:
A tuple consisting of the processed `batch` and the processed list of
`episodes`.
"""
return batch, episodes

@OverrideToImplementCustomLogic
@abc.abstractmethod
def _update(
Expand Down
Loading
Loading