Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New ConnectorV2 API #04: Changes to Learner/LearnerGroup API to allow updating from Episodes. #41235

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/rllib/package_ref/learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ Performing Updates
:nosignatures:
:toctree: doc/

Learner.update
Learner.update_from_batch
Learner.update_from_episodes
Learner._update
Learner.additional_update
Learner.additional_update_for_module
Expand Down
22 changes: 13 additions & 9 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,23 @@ Updates

.. testcode::

# This is a blocking update
results = learner_group.update(DUMMY_BATCH)
# This is a blocking update.
results = learner_group.update_from_batch(batch=DUMMY_BATCH)

# This is a non-blocking update. The results are returned in a future
# call to `async_update`
_ = learner_group.async_update(DUMMY_BATCH)
# call to `update_from_batch(..., async_update=True)`
_ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True)

# Artificially wait for async request to be done to get the results
# in the next call to `LearnerGroup.async_update()`.
# in the next call to
# `LearnerGroup.update_from_batch(..., async_update=True)`.
time.sleep(5)
results = learner_group.async_update(DUMMY_BATCH)
results = learner_group.update_from_batch(
batch=DUMMY_BATCH, async_update=True
)
# `results` is a list of results dict. The items in the list represent the different
# remote results from the different calls to `async_update()`.
# remote results from the different calls to
# `update_from_batch(..., async_update=True)`.
assert len(results) > 0
# Each item is a results dict, already reduced over the n Learner workers.
assert isinstance(results[0], dict), results[0]
Expand All @@ -256,8 +260,8 @@ Updates

.. testcode::

# This is a blocking update.
result = learner.update(DUMMY_BATCH)
# This is a blocking update (given a training batch).
result = learner.update_from_batch(batch=DUMMY_BATCH)

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import functools
import gymnasium as gym
import importlib
import importlib.metadata
import json
import logging
import numpy as np
import os
from packaging import version
import importlib.metadata
import re
import tempfile
import time
Expand Down Expand Up @@ -1607,7 +1607,7 @@ def training_step(self) -> ResultDict:
# TODO: (sven) rename MultiGPUOptimizer into something more
# meaningful.
if self.config._enable_new_api_stack:
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update_from_batch(batch=train_batch)
elif self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/appo/tests/test_appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_appo_loss(self):
env=algo.workers.local_worker().env
)
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update_from_batch(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self) -> ResultDict:
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

# Updating the policy.
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update_from_batch(batch=train_batch)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ def training_step(self) -> ResultDict:
)

# Perform the actual update via our learner group.
train_results = self.learner_group.update(
SampleBatch(sample).as_multi_agent(),
train_results = self.learner_group.update_from_batch(
batch=SampleBatch(sample).as_multi_agent(),
reduce_fn=self._reduce_results,
)
self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/dreamerv3/utils/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def report_predicted_vs_sampled_obs(
Continues: Compute MSE (sampled vs predicted).

Args:
results: The results dict that was returned by `LearnerGroup.update()`.
results: The results dict that was returned by
`LearnerGroup.update_from_batch()`.
sample: The sampled data (dict) from the replay buffer. Already tf-tensor
converted.
batch_size_B: The batch size (B). This is the number of trajectories sampled
Expand Down
30 changes: 12 additions & 18 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,24 +946,18 @@ def learn_on_processed_samples(self) -> ResultDict:
self.batches_to_place_on_learner.clear()
# If there are no learner workers and learning is directly on the driver
# Then we can't do async updates, so we need to block.
blocking = self.config.num_learner_workers == 0
async_update = self.config.num_learner_workers > 0
results = []
for batch in batches:
if blocking:
result = self.learner_group.update(
batch,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)
result = self.learner_group.update_from_batch(
batch=batch,
async_update=async_update,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)
if not async_update:
results = [result]
else:
results = self.learner_group.async_update(
batch,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)

for r in results:
self._counters[NUM_ENV_STEPS_TRAINED] += r[ALL_MODULES].pop(
Expand All @@ -973,14 +967,14 @@ def learn_on_processed_samples(self) -> ResultDict:
NUM_AGENT_STEPS_TRAINED
)

self._counters.update(self.learner_group.get_in_queue_stats())
self._counters.update(self.learner_group.get_stats())
# If there are results, reduce-mean over each individual value and return.
if results:
return tree.map_structure(lambda *x: np.mean(x), *results)

# Nothing on the queue -> Don't send requests to learner group
# or no results ready (from previous `self.learner_group.update()` calls) for
# reducing.
# or no results ready (from previous `self.learner_group.update_from_batch()`
# calls) for reducing.
return {}

def place_processed_samples_on_learner_thread_queue(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/tests/test_impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_impala_loss(self):
env=algo.workers.local_worker().env
)
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update_from_batch(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def training_step(self) -> ResultDict:
if self.config._enable_new_api_stack:
# TODO (Kourosh) Clearly define what train_batch_size
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
train_results = self.learner_group.update(
train_batch,
train_results = self.learner_group.update_from_batch(
batch=train_batch,
minibatch_size=self.config.sgd_minibatch_size,
num_iters=self.config.num_sgd_iter,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_loss(self):

# Load the algo weights onto the learner_group.
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update_from_batch(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_value():
assert init_std == 0.0, init_std
batch = compute_gae_for_sample_batch(policy, PENDULUM_FAKE_BATCH.copy())
batch = policy._lazy_tensor_dict(batch)
algo.learner_group.update(batch.as_multi_agent())
algo.learner_group.update_from_batch(batch=batch.as_multi_agent())

# Check the variable is updated.
post_std = get_value()
Expand Down
Loading
Loading