diff --git a/acme/jax/running_statistics_test.py b/acme/jax/running_statistics_test.py index 2119515737..64b8105b64 100644 --- a/acme/jax/running_statistics_test.py +++ b/acme/jax/running_statistics_test.py @@ -21,7 +21,7 @@ from acme import specs from acme.jax import running_statistics import jax -from jax.config import config as jax_config +from jax import config as jax_config import jax.numpy as jnp import numpy as np import tree