Skip to content

Commit

Permalink
Reword some things
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 16, 2022
1 parent a118c6f commit c4ae178
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions docs/jep/11859-type-annotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,29 @@ With this framing (Level 1/2/3) and JAX-specific challenges in mind, we can begi

For JAX type annotation, we have the following goals:

1. We would like to support full, *Level 1, 2, and 3* type annotation. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.
1. We would like to support full, *Level 1, 2, and 3* type annotation as far as possible. In particular, this means that we should have restrictive type annotations on both inputs and outputs to public API functions.

2. When functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. The reason for this is that without the mechanisms of [PEP 612](https://peps.python.org/pep-0612/) there is no good way to do otherwise, and `ParamSpec` will not be available for use until Python 3.10.
2. In order to not add undue development friction (due to the internal/external CI differences), we would like to be conservative in the type annotation constructs we use: in particular, when it comes to recently-introduced mechanisms such as `ParamSpec` (PEP [PEP 612](https://peps.python.org/pep-0612/),), we would like to wait until support in mypy and other tools stabilizes before relying on them.
One impact of this is that for the time being, when functions are decorated by jax transformations like `jit`, `vmap`, `grad`, etc. JAX will **strip all annotations**. This is because `ParamSpec` is still only partially supported; the PEP is slated for Python 3.10 (though it can be used before that via [typing-extensions](https://github.com/python/typing_extensions) and at the time of this writing mypy has a laundry-list of incompatibilities with the `ParamSpec`-based annotations (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)).
We will revisit this question in the future once support for such features stabilizes.

3. We will design JAX type annotations to annotate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape). Inputs to JAX functions should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
3. JAX type annotations shoudl in general indicate the **intent** of APIs, rather than the implementation, so that the annotations become useful to communicate the contract of the API. This means that at times inputs that are valid at runtime may not be recognized as valid by the static type checker (a simple example is an arbitrary iterator passed in place of a shape in some function implementations).

4. Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or dtype-adjacent classes such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:

- `NDArray`
- `ArrayLike`
- `DtypeLike`
- `ShapeLike`
- etc.

Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation.
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in most places, so the type definition will be simpler than the numpy analog.

5. Conversely, outputs of functions and methods should be typed as strictly as possible: for example, for a JAX function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent. For this purpose, we will implement in {mod}`jax.typing` several strictly-typed analogs of the permissive types mentioned above, namely:

4. Function outputs should be typed as strictly as possible: for example, for a function that returns an array, the output should be annotated with `jnp.ndarray` rather than `ArrayLike`. Functions returning a dtype should always be annotated `np.dtype`, and functions returning a shape should always be `Tuple[int]` or a strictly-typed NamedShape equivalent.
- `NDArray` (perhaps this could be equivalent to `jnp.ndarray`?)
- `DType` (perhaps this could be simply `np.dtype`?)
- `Shape`
- `NamedShape`
- etc.

5. Aside from common typing protocols gathered in `jax.typing`, we should avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification cannot be simply specified. This is a comprimise that achieces the goals of Level 1 and 2 annotations, while punting on Level 3 for complicated APIs.
6. Aside from common typing protocols gathered in `jax.typing`, we should err on the side of simplicity, and avoid constructing overly-complex protocols for arguments passed to API functions, and instead use simple unions such as `Union[simple_type, Any]` in the case that the full type specification of the API cannot be succinctly specified. This is a comprimise that achieves the goals of Level 1 and 2 annotations, while punting on Level 3 in favor of simplicity.

0 comments on commit c4ae178

Please sign in to comment.