Skip to content

Commit

Permalink
[JAX] Fix incorrect type annotations.
Browse files Browse the repository at this point in the history
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases.

PiperOrigin-RevId: 556788994
  • Loading branch information
hawkinsp authored and TF2JAXDev committed Aug 14, 2023
1 parent 1a7a628 commit 22dfe41
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,7 @@ def _matrix_diag(proto):

def _func(
diagonals: jnp.ndarray,
k: jnp.ndarray,
k: int,
num_rows: jnp.ndarray,
num_cols: jnp.ndarray,
padding_value: jnp.ndarray,
Expand Down Expand Up @@ -1408,7 +1408,7 @@ def _pad_v2(proto):

def _func(
inputs: jnp.ndarray,
padding: jnp.ndarray,
padding: Union[Sequence[Sequence[int]], Sequence[int], int],
constant_values: jnp.ndarray,
) -> jnp.ndarray:
return jnp.pad(inputs, pad_width=padding, constant_values=constant_values)
Expand Down Expand Up @@ -1907,8 +1907,8 @@ def _func(
key: jnp.ndarray,
counter: jnp.ndarray,
alg: jnp.ndarray,
minval: jnp.ndarray = jnp.iinfo(jax_dtype).min,
maxval: jnp.ndarray = jnp.iinfo(jax_dtype).max,
minval: Union[jnp.ndarray, int] = jnp.iinfo(jax_dtype).min,
maxval: Union[jnp.ndarray, int] = jnp.iinfo(jax_dtype).max,
) -> jnp.ndarray:
del counter, alg # TODO(b/266553394) combine key and counter?
return jax.random.randint(
Expand Down

0 comments on commit 22dfe41

Please sign in to comment.