From 22dfe41e95a1151ffe1faf2cb91d9a7f30711726 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 14 Aug 2023 07:52:53 -0700 Subject: [PATCH] [JAX] Fix incorrect type annotations. An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases. PiperOrigin-RevId: 556788994 --- tf2jax/_src/ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 1f17766..a0b829a 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1189,7 +1189,7 @@ def _matrix_diag(proto): def _func( diagonals: jnp.ndarray, - k: jnp.ndarray, + k: int, num_rows: jnp.ndarray, num_cols: jnp.ndarray, padding_value: jnp.ndarray, @@ -1408,7 +1408,7 @@ def _pad_v2(proto): def _func( inputs: jnp.ndarray, - padding: jnp.ndarray, + padding: Union[Sequence[Sequence[int]], Sequence[int], int], constant_values: jnp.ndarray, ) -> jnp.ndarray: return jnp.pad(inputs, pad_width=padding, constant_values=constant_values) @@ -1907,8 +1907,8 @@ def _func( key: jnp.ndarray, counter: jnp.ndarray, alg: jnp.ndarray, - minval: jnp.ndarray = jnp.iinfo(jax_dtype).min, - maxval: jnp.ndarray = jnp.iinfo(jax_dtype).max, + minval: Union[jnp.ndarray, int] = jnp.iinfo(jax_dtype).min, + maxval: Union[jnp.ndarray, int] = jnp.iinfo(jax_dtype).max, ) -> jnp.ndarray: del counter, alg # TODO(b/266553394) combine key and counter? return jax.random.randint(