Skip to content

Commit

Permalink
Add optional grain dependency & test (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 23, 2024
1 parent 1115693 commit a9f2473
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 6 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ on:
permissions:
contents: read # to fetch code

# TODO(jakevdp): add testing on macOS-11 and windows-2019 when grain supports them,
# or alternatively run a subset of tests without grain.
jobs:
built-latest:
name: Latest packages (${{ matrix.os }} Python ${{ matrix.python-version }})
Expand All @@ -20,9 +22,6 @@ jobs:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.10", "3.11", "3.12"]
include:
- os: windows-2019
python-version: "3.11"

steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
Expand All @@ -35,8 +34,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev,tfds]
pip install -U jax flax optax orbax tensorflow tensorflow_datasets
pip install -U jax flax grain optax orbax tensorflow tensorflow_datasets pytest pytest-xdist
- name: Run tests
run: |
pytest -n auto jax_ml_stack
Expand All @@ -60,7 +58,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev,tfds]
pip install .[dev,tfds,grain]
- name: Run tests
run: |
pytest -n auto jax_ml_stack
47 changes: 47 additions & 0 deletions jax_ml_stack/tests/test_nnx_with_tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import unittest
from flax import nnx
import grain.python as grain
import numpy as np
import optax
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -69,6 +70,52 @@ def loss_fn(model, batch):
_, grads = grad_fn(model, batch)
optimizer.update(grads)

def test_nnx_with_tfds_plus_grain(self):
data_source = tfds.data_source('mnist', split='train')

sampler = grain.IndexSampler(
num_records=5,
num_epochs=1,
shard_options=grain.NoSharding(),
shuffle=True,
seed=0,
)

class DownSample(grain.MapTransform):
shape: tuple[int, int]

def __init__(self, shape: tuple[int, int]):
self.shape = shape

def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
image = element['image']
element['image_scaled'] = image - image.mean()
return element

operations = [DownSample((16, 16))]

loader = grain.DataLoader(
data_source=data_source,
operations=operations,
sampler=sampler,
worker_count=0, # Scale to multiple workers in multiprocessing
)

model = CNN(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate=0.005))

def loss_fn(model, batch):
logits = model(batch['image_scaled'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=np.ravel(batch['label'])
).mean()
return loss, logits

for batch in loader:
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
_, grads = grad_fn(model, batch)
optimizer.update(grads)


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ keywords = []
dependencies = [
"jax==0.4.31",
"flax==0.8.5",
"grain==0.2.0",
"optax==0.2.3",
"orbax==0.1.9",
]
Expand All @@ -33,11 +34,18 @@ dev = [
"pytest",
"pytest-xdist",
]

# TensorFlow datasets is an extra because it has a large install footprint.
tfds = [
"tensorflow==2.17.0",
"tensorflow_datasets==4.9.6",
]

# Grain is an extra because as of v0.2.0 it has no OSX wheels.
grain = [
"grain==0.2.0",
]

[tool.pyink]
# Formatting configuration to follow Google style-guide
line-length = 80
Expand Down

0 comments on commit a9f2473

Please sign in to comment.