From 0eed28a01047bbc8a987c6095fd4fc10635c1b55 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 6 May 2024 04:59:23 -0700 Subject: [PATCH] Fix a typo in jax.jit docstring --- jax/_src/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index c31ea71745d1..c61afaca5952 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -173,7 +173,7 @@ def jit( ``in_shardings`` or an error is raised, and the compiled computation has input shardings corresponding to ``in_shardings``. If not provided, the compiled computation's input shardings are inferred from argument - sharings. + shardings. out_shardings: optional, a :py:class:`Sharding` or pytree with :py:class:`Sharding` leaves and structure that is a tree prefix of the output of ``fun``. If provided, it has the same effect as applying