Skip to content

Commit

Permalink
Depend on the jax substrate of tensorflow_probability explicitly.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625006545
Change-Id: Ib189d2bdd39687d8aaf6acb124ef72932613a474
  • Loading branch information
ThomasColthurst authored and Copybara-Service committed Apr 15, 2024
1 parent 440df89 commit aa42e1c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
3 changes: 1 addition & 2 deletions acme/agents/jax/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions

EntropyFn = Callable[
Expand Down
26 changes: 17 additions & 9 deletions acme/jax/losses/mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@

import jax
import jax.numpy as jnp
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp

tfp = tensorflow_probability.substrates.jax
tfd = tensorflow_probability.substrates.jax.distributions
tfd = tfp.distributions

_MPO_FLOAT_EPSILON = 1e-8
_MIN_LOG_TEMPERATURE = -18.0
Expand Down Expand Up @@ -242,13 +241,19 @@ def __call__(
diff_out_of_bound = actions - jnp.clip(actions, -1.0, 1.0)
cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1)

penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss(
cost_out_of_bound, self._epsilon_penalty, penalty_temperature)
penalty_normalized_weights, loss_penalty_temperature = (
compute_weights_and_temperature_loss(
cost_out_of_bound, self._epsilon_penalty, penalty_temperature
)
)

# Only needed for diagnostics: Compute estimated actualized KL between the
# non-parametric and current target policies.
penalty_kl_nonparametric = compute_nonparametric_kl_from_normalized_weights(
penalty_normalized_weights)
penalty_kl_nonparametric = (
compute_nonparametric_kl_from_normalized_weights(
penalty_normalized_weights
)
)

# Combine normalized weights.
normalized_weights += penalty_normalized_weights
Expand Down Expand Up @@ -283,8 +288,11 @@ def __call__(
# Compute the alpha-weighted KL-penalty and dual losses to adapt the alphas.
loss_kl_mean, loss_alpha_mean = compute_parametric_kl_penalty_and_dual_loss(
kl_mean, alpha_mean, self._epsilon_mean)
loss_kl_stddev, loss_alpha_stddev = compute_parametric_kl_penalty_and_dual_loss(
kl_stddev, alpha_stddev, self._epsilon_stddev)
loss_kl_stddev, loss_alpha_stddev = (
compute_parametric_kl_penalty_and_dual_loss(
kl_stddev, alpha_stddev, self._epsilon_stddev
)
)

# Combine losses.
loss_policy = loss_policy_mean + loss_policy_stddev
Expand Down
4 changes: 2 additions & 2 deletions acme/jax/networks/multiplexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from acme.jax import utils
import haiku as hk
import jax.numpy as jnp
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp

tfd = tensorflow_probability.substrates.jax.distributions
tfd = tfp.distributions
ModuleOrArrayTransform = Union[hk.Module, Callable[[jnp.ndarray], jnp.ndarray]]


Expand Down

0 comments on commit aa42e1c

Please sign in to comment.