Skip to content

Commit

Permalink
[JAX] Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 555454626
Change-Id: I089d1c5a0988f2b608fce41cc345c86f17b8957c
  • Loading branch information
hawkinsp authored and Copybara-Service committed Aug 10, 2023
1 parent 95b7cfe commit 76af644
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions acme/agents/jax/ppo/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ def init() -> EMAMeanStdNormalizerParams:
biased_second_moment=second_moment,
)

def _normalize_leaf(x: jnp.array, ema_counter: jnp.int32,
biased_first_moment: jnp.array,
biased_second_moment: jnp.array) -> jnp.ndarray:
def _normalize_leaf(
x: jnp.ndarray,
ema_counter: jnp.int32,
biased_first_moment: jnp.ndarray,
biased_second_moment: jnp.ndarray,
) -> jnp.ndarray:
zero_debias = 1. / (1. - jnp.power(tau, ema_counter))
mean = biased_first_moment * zero_debias
second_moment = biased_second_moment * zero_debias
Expand Down

0 comments on commit 76af644

Please sign in to comment.