Skip to content

jax-ml/jax-ai-stack

Repository files navigation

jax-ai-stack

This is a work-in-progress meta-package for the AI/ML stack built on top of the jax package. It is intended as a location for tests, documentation, and installation instructions that cover multiple packages in the JAX ecosystem.

Installing the stack

The stack can be installed with the following command:

pip install jax-ai-stack

This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:

  • JAX: the core JAX package, which includes array operations and program transformations like jit, vmap, grad, etc.
  • flax: build neural networks with JAX
  • ml_dtypes: NumPy dtype extensions for machine learning.
  • optax: gradient processing and optimization in JAX.
  • orbax: checkpointing and persistence utilities for JAX.

Optional packages

Additionally, there are optional packages you can install with pip extras. The following command:

pip install jax-ai-stack[grain]

will install a compatible version of the grain data loader.

Similarly, the following command:

pip install jax-ai-stack[tfds]

will install a compatible version of tensorflow and tensorflow-datasets.