Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New ConnectorV2 API #04: Changes to Learner/LearnerGroup API to allow updating from Episodes. #41235

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,7 @@ def training_step(self) -> ResultDict:
if self.config._enable_new_api_stack:
is_module_trainable = self.workers.local_worker().is_policy_to_train
self.learner_group.set_is_module_trainable(is_module_trainable)
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update(batch=train_batch)
elif self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self) -> ResultDict:
# Updating the policy.
is_module_trainable = self.workers.local_worker().is_policy_to_train
self.learner_group.set_is_module_trainable(is_module_trainable)
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update(batch=train_batch)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def training_step(self) -> ResultDict:

# Perform the actual update via our learner group.
train_results = self.learner_group.update(
SampleBatch(sample).as_multi_agent(),
batch=SampleBatch(sample).as_multi_agent(),
reduce_fn=self._reduce_results,
)
self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def learn_on_processed_samples(self) -> ResultDict:
for batch in batches:
if blocking:
result = self.learner_group.update(
batch,
batch=batch,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/tests/test_impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_impala_loss(self):
learner_group_config.num_learner_workers = 0
learner_group = learner_group_config.build()
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def training_step(self) -> ResultDict:
is_module_trainable = self.workers.local_worker().is_policy_to_train
self.learner_group.set_is_module_trainable(is_module_trainable)
train_results = self.learner_group.update(
train_batch,
batch=train_batch,
minibatch_size=self.config.sgd_minibatch_size,
num_iters=self.config.num_sgd_iter,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_loss(self):

# Load the algo weights onto the learner_group.
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
58 changes: 48 additions & 10 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,13 +1228,16 @@ def additional_update_for_module(

def update(
self,
batch: MultiAgentBatch,
*,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
batch: Optional[MultiAgentBatch] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to discuss the alternative to provide two different (mutually exclusive?) methods that the user/algo can decide to call: update_from_batch (for algos that do NOT require episode processing, such as DQN) or update_from_episodes (for algos that require a view on the sampled episodes for e.g. vf-bootstrapping, vtrace, etc..).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it if the two methods are separated. I don't think there would be a case where a specific algorithm's learner would have both methods implemented. ie DQN will only implemented update_from_batch, and PPO would only implement update_from_episodes. This is much much cleaner than mixing both into one function. The user will have to deal with less cognitive load if they are separated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated them in the LearnerGroup and Learner APIs:

  • update_from_batch(async=False|True)
  • update_from_episodes(async=False|True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think it's nicer to have the async_update bool option as an extra arg (instead of separate method) for better consistency and less code bloat.

episodes: Optional[List[EpisodeType]] = None,
reduce_fn: Callable[[List[Mapping[str, Any]]], ResultDict] = (
_reduce_mean_results
),
# TODO (sven): Deprecate these in favor of learner hyperparams for only those
# algos actually that need to do minibatching.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
"""Do `num_iters` minibatch updates given the original batch.

Expand All @@ -1254,23 +1257,30 @@ def update(
example for metrics) or be more selective about you want to report back
to the algorithm's training_step. If None is passed, the results will
not get reduced.

Returns:
A dictionary of results, in numpy format or a list of such dictionaries in
case `reduce_fn` is None and we have more than one minibatch pass.
"""
self._check_is_built()

missing_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys())
if len(missing_module_ids) > 0:
raise ValueError(
"Batch contains module ids that are not in the learner: "
f"{missing_module_ids}"
if batch is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the alternative design (two update methods), we could avoid these rather ugly if-blocks then.

unknown_module_ids = (
set(batch.policy_batches.keys()) - set(self.module.keys())
)
if len(unknown_module_ids) > 0:
raise ValueError(
"Batch contains module ids that are not in the learner: "
f"{unknown_module_ids}"
)

if num_iters < 1:
# We must do at least one pass on the batch for training.
raise ValueError("`num_iters` must be >= 1")

# Call the train data preprocessor.
batch, episodes = self._preprocess_train_data(batch=batch, episodes=episodes)

if minibatch_size:
batch_iter = MiniBatchCyclicIterator
elif num_iters > 1:
Expand Down Expand Up @@ -1309,12 +1319,12 @@ def update(
metrics_per_module=defaultdict(dict, **metrics_per_module),
)
self._check_result(result)
# TODO (sven): Figure out whether `compile_metrics` should be forced
# TODO (sven): Figure out whether `compile_results` should be forced
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

# to return all numpy/python data, then we can skip this conversion
# step here.
results.append(convert_to_numpy(result))

batch = self._set_slicing_by_batch_id(batch, value=False)
self._set_slicing_by_batch_id(batch, value=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch never used.


# Reduce results across all minibatches, if necessary.

Expand All @@ -1330,6 +1340,34 @@ def update(
# dict.
return reduce_fn(results)

@OverrideToImplementCustomLogic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is any neural network inference, does it happen here or in the connector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! The answer is: sometimes both.

For example: If you have some preprocessing needs for your training data (no matter whether episodes or batches), then you might want to do some preprocessing on this data (e.g. clip rewards, extend episodes by one artificial timestep for v-trace or GAE) and then perform a pre-forward pass through your network (e.g. to get the value estimates). For that pre-forward pass, you'll need to call your connector first to make sure this batch has all custom-required data formats (e.g. LSTM zero-padding). Only after all these preprocessing steps, you will be able to continue with the regular forward_train + loss + ... procedure.

def _preprocess_train_data(self, *, batch, episodes) -> Tuple[Any, Any]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this should be private or public?

"""Allows custom preprocessing of batch/episode data before the actual update.

The higher level order, in which this method is called from within
`Learner.update(batch, episodes)` is:
* _preprocess_train_data(batch, episodes)
* _learner_connector(batch, episodes)
* _update_from_batch(batch)

The default implementation does not do any processing and is a mere pass through.
However, specific algorithms should override this method to implement their
specific training data preprocessing needs. It is possible to perform separate
forward passes (besides the main "forward_train()" one during
`_update_from_batch`) in this method and custom algorithms might also want to
use this Learner's `self._learner_connector` to prepare the data (batch/episodes)
for such an extra forward call.

Args:
batch: A data batch to preprocess.
episodes: A list of episodes to preprocess.

Returns:
A tuple consisting of the processed `batch` and the processed list of
`episodes`.
"""
return batch, episodes

@OverrideToImplementCustomLogic
@abc.abstractmethod
def _update(
Expand Down
105 changes: 78 additions & 27 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from ray.rllib.core.learner.learner import LearnerSpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
from ray.rllib.utils.minibatch_utils import ShardBatchIterator
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.minibatch_utils import ShardBatchIterator, ShardEpisodesIterator
from ray.rllib.utils.typing import EpisodeType, ResultDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.train._internal.backend_executor import BackendExecutor
from ray.tune.utils.file_transfer import sync_dir_between_nodes
Expand Down Expand Up @@ -146,25 +146,38 @@ def is_local(self) -> bool:

def update(
self,
batch: MultiAgentBatch,
*,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
batch: Optional[MultiAgentBatch] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same discussion as above.

episodes: Optional[List[EpisodeType]] = None,
reduce_fn: Optional[Callable[[List[Mapping[str, Any]]], ResultDict]] = (
_reduce_mean_results
),
# TODO (sven): Deprecate the following args. They should be extracted from the
# LearnerHyperparameters of those specific algorithms that actually require
# these settings.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
"""Do one or more gradient based updates to the Learner(s) based on given data.

Args:
batch: The data batch to use for the update.
batch: The (optional) data batch to use for the update. If there are more
than one Learner workers, the batch is split amongst these and one
shard is sent to each Learner. If `batch` is not provided, the user
must provide the `episodes` arg. Sending both `batch` and `episodes`
is also allowed.
episodes: The (optional) list of Episodes to process and perform the update
for. If there are more than one Learner workers, the list of episodes
is split amongst these and one list shard is sent to each Learner.
If `episodes` is not provided, the user must provide the `batch` arg.
Sending both `batch` and `episodes` is also allowed.
minibatch_size: The minibatch size to use for the update.
num_iters: The number of complete passes over all the sub-batches in the
input multi-agent batch.
reduce_fn: An optional callable to reduce the results from a list of the
Learner actors into a single result. This can be any arbitrary function
that takes a list of dictionaries and returns a single dictionary. For
example you can either take an average (default) or concatenate the
example, you can either take an average (default) or concatenate the
results (for example for metrics) or be more selective about you want to
report back to the algorithm's training_step. If None is passed, the
results will not get reduced.
Expand All @@ -175,60 +188,98 @@ def update(
"""

# Construct a multi-agent batch with only the trainable modules.
train_batch = {}
for module_id in batch.policy_batches.keys():
if self._is_module_trainable(module_id, batch):
train_batch[module_id] = batch.policy_batches[module_id]
train_batch = MultiAgentBatch(train_batch, batch.count)
# TODO (sven): Move this into individual Learners. It might be that
# batch/episodes postprocessing on each Learner requires the non-trainable
# modules' data.
train_batch = None
if batch is not None:
train_batch = {}
for module_id in batch.policy_batches.keys():
if self._is_module_trainable(module_id, batch):
train_batch[module_id] = batch.policy_batches[module_id]
train_batch = MultiAgentBatch(train_batch, batch.count)

if self.is_local:
assert batch is not None or episodes is not None
results = [
self._learner.update(
train_batch,
batch=train_batch,
episodes=episodes,
minibatch_size=minibatch_size,
num_iters=num_iters,
reduce_fn=reduce_fn,
)
]
else:

def _learner_update(learner, minibatch):
def _learner_update(learner: Learner, batch_shard=None, episodes_shard=None):
return learner.update(
minibatch,
batch=batch_shard,
episodes=episodes_shard,
minibatch_size=minibatch_size,
num_iters=num_iters,
reduce_fn=reduce_fn,
)

results = self._get_results(
self._worker_manager.foreach_actor(
[
partial(_learner_update, minibatch=minibatch)
for minibatch in ShardBatchIterator(batch, len(self._workers))
]
# Only batch provided, split it up into n shards.
if episodes is None:
assert batch is not None
results = self._get_results(
self._worker_manager.foreach_actor(
[
partial(_learner_update, batch_shard=batch_shard)
for batch_shard in ShardBatchIterator(train_batch, len(self._workers))
]
)
)
)
elif batch is None:
assert episodes is not None
results = self._get_results(
self._worker_manager.foreach_actor(
[
partial(_learner_update, episodes_shard=episodes_shard)
for episodes_shard in ShardEpisodesIterator(episodes, len(self._workers))
]
)
)
# TODO (sven): Implement the case in which both batch and episodes might
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have mutually exclusive methods: update_from_batch and update_from_episodes, this case (both batch AND episodes provided by user) would not exist anyways.

# already be provided (or figure out whether this makes sense at all).
else:
raise NotImplementedError

# TODO(sven): Move reduce_fn to the training_step
# TODO (sven): Move reduce_fn to the training_step
if reduce_fn is None:
return results
else:
return reduce_fn(results)

def async_update(
self,
batch: MultiAgentBatch,
*,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
batch: Optional[MultiAgentBatch] = None,
episodes: Optional[List[EpisodeType]] = None,
reduce_fn: Optional[Callable[[List[Mapping[str, Any]]], ResultDict]] = (
_reduce_mean_results
),
# TODO (sven): Deprecate the following args. They should be extracted from the
# LearnerHyperparameters of those specific algorithms that actually require
# these settings.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
) -> Union[List[Mapping[str, Any]], List[List[Mapping[str, Any]]]]:
"""Asnychronously do gradient based updates to the Learner(s) with `batch`.

Args:
batch: The data batch to use for the update.
batch: The (optional) data batch to use for the update. If there are more
than one Learner workers, the batch is split amongst these and one
shard is sent to each Learner. If `batch` is not provided, the user
must provide the `episodes` arg. Sending both `batch` and `episodes`
is also allowed.
episodes: The (optional) list of Episodes to process and perform the update
for. If there are more than one Learner workers, the list of episodes
is split amongst these and one list shard is sent to each Learner.
If `episodes` is not provided, the user must provide the `batch` arg.
Sending both `batch` and `episodes` is also allowed.
minibatch_size: The minibatch size to use for the update.
num_iters: The number of complete passes over all the sub-batches in the
input multi-agent batch.
Expand Down
34 changes: 33 additions & 1 deletion rllib/utils/minibatch_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import List

from ray.rllib.policy.sample_batch import MultiAgentBatch, concat_samples
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import EpisodeType


@DeveloperAPI
Expand Down Expand Up @@ -153,3 +155,33 @@ def __iter__(self):
# TODO (Avnish): int(batch_size) ? How should we shard MA batches really?
new_batch = MultiAgentBatch(batch_to_send, int(batch_size))
yield new_batch


@DeveloperAPI
class ShardEpisodesIterator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a unittest for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Iterator for sharding a list of Episodes into num_shards sub-lists of Episodes.

Args:
episodes: The input list of Episodes.
num_shards: The number of shards to split the episodes into.

Yields:
A sub-list of Episodes of size roughly `len(episodes) / num_shards`.
"""
def __init__(self, episodes: List[EpisodeType], num_shards: int):
self._episodes = sorted(episodes, key=len, reverse=True)
self._num_shards = num_shards

def __iter__(self):
# Initialize sub-lists and their total lengths
sublists = [[] for _ in range(self._num_shards)]
lengths = [0 for _ in range(self._num_shards)]

for episodes in self._episodes:
# Find the sub-list with the minimum total length and add the item to it
min_index = lengths.index(min(lengths))
sublists[min_index].append(item)
lengths[min_index] += len(item)

for sublist in sublists:
yield sublist
3 changes: 1 addition & 2 deletions rllib/utils/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def convert_to_numpy(
"""

if reduce_floats != DEPRECATED_VALUE:
deprecation_warning(old="reduce_floats", new="reduce_types", error=True)
reduce_type = reduce_floats
deprecation_warning(old="reduce_floats", new="reduce_type", error=True)

# The mapping function used to numpyize torch/tf Tensors (and move them
# to the CPU beforehand).
Expand Down
Loading