From 76af6448e368e3168534972c41b9607a7a63b906 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 10 Aug 2023 05:34:15 -0700 Subject: [PATCH] [JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. `jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`. PiperOrigin-RevId: 555454626 Change-Id: I089d1c5a0988f2b608fce41cc345c86f17b8957c --- acme/agents/jax/ppo/normalization.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/acme/agents/jax/ppo/normalization.py b/acme/agents/jax/ppo/normalization.py index 1413f047fd..52d5b50dc8 100644 --- a/acme/agents/jax/ppo/normalization.py +++ b/acme/agents/jax/ppo/normalization.py @@ -132,9 +132,12 @@ def init() -> EMAMeanStdNormalizerParams: biased_second_moment=second_moment, ) - def _normalize_leaf(x: jnp.array, ema_counter: jnp.int32, - biased_first_moment: jnp.array, - biased_second_moment: jnp.array) -> jnp.ndarray: + def _normalize_leaf( + x: jnp.ndarray, + ema_counter: jnp.int32, + biased_first_moment: jnp.ndarray, + biased_second_moment: jnp.ndarray, + ) -> jnp.ndarray: zero_debias = 1. / (1. - jnp.power(tau, ema_counter)) mean = biased_first_moment * zero_debias second_moment = biased_second_moment * zero_debias