diff --git a/acme/agents/jax/ail/losses.py b/acme/agents/jax/ail/losses.py index 8d68dbd474..bcf08e79b5 100644 --- a/acme/agents/jax/ail/losses.py +++ b/acme/agents/jax/ail/losses.py @@ -21,10 +21,10 @@ from acme.jax import networks as networks_lib import jax import jax.numpy as jnp -import tensorflow_probability as tfp +import tensorflow_probability.substrates.jax as tfp import tree -tfp = tfp.experimental.substrates.jax + tfd = tfp.distributions # The loss is a function taking the discriminator, its state, the demo