Skip to content

Commit

Permalink
[RLlib] New ConnectorV2 API #4: Changes to Learner/LearnerGroup API t…
Browse files Browse the repository at this point in the history
…o allow updating from Episodes. (#41235)
  • Loading branch information
sven1977 authored Jan 10, 2024
1 parent 65478d4 commit 806701e
Show file tree
Hide file tree
Showing 18 changed files with 611 additions and 301 deletions.
3 changes: 2 additions & 1 deletion doc/source/rllib/package_ref/learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ Performing Updates
:nosignatures:
:toctree: doc/

Learner.update
Learner.update_from_batch
Learner.update_from_episodes
Learner._update
Learner.additional_update
Learner.additional_update_for_module
Expand Down
22 changes: 13 additions & 9 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,23 @@ Updates

.. testcode::

# This is a blocking update
results = learner_group.update(DUMMY_BATCH)
# This is a blocking update.
results = learner_group.update_from_batch(batch=DUMMY_BATCH)

# This is a non-blocking update. The results are returned in a future
# call to `async_update`
_ = learner_group.async_update(DUMMY_BATCH)
# call to `update_from_batch(..., async_update=True)`
_ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True)

# Artificially wait for async request to be done to get the results
# in the next call to `LearnerGroup.async_update()`.
# in the next call to
# `LearnerGroup.update_from_batch(..., async_update=True)`.
time.sleep(5)
results = learner_group.async_update(DUMMY_BATCH)
results = learner_group.update_from_batch(
batch=DUMMY_BATCH, async_update=True
)
# `results` is a list of results dict. The items in the list represent the different
# remote results from the different calls to `async_update()`.
# remote results from the different calls to
# `update_from_batch(..., async_update=True)`.
assert len(results) > 0
# Each item is a results dict, already reduced over the n Learner workers.
assert isinstance(results[0], dict), results[0]
Expand All @@ -256,8 +260,8 @@ Updates

.. testcode::

# This is a blocking update.
result = learner.update(DUMMY_BATCH)
# This is a blocking update (given a training batch).
result = learner.update_from_batch(batch=DUMMY_BATCH)

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import functools
import gymnasium as gym
import importlib
import importlib.metadata
import json
import logging
import numpy as np
import os
from packaging import version
import importlib.metadata
import re
import tempfile
import time
Expand Down Expand Up @@ -1607,7 +1607,7 @@ def training_step(self) -> ResultDict:
# TODO: (sven) rename MultiGPUOptimizer into something more
# meaningful.
if self.config._enable_new_api_stack:
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update_from_batch(batch=train_batch)
elif self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/appo/tests/test_appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_appo_loss(self):
env=algo.workers.local_worker().env
)
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update_from_batch(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self) -> ResultDict:
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

# Updating the policy.
train_results = self.learner_group.update(train_batch)
train_results = self.learner_group.update_from_batch(batch=train_batch)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ def training_step(self) -> ResultDict:
)

# Perform the actual update via our learner group.
train_results = self.learner_group.update(
SampleBatch(sample).as_multi_agent(),
train_results = self.learner_group.update_from_batch(
batch=SampleBatch(sample).as_multi_agent(),
reduce_fn=self._reduce_results,
)
self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/dreamerv3/utils/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def report_predicted_vs_sampled_obs(
Continues: Compute MSE (sampled vs predicted).
Args:
results: The results dict that was returned by `LearnerGroup.update()`.
results: The results dict that was returned by
`LearnerGroup.update_from_batch()`.
sample: The sampled data (dict) from the replay buffer. Already tf-tensor
converted.
batch_size_B: The batch size (B). This is the number of trajectories sampled
Expand Down
30 changes: 12 additions & 18 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,24 +946,18 @@ def learn_on_processed_samples(self) -> ResultDict:
self.batches_to_place_on_learner.clear()
# If there are no learner workers and learning is directly on the driver
# Then we can't do async updates, so we need to block.
blocking = self.config.num_learner_workers == 0
async_update = self.config.num_learner_workers > 0
results = []
for batch in batches:
if blocking:
result = self.learner_group.update(
batch,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)
result = self.learner_group.update_from_batch(
batch=batch,
async_update=async_update,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)
if not async_update:
results = [result]
else:
results = self.learner_group.async_update(
batch,
reduce_fn=_reduce_impala_results,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)

for r in results:
self._counters[NUM_ENV_STEPS_TRAINED] += r[ALL_MODULES].pop(
Expand All @@ -973,14 +967,14 @@ def learn_on_processed_samples(self) -> ResultDict:
NUM_AGENT_STEPS_TRAINED
)

self._counters.update(self.learner_group.get_in_queue_stats())
self._counters.update(self.learner_group.get_stats())
# If there are results, reduce-mean over each individual value and return.
if results:
return tree.map_structure(lambda *x: np.mean(x), *results)

# Nothing on the queue -> Don't send requests to learner group
# or no results ready (from previous `self.learner_group.update()` calls) for
# reducing.
# or no results ready (from previous `self.learner_group.update_from_batch()`
# calls) for reducing.
return {}

def place_processed_samples_on_learner_thread_queue(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/tests/test_impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_impala_loss(self):
env=algo.workers.local_worker().env
)
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update_from_batch(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def training_step(self) -> ResultDict:
if self.config._enable_new_api_stack:
# TODO (Kourosh) Clearly define what train_batch_size
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
train_results = self.learner_group.update(
train_batch,
train_results = self.learner_group.update_from_batch(
batch=train_batch,
minibatch_size=self.config.sgd_minibatch_size,
num_iters=self.config.num_sgd_iter,
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_loss(self):

# Load the algo weights onto the learner_group.
learner_group.set_weights(algo.get_weights())
learner_group.update(train_batch.as_multi_agent())
learner_group.update_from_batch(batch=train_batch.as_multi_agent())

algo.stop()

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_value():
assert init_std == 0.0, init_std
batch = compute_gae_for_sample_batch(policy, PENDULUM_FAKE_BATCH.copy())
batch = policy._lazy_tensor_dict(batch)
algo.learner_group.update(batch.as_multi_agent())
algo.learner_group.update_from_batch(batch=batch.as_multi_agent())

# Check the variable is updated.
post_std = get_value()
Expand Down
Loading

0 comments on commit 806701e

Please sign in to comment.