diff --git a/jax/_src/api.py b/jax/_src/api.py index c31ea71745d1..c61afaca5952 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -173,7 +173,7 @@ def jit( ``in_shardings`` or an error is raised, and the compiled computation has input shardings corresponding to ``in_shardings``. If not provided, the compiled computation's input shardings are inferred from argument - sharings. + shardings. out_shardings: optional, a :py:class:`Sharding` or pytree with :py:class:`Sharding` leaves and structure that is a tree prefix of the output of ``fun``. If provided, it has the same effect as applying