Skip to content

Commit

Permalink
Respond to feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 6, 2022
1 parent b50dcaf commit 8a39949
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions docs/jep/12049-type-annotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,13 @@ An advantage of this approach is that it is that the API is already familiar to
- `NDArray`, which is a generic array type used only for type annotations
- `np.ndarray`, which is the array type used for instance checks, and also works for type annotations. In Python 3.9, can use constructs of the form `np.ndarray[shape, dtype]`, but this is currently poorly documented and it's unclear how well supported that is at the moment.

JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, we could do something like the following:
JAX currently implements {class}`jax.numpy.ndarray` for use with runtime `isinstance` checks. It uses a metaclass override to ensure that tracers also will pass `isinstance(tracer, jnp.ndarray)`, but it does not currently have any mechanism to ensure that tracers will be valid for `jnp.ndarray` annotations. If we were to follow numpy's approach to type annotation, changes might look something like this:

- add `jax.typing.NDArray` for use with type annotations
- add `TYPE_CHECKING` logic to ensure that `jnp.ndarray` can be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.
- add `TYPE_CHECKING` logic to ensure that the `jnp.ndarray` object can also be used as an annotation for tracers; and possibly add class-level `__getitem__` to match numpy's shape and dtype specification.

A potential point of confusion with this is that JAX arrays are not actually of type `ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type).
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of all the above).
A potential point of confusion with this is that JAX arrays are not actually of type `jnp.ndarray`, but rather `DeviceArray` or `ShardedDeviceArray` (soon to be unified under a single `jax.Array` type; see {jax-issue}`#12016`).
Following numpy's lead could result in confusion: we'd have `jax.Array`, `jax.numpy.ndarray`, and `jax.typing.NDArray`, each of which is useful in a particular subset of cases but not others (type identity, isinstance checks, type annotations, and tracer-compatible versions of the above).
Despite the familiarity of numpy's API choices, the `ndarray` / `NDArray` / `Array` trichotomy may cause too much confusion

#### Choosing our own path: Unification
Expand All @@ -222,18 +223,15 @@ Python itself is slowly moving to a world of unifying instance and annotation ty
With this in mind, JAX could instead choose to use its (eventual) `jax.Array` type directly for both annotation and instance checks. For handling tracers within type annotations, we could use a construct like the following:
```python
if TYPE_CHECKING:
Array = Union[jax.Array, jax.Tracer]
Array = Union[jax._src.array.Array, jax.Tracer]
else:
Array = jax._src.array.Array
```

For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, we could support constructions like `jax.Array[(3, 4), int]` via Python 3.9 class-level `__getitem__`, similar to how `list[int]` and `dict[str, int]` work for Python 3.9 built-in types.

Another advantage of this unification route (using `jax.Array` as a generic array annotation) is that it fits well with the approach used by the `jaxtyping` library, which uses `jaxtyping.Array[...]` as its core annotation type.
For handling instance checks, we could use the same metaclass override for `jax.Array` that we currently do in the case of `jnp.ndarray`. And if we would like to support more granular shape/dtype-specific annotations in the future, this would set us up to follow the conventions being developed in the [`jaxtyping`](https://github.com/google/jaxtyping/) project. Because `jaxtyping` uses `jaxtyping.Array[...]` as its core annotation type, unifying under `jax.Array` makes any potential future integration with that project more natural.

Given these advantages, it seems like the unified `jax.Array` is the better option compared to splitting annotation and instance logic between `jax.Array`, `jax.typing.NDArray`, and `jnp.ndarray`.


### Implementation Plan

To move forward with type annotations, we will do the following:
Expand Down

0 comments on commit 8a39949

Please sign in to comment.