Before you start, you should have some experience in PyTorch or TensorFlow. The Jax may be a difficult library to learn, but it's worth it.
Get started easily with training a model using JAX, simply fit it. If you are familiar with PyTorch Lightning or TensorFlow Keras, you will love this library. It's a simple and lightweight library for training your model with JAX in a few lines of code.
- jax (jax, jaxlib)
- flax to define your model
- optax for optimizer, learning rate schedule, and loss function
- orbax to save the checkpoints
- tqdm to show the progress bar
- tensorboardX to log the training process
Fork this repository or copy the
file to your project.
It's a template for training your model. But only three key parts of code are required to modify in your script.
import jax, flax, optax, orbax
from fit import lr_schedule, TrainState
# prepare your dataset
train_ds, test_ds = your_dataset()
# lr schedule
lr_fn = lr_schedule(
# key 1: your model
model = YourModel()
# init key and model
key = jax.random.PRNGKey(0)
x = jnp.ones((1, 28, 28, 1)) # MNIST example input
var = model.init(key, x, train=True)
state = TrainState.create(
# your training step, the template in the next section
def loss_fn():
# key 2: your loss function
return state, loss_dict, opt_state
# your evaluation step
def eval_step():
# key 3: your evaluation function
return acc
fit(state, train_ds, test_ds,
Let's start with a simple example, training a model on the MNIST dataset. First, import the fit
module in your training script.
from fit import *
Before training, you need to define your model, loss function, and evaluation function. Let's start with the model.
The following is a very simple example of a model. The setup
function is used to define the model structure, and the __call__
function defines the forward pass of the model.
class Model(nn.Module):
def setup(self):
self.conv1 = nn.Conv(features=16, kernel_size=(3, 3))
self.dense1 = nn.Dense(features=10)
# train=False for evaluation mode
# if you use dropout or batch normalization
# I bet you will use it
def __call__(self, x, train=False):
# simple conv + bn + relu + fully connected layer
x = self.conv1(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
# dropout layer
x = nn.Dropout(rate=0.5)(x, deterministic=not train)
# flatten
x = x.reshape((x.shape[0], -1))
x = self.dense1(x)
return x
Then, only two things are required to consider: loss function and evaluation function.
Let's focus on the loss_fn
function. Let's start with the pseudo pytorch style code. It's helpful to understand the loss_fn
function in Jax.
def loss_fn():
loss = criterion(logits, labels)
return loss
Easy, right? Let's continue to let's keep.
def loss_fn(logits, labels):
loss = optax.softmax_cross_entropy(
jax.nn.one_hot(labels, 10)
# put the losses you want to log to tensorboard
loss_dict = {'loss': loss}
return loss, loss_dict
Notice that your loss function should return a total loss value and a dictionary which you want to log to tensorboard.
Now, let's move on to the evaluation function with the pseudo pytorch style code.
def eval_step():
true_x, true_y = data
pred_y = model(true_x)
# your metric function such as accuracy in pytorch
acc = metric(pred_y, true_y)
return acc
In pytorch, you can use the model.eval()
function to switch the model to evaluation mode. Because the dropout layer and batch normalization layer have different behaviors in training and evaluation mode. In Jax, you need to set the train=False
argument in the apply_fn
function. Notice that your model structure should be different in training and evaluation mode if you use the dropout layer or batch normalization layer, see the __call__
function in the Model section.
It's similar to the train_step
function and only requires the state
object and the batch
def eval_step(state: TrainState, batch):
x, y = batch
logits = state.apply_fn({
'params': state.params,
'batch_stats': state.batch_stats,
}, x, train=False)
acc = jnp.equal(jnp.argmax(logits, -1), y).mean()
return acc
Prepare your dataset and data loaders for training and evaluation. You can use the TensorFlow Datasets or Torchvision Datasets to load the any dataset you want. Here is an example of loading the MNIST dataset.
ds = tfds.load("mnist", split="train", as_supervised=True)
train_ds = ds.take(50000).map(lambda x, y: (x / 255, y))
ds = torchvision.datasets.MNIST(
root="data", train=True, download=True,
train_ds =, batch_size=32, shuffle=True)
By the way, lr_schedule
is used to create the learning rate function, which is required by the TrainState
object. You can define your own learning rate function, or use the default one:
lr_fn = lr_schedule(base_lr=1e-3,
Furthermore, you can define your own chainable update transformations, check the optax
library for more information.
state = TrainState.create(
# chainable update transformations
Finally, call the fit
function to start training.
fit(state, train_ds, test_ds,
# evaluate the model every N epochs (default 1)
# log name for tensorboard
# hyperparameters for the training process
# such as batch size, learning rate, etc.
# it's optional for you
'batch_size': 32,
'lr': 1e-3,
You can open the tensorboard to see the training process or check any loss and accuracy metrics.
What's the @jax.jit decorator?
It's a decorator to compile the function to a single static function, which can be executed on GPU or TPU, if you want to speed up the training process especially for the your own loss function and evaluation function, you can add the @jax.jit
What's the batch state and the dropout key?
The batch state is used to store the batch normalization statistics, and the dropout key is used to generate the random mask for the dropout layer.