This is a work-in-progress meta-package for the AI/ML stack built on top of the jax package. It is intended as a location for tests, documentation, and installation instructions that cover multiple packages in the JAX ecosystem.
The stack can be installed with the following command:
pip install jax-ai-stack
This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:
- JAX: the core JAX package, which includes array operations
and program transformations like
jit
,vmap
,grad
, etc. - flax: build neural networks with JAX
- ml_dtypes: NumPy dtype extensions for machine learning.
- optax: gradient processing and optimization in JAX.
- orbax: checkpointing and persistence utilities for JAX.
Additionally, there are optional packages you can install with pip
extras.
The following command:
pip install jax-ai-stack[grain]
will install a compatible version of the grain data loader.
Similarly, the following command:
pip install jax-ai-stack[tfds]
will install a compatible version of tensorflow and tensorflow-datasets.