JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome Equinox JAX projects, and other resources. Contributions are welcome!
Official examples can be found in the
Models:
-
PaLM-jax - Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
-
mistral_jax - This is a port of Mistral-7B model in JAX
Projects and Packages:
-
levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
-
eqxvision - A Python package of computer vision models for the Equinox ecosystem.
-
haliax - Named Tensors for Legible Deep Learning in JAX
-
diffrax: numerical differential equation solvers.
-
lineax - Linear solvers in JAX and Equinox.
-
optimistix - Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox.
-
sympy2jax - Turn SymPy expressions into trainable JAX expressions.
-
flowMC - Normalizing-flow enhanced sampling package for probabilistic inference in Jax
-
flowjax - FlowJax: Distributions and Normalizing Flows in Jax
-
traceax - Traceax: Stochastic trace estimation using JAX
-
galax - Galactic and Gravitational Dynamics in Python (+ GPU and autodiff)
-
coordinax - Coordinates in JAX
-
unxt - Unitful Quantities in JAX
-
statedict2pytree - Transforming a PyTorch model into an JAX PyTree
Always useful
jaxtyping: type annotations for shape/dtype of arrays.
Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Scientific computing
BlackJAX: probabilistic+Bayesian sampling.
PySR: symbolic regression. (Non-JAX honourable mention!)
Awesome JAX
Awesome JAX: a longer list of other JAX projects.
Contributions welcome! Read the contribution guidelines first.
Repository inspired by: