Skip to content

Commit

Permalink
Merge pull request #4070 from google:nnx-improve-mnist-tutorial
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651195488
  • Loading branch information
Flax Authors committed Jul 10, 2024
2 parents 359ec7f + ebfe115 commit 1b58348
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 195 deletions.
208 changes: 55 additions & 153 deletions docs/nnx/mnist_tutorial.ipynb

Large diffs are not rendered by default.

63 changes: 21 additions & 42 deletions docs/nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ import tensorflow as tf # TensorFlow operations
tf.random.set_seed(0) # set random seed for reproducibility
num_epochs = 10
train_steps = 1200
eval_every = 200
batch_size = 32
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
Expand All @@ -63,11 +64,9 @@ test_ds = test_ds.map(
) # normalize test set
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
train_ds = train_ds.repeat().shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
test_ds = test_ds.shuffle(1024)
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
```
Expand Down Expand Up @@ -117,7 +116,7 @@ nnx.display(y)

## 4. Create Optimizer and Metrics

In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.

```{code-cell} ipython3
import optax
Expand All @@ -134,9 +133,9 @@ metrics = nnx.MultiMetric(
nnx.display(optimizer)
```

## 5. Training step
## 5. Define step functions

We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing.
We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we'll use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics.

```{code-cell} ipython3
def loss_fn(model: CNN, batch):
Expand All @@ -145,49 +144,31 @@ def loss_fn(model: CNN, batch):
logits=logits, labels=batch['label']
).mean()
return loss, logits
```

Next, we create the training step function. This function takes the `model` and a data `batch` and does the following:
* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`.
* Updates training accuracy using the loss, logits, and batch labels.
* Updates model parameters via the optimizer by applying the gradient updates.

```{code-cell} ipython3
@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
"""Train for a single step."""
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
optimizer.update(grads)
```

The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with
[XLA](https://www.tensorflow.org/xla), optimizing performance on
hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),
except it can transforms functions that contain NNX objects as inputs and outputs.
metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates
optimizer.update(grads) # inplace updates
## 6. Evaluation step

Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier.

```{code-cell} ipython3
@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
loss, logits = loss_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates
```

## 7. Seed randomness
The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with
[XLA](https://www.tensorflow.org/xla), optimizing performance on
hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),
except it can transforms functions that contain NNX objects as inputs and outputs.

For reproducible dataset shuffling (using `tf.data.Dataset.shuffle`), set the TF random seed.
**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because NNX transforms respect reference semantics for NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of NNX that allows for a more concise and readable code.

```{code-cell} ipython3
tf.random.set_seed(0)
```
+++

## 8. Train and Evaluate
## 6. Train and Evaluate

Now we train a model using batches of data for 10 epochs, evaluate its performance
on the test set after each epoch, and log the training and testing metrics (loss and
Expand All @@ -196,8 +177,6 @@ accuracy) throughout the process. Typically this leads to a model with around 99
```{code-cell} ipython3
:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
metrics_history = {
'train_loss': [],
'train_accuracy': [],
Expand All @@ -212,7 +191,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
# - the training loss and accuracy batch metrics
train_step(model, optimizer, metrics, batch)
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # one training epoch has passed
# Log training metrics
for metric, value in metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
Expand All @@ -228,18 +207,18 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
metrics.reset() # reset metrics for next training epoch
print(
f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"[train] step: {step}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
)
print(
f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"[test] step: {step}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
)
```

## 9. Visualize Metrics
## 7. Visualize Metrics

Use Matplotlib to create plots for loss and accuracy.

Expand Down

0 comments on commit 1b58348

Please sign in to comment.