Skip to content

Commit

Permalink
[RLlib] Remove 2nd Learner ConnectorV2 pass from PPO (add new GAE Con…
Browse files Browse the repository at this point in the history
…nector piece). Fix: "State-connector" would use `seq_len=20`. (#47401)
  • Loading branch information
sven1977 authored Aug 30, 2024
1 parent 3c950a1 commit d1f21a5
Show file tree
Hide file tree
Showing 35 changed files with 578 additions and 549 deletions.
6 changes: 1 addition & 5 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,7 @@ def build_learner_connector(
# Append all other columns handling.
pipeline.append(AddColumnsFromEpisodesToTrainBatch())
# Append STATE_IN/STATE_OUT (and time-rank) handler.
pipeline.append(
AddStatesFromEpisodesToBatch(
as_learner_connector=True, max_seq_len=self.model.get("max_seq_len")
)
)
pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True))
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent():
pipeline.append(
Expand Down
5 changes: 4 additions & 1 deletion rllib/algorithms/impala/impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def build(self) -> None:
# Extend all episodes by one artificual timestep to allow the value function net
# to compute the bootstrap values (and add a mask to the batch to know, which
# slots to mask out).
if self.config.add_default_connectors_to_learner_pipeline:
if (
self._learner_connector is not None
and self.config.add_default_connectors_to_learner_pipeline
):
self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate())

# Create and start the GPU-loader thread. It picks up train-ready batches from
Expand Down
1 change: 0 additions & 1 deletion rllib/algorithms/impala/impala_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ class VTraceOptimizer:
def __init__(self):
pass

@override(TorchPolicyV2)
def optimizer(
self,
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
Expand Down
10 changes: 9 additions & 1 deletion rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,15 @@ def training(
baseline; required for using GAE).
use_gae: If true, use the Generalized Advantage Estimator (GAE)
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
lambda_: The GAE (lambda) parameter.
lambda_: The lambda parameter for General Advantage Estimation (GAE).
Defines the exponential weight used between actually measured rewards
vs value function estimates over multiple time steps. Specifically,
`lambda_` balances short-term, low-variance estimates with longer-term,
high-variance returns. A `lambda_` of 0.0 makes the GAE rely only on
immediate rewards (and vf predictions from there on, reducing variance,
but increasing bias), while a `lambda_` of 1.0 only incorporates vf
predictions at the truncation points of the given episodes or episode
chunks (reducing bias but increasing variance).
use_kl_loss: Whether to use the KL-term in the loss function.
kl_coeff: Initial coefficient for KL divergence.
kl_target: Target value for KL divergence.
Expand Down
213 changes: 26 additions & 187 deletions rllib/algorithms/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
import abc
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict

from ray.rllib.algorithms.ppo.ppo import (
LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY,
PPOConfig,
)
from ray.rllib.core.columns import Columns
from ray.rllib.connectors.learner import (
AddOneTsToEpisodesAndTruncate,
GeneralAdvantageEstimation,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets
from ray.rllib.utils.postprocessing.episodes import (
add_one_ts_to_episodes_and_truncate,
remove_last_ts_from_data,
remove_last_ts_from_episodes_and_restore_truncateds,
)
from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import EpisodeType, ModuleID, TensorType
from ray.rllib.utils.typing import ModuleID, TensorType


class PPOLearner(Learner):
Expand Down Expand Up @@ -57,156 +50,27 @@ def build(self) -> None:
)
)

@override(Learner)
def _update_from_batch_or_episodes(
self,
*,
batch=None,
episodes=None,
**kwargs,
):
# First perform GAE computation on the entirety of the given train data (all
# episodes).
if self.config.enable_env_runner_and_connector_v2:
batch, episodes = self._compute_gae_from_episodes(episodes=episodes)

# Now that GAE (advantages and value targets) have been added to the train
# batch, we can proceed normally (calling super method) with the update step.
return super()._update_from_batch_or_episodes(
batch=batch,
episodes=episodes,
**kwargs,
)

def _compute_gae_from_episodes(
self,
*,
episodes: Optional[List[EpisodeType]] = None,
) -> Tuple[Optional[Dict[str, Any]], Optional[List[EpisodeType]]]:
"""Computes GAE advantages (and value targets) given a list of episodes.
Note that the episodes may be SingleAgent- or MultiAgentEpisodes and may be
episode chunks (not starting from reset or ending prematurely).
The GAE computation here is performed in a very efficient way via elongating
all given episodes by 1 artificial timestep (last obs, actions, states, etc..
repeated, last reward=0.0, etc..), then creating a forward batch from this data
using the connector, pushing the resulting batch through the value function,
thereby extracting the bootstrap values (at the artificially added time steos)
and all other value predictions (all other timesteps) and then reducing the
batch and episode lengths again accordingly.
"""
if not episodes:
raise ValueError(
"`PPOLearner._compute_gae_from_episodes()` must have the `episodes` "
"arg provided! Otherwise, GAE/advantage computation can't be performed."
)

batch = {}

sa_episodes_list = list(
self._learner_connector.single_agent_episode_iterator(
episodes, agents_that_stepped_only=False
)
)
# Make all episodes one ts longer in order to just have a single batch
# (and distributed forward pass) for both vf predictions AND the bootstrap
# vf computations.
orig_truncateds_of_sa_episodes = add_one_ts_to_episodes_and_truncate(
sa_episodes_list
)

# Call the learner connector (on the artificially elongated episodes)
# in order to get the batch to pass through the module for vf (and
# bootstrapped vf) computations.
batch_for_vf = self._learner_connector(
rl_module=self.module,
batch={},
episodes=episodes,
shared_data={},
)
# Perform the value model's forward pass.
vf_preds = convert_to_numpy(self._compute_values(batch_for_vf))

for module_id, module_vf_preds in vf_preds.items():
# Collect new (single-agent) episode lengths.
episode_lens_plus_1 = [
len(e)
for e in sa_episodes_list
if e.module_id is None or e.module_id == module_id
]

# Remove all zero-padding again, if applicable, for the upcoming
# GAE computations.
module_vf_preds = unpad_data_if_necessary(
episode_lens_plus_1, module_vf_preds
)
# Compute value targets.
module_value_targets = compute_value_targets(
values=module_vf_preds,
rewards=unpad_data_if_necessary(
episode_lens_plus_1,
convert_to_numpy(batch_for_vf[module_id][Columns.REWARDS]),
),
terminateds=unpad_data_if_necessary(
episode_lens_plus_1,
convert_to_numpy(batch_for_vf[module_id][Columns.TERMINATEDS]),
),
truncateds=unpad_data_if_necessary(
episode_lens_plus_1,
convert_to_numpy(batch_for_vf[module_id][Columns.TRUNCATEDS]),
),
gamma=self.config.gamma,
lambda_=self.config.lambda_,
)

# Remove the extra timesteps again from vf_preds and value targets. Now that
# the GAE computation is done, we don't need this last timestep anymore in
# any of our data.
module_vf_preds, module_value_targets = remove_last_ts_from_data(
episode_lens_plus_1,
module_vf_preds,
module_value_targets,
)
module_advantages = module_value_targets - module_vf_preds
# Drop vf-preds, not needed in loss. Note that in the PPORLModule, vf-preds
# are recomputed with each `forward_train` call anyway.
# Standardize advantages (used for more stable and better weighted
# policy gradient computations).
module_advantages = (module_advantages - module_advantages.mean()) / max(
1e-4, module_advantages.std()
)

# Restructure ADVANTAGES and VALUE_TARGETS in a way that the Learner
# connector can properly re-batch these new fields.
batch_pos = 0
for eps in sa_episodes_list:
if eps.module_id is not None and eps.module_id != module_id:
continue
len_ = len(eps) - 1
self._learner_connector.add_n_batch_items(
batch=batch,
column=Postprocessing.ADVANTAGES,
items_to_add=module_advantages[batch_pos : batch_pos + len_],
num_items=len_,
single_agent_episode=eps,
)
self._learner_connector.add_n_batch_items(
batch=batch,
column=Postprocessing.VALUE_TARGETS,
items_to_add=module_value_targets[batch_pos : batch_pos + len_],
num_items=len_,
single_agent_episode=eps,
# Extend all episodes by one artificial timestep to allow the value function net
# to compute the bootstrap values (and add a mask to the batch to know, which
# slots to mask out).
if (
self._learner_connector is not None
and self.config.add_default_connectors_to_learner_pipeline
):
# Before anything, add one ts to each episode (and record this in the loss
# mask, so that the computations at this extra ts are not used to compute
# the loss).
self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate())
# At the end of the pipeline (when the batch is already completed), add the
# GAE connector, which performs a vf forward pass, then computes the GAE
# computations, and puts the results of this (advantages, value targets)
# directly back in the batch. This is then the batch used for
# `forward_train` and `compute_losses`.
self._learner_connector.append(
GeneralAdvantageEstimation(
gamma=self.config.gamma, lambda_=self.config.lambda_
)
batch_pos += len_

# Remove the extra (artificial) timesteps again at the end of all episodes.
remove_last_ts_from_episodes_and_restore_truncateds(
sa_episodes_list,
orig_truncateds_of_sa_episodes,
)

return batch, episodes
)

@override(Learner)
def remove_module(self, module_id: ModuleID, **kwargs):
Expand Down Expand Up @@ -245,31 +109,6 @@ def after_gradient_based_update(
):
self._update_module_kl_coeff(module_id=module_id, config=config)

@OverrideToImplementCustomLogic
def _compute_values(
self,
batch_for_vf: Dict[str, Any],
) -> Union[TensorType, Dict[str, Any]]:
"""Computes the value function predictions for the module being optimized.
This method must be overridden by multiagent-specific algorithm learners to
specify the specific value computation logic. If the algorithm is single agent
(or independent multi-agent), there should be no need to override this method.
Args:
batch_for_vf: The multi-agent batch (mapping ModuleIDs to module data) to
be used for value function predictions.
Returns:
A dictionary mapping module IDs to individual value function prediction
tensors.
"""
return {
module_id: self.module[module_id].unwrapped().compute_values(module_batch)
for module_id, module_batch in batch_for_vf.items()
if self.should_module_be_updated(module_id, batch_for_vf)
}

@abc.abstractmethod
def _update_module_kl_coeff(
self,
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.models.distributions import Distribution
from ray.rllib.utils.annotations import ExperimentalAPI, override
Expand All @@ -16,7 +17,7 @@


@ExperimentalAPI
class PPORLModule(RLModule, abc.ABC):
class PPORLModule(RLModule, ValueFunctionAPI, abc.ABC):
def setup(self):
# __sphinx_doc_begin__
catalog = self.config.get_catalog()
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo_with_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_ppo_compilation_and_schedule_mixins(self):

num_iterations = 2

# TODO (Kourosh) Bring back "FrozenLake-v1"
# TODO (sven) Bring back "FrozenLake-v1"
for env in [
# "CliffWalking-v0",
"CartPole-v1",
Expand Down
9 changes: 4 additions & 5 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,17 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:

return output

# TODO (sven): Try to move entire GAE computation into PPO's loss function (similar
# to IMPALA's v-trace architecture). This would also get rid of the second
# Connector pass currently necessary.
@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
batch_ = batch
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
encoder_outs = self.encoder.critic_encoder(batch)[ENCODER_OUT]
batch_ = batch.copy()
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
# Shared encoder.
else:
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AddObservationsFromEpisodesToBatch(ConnectorV2):
]
This ConnectorV2:
- Operates on a list of Episode objects.
- Operates on a list of Episode objects (single- or multi-agent).
- Gets the most recent observation(s) from all the given episodes and adds them
to the batch under construction (as a list of individual observations).
- Does NOT alter any observations (or other data) in the given episodes.
Expand Down
Loading

0 comments on commit d1f21a5

Please sign in to comment.