Skip to content

Commit

Permalink
rlax: Upstream Muesli utilities to rlax.
Browse files Browse the repository at this point in the history
We now provide methods for constructing the clipped MPO (CMPO) policy targets used as part of the Muesli agent loss. These CMPO targets are in expectation proportional to: `prior(a|s) * exp(clip(norm(Q(s, a))))` where the prior is computed by the actor policy head, and the Q values are computed using the learned model's reward and value heads.

See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. (https://arxiv.org/pdf/2104.06159.pdf) for more details.

PiperOrigin-RevId: 493987878
  • Loading branch information
suryabhupa authored and RLaxDev committed Dec 13, 2022
1 parent 44ef3f0 commit fe8b3f5
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 1 deletion.
14 changes: 14 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ Policy Optimization
.. autosummary::

clipped_surrogate_pg_loss
cmpo_policy_targets
constant_policy_targets
dpg_loss
entropy_loss
Expand All @@ -238,6 +239,7 @@ Policy Optimization
qpg_loss
rm_loss
rpg_loss
sampled_cmpo_policy_targets
sampled_policy_distillation_loss
zero_policy_targets

Expand All @@ -247,6 +249,18 @@ Clipped Surrogate PG Loss
.. autofunction:: clipped_surrogate_pg_loss


CMPO Policy Targets
~~~~~~~~~~~~~~~~~~~

.. autofunction:: cmpo_policy_targets


Sampled CMPO Policy Targets
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: sampled_cmpo_policy_targets


Compute Parametric KL Penalty and Dual Loss
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions rlax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@
from rlax._src.policy_gradients import qpg_loss
from rlax._src.policy_gradients import rm_loss
from rlax._src.policy_gradients import rpg_loss
from rlax._src.policy_targets import cmpo_policy_targets
from rlax._src.policy_targets import constant_policy_targets
from rlax._src.policy_targets import PolicyTarget
from rlax._src.policy_targets import sampled_cmpo_policy_targets
from rlax._src.policy_targets import sampled_policy_distillation_loss
from rlax._src.policy_targets import zero_policy_targets
from rlax._src.pop_art import art
Expand Down Expand Up @@ -159,6 +161,7 @@
"categorical_td_learning",
"clip_gradient",
"clipped_surrogate_pg_loss",
"cmpo_policy_targets",
"compose_tx",
"conditional_update",
"constant_policy_targets",
Expand Down Expand Up @@ -230,6 +233,7 @@
"rpg_loss",
"sample_start_indices",
"sampled_policy_distillation_loss",
"sampled_cmpo_policy_targets",
"sarsa",
"sarsa_lambda",
"sigmoid",
Expand Down
190 changes: 189 additions & 1 deletion rlax/_src/policy_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to construct and learn from policy targets."""
"""Construct and learn from policy targets. Used by Muesli-based agents."""

import functools

import chex
import distrax
import jax
import jax.numpy as jnp
from rlax._src import base


@chex.dataclass(frozen=True)
Expand Down Expand Up @@ -106,3 +107,190 @@ def sampled_policy_distillation_loss(
# We average over the samples, over time and batch, and if the actions are
# a continuous vector also over the actions.
return -jnp.mean(weights * jnp.maximum(log_probs, min_logp))


def cmpo_policy_targets(
prior_distribution,
embeddings,
rng_key,
baseline_value,
q_provider,
advantage_normalizer,
*,
num_actions,
min_target_advantage=-jnp.inf,
max_target_advantage=1.0,
kl_weight=1.0,
) -> PolicyTarget:
"""Policy targets for Clipped MPO.
The policy targets are in-expectation proportional to:
`prior(a|s) * exp(clip(norm(Q(s, a))))`
See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al.
(https://arxiv.org/pdf/2104.06159.pdf).
Args:
prior_distribution: the prior policy distribution.
embeddings: embeddings for the `q_provider`.
rng_key: a JAX pseudo random number generator key.
baseline_value: the baseline for `advantage_normalizer`.
q_provider: a fn to compute q values.
advantage_normalizer: a fn to normalise advantages.
*,
num_actions: The total number of discrete actions.
min_target_advantage: The minimum advantage of a policy target.
max_target_advantage: The max advantage of a policy target.
kl_weight: The coefficient for the KL regularizer.
Returns:
the clipped MPO policy targets.
"""
# Expecting shape [B].
chex.assert_rank(baseline_value, 1)
rng_key, query_rng_key = jax.random.split(rng_key)
del rng_key

# Producing all actions with shape [num_actions, B].
batch_size, = baseline_value.shape
actions = jnp.broadcast_to(
jnp.expand_dims(jnp.arange(num_actions, dtype=jnp.int32), axis=-1),
[num_actions, batch_size])

# Using vmap over the num_actions in axis=0.
def _query_q(actions):
return q_provider(
# Using the same rng_key for the all actions samples.
rng_key=query_rng_key,
action=actions,
embeddings=embeddings)
qvalues = jax.vmap(_query_q)(actions)

# Using the same advantage normalization as for policy gradients.
raw_advantage = advantage_normalizer(
returns=qvalues, baseline_value=baseline_value)
clipped_advantage = jnp.clip(
raw_advantage, min_target_advantage,
max_target_advantage)

# Construct and normalise the weights.
log_prior = prior_distribution.log_prob(actions)
weights = softmax_policy_target_normalizer(
log_prior + clipped_advantage / kl_weight)
policy_targets = PolicyTarget(actions=actions, weights=weights)
return policy_targets


def sampled_cmpo_policy_targets(
prior_distribution,
embeddings,
rng_key,
baseline_value,
q_provider,
advantage_normalizer,
*,
num_actions=2,
min_target_advantage=-jnp.inf,
max_target_advantage=1.0,
kl_weight=1.0,
) -> PolicyTarget:
"""Policy targets for sampled CMPO.
As in CMPO the policy targets are in-expectation proportional to:
`prior(a|s) * exp(clip(norm(Q(s, a))))`
However we only sample a subset of the actions, this allows to scale to
large discrete action spaces and to continuous actions.
See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al.
(https://arxiv.org/pdf/2104.06159.pdf).
Args:
prior_distribution: the prior policy distribution.
embeddings: embeddings for the `q_provider`.
rng_key: a JAX pseudo random number generator key.
baseline_value: the baseline for `advantage_normalizer`.
q_provider: a fn to compute q values.
advantage_normalizer: a fn to normalise advantages.
*,
num_actions: The number of actions to expand on each step.
min_target_advantage: The minimum advantage of a policy target.
max_target_advantage: The max advantage of a policy target.
kl_weight: The coefficient for the KL regularizer.
Returns:
the sampled clipped MPO policy targets.
"""
# Expecting shape [B].
chex.assert_rank(baseline_value, 1)
query_rng_key, action_key = jax.random.split(rng_key)
del rng_key

# Sampling the actions from the prior.
actions = prior_distribution.sample(
seed=action_key, sample_shape=[num_actions])

# Using vmap over the num_expanded in axis=0.
def _query_q(actions):
return q_provider(
# Using the same rng_key for the all actions samples.
rng_key=query_rng_key,
action=actions,
embeddings=embeddings)
qvalues = jax.vmap(_query_q)(actions)

# Using the same advantage normalization as for policy gradients.
raw_advantage = advantage_normalizer(
returns=qvalues, baseline_value=baseline_value)
clipped_advantage = jnp.clip(
raw_advantage, min_target_advantage, max_target_advantage)

# The expected normalized weight would be 1.0. The weights would be
# normalized, if the baseline_value is the log of the expected weight. I.e.,
# if the baseline_value is log(sum_a(prior(a|s) * exp(Q(s, a)/c))).
weights = jnp.exp(clipped_advantage / kl_weight)

# The weights are tiled, if using multiple continuous actions.
# It is OK to use multiple continuous actions inside the Q(s, a),
# because the action is sampled from the joint distribution
# and weight is not based on non-joint probabilities.
log_prob = prior_distribution.log_prob(actions)
weights = jnp.broadcast_to(
base.lhs_broadcast(weights, log_prob), log_prob.shape)
return PolicyTarget(actions=actions, weights=weights)


def softmax_policy_target_normalizer(log_weights):
"""Returns self-normalized weights.
The self-normalizing weights introduce a significant bias,
if computing the average weight from a small number of samples.
Args:
log_weights: log unnormalized weights, shape `[num_targets, ...]`.
Returns:
Weights divided by average weight from sample. Weights sum to `num_targets`.
"""
num_targets = log_weights.shape[0]
return num_targets * jax.nn.softmax(log_weights, axis=0)


def loo_policy_target_normalizer(log_weights):
"""A leave-one-out normalizer.
Args:
log_weights: log unnormalized weights, shape `[num_targets, ...]`.
Returns:
Weights divided by a consistent estimate of the average weight. The weights
are not guaranteed to sum to `num_targets`.
"""
num_targets = log_weights.shape[0]
weights = jnp.exp(log_weights)
# Using a safe consistent estimator of the average weight, independently of
# the numerator.
# The unnormalized weight are already approximately normalized by a
# baseline_value, so we use `1` as the initial estimate of the average weight.
avg_weight = (
1 + jnp.sum(weights, axis=0, keepdims=True) - weights) / num_targets
return weights / avg_weight

0 comments on commit fe8b3f5

Please sign in to comment.