From a9f24732e10cdcbf95890558cdcb509dacb25419 Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Fri, 23 Aug 2024 14:37:42 -0700 Subject: [PATCH] Add optional grain dependency & test (#9) --- .github/workflows/test.yaml | 10 ++--- jax_ml_stack/tests/test_nnx_with_tfds.py | 47 ++++++++++++++++++++++++ pyproject.toml | 8 ++++ 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 36006d7..9d30d65 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,6 +12,8 @@ on: permissions: contents: read # to fetch code +# TODO(jakevdp): add testing on macOS-11 and windows-2019 when grain supports them, +# or alternatively run a subset of tests without grain. jobs: built-latest: name: Latest packages (${{ matrix.os }} Python ${{ matrix.python-version }}) @@ -20,9 +22,6 @@ jobs: matrix: os: ["ubuntu-latest"] python-version: ["3.10", "3.11", "3.12"] - include: - - os: windows-2019 - python-version: "3.11" steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 @@ -35,8 +34,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install .[dev,tfds] - pip install -U jax flax optax orbax tensorflow tensorflow_datasets + pip install -U jax flax grain optax orbax tensorflow tensorflow_datasets pytest pytest-xdist - name: Run tests run: | pytest -n auto jax_ml_stack @@ -60,7 +58,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install .[dev,tfds] + pip install .[dev,tfds,grain] - name: Run tests run: | pytest -n auto jax_ml_stack diff --git a/jax_ml_stack/tests/test_nnx_with_tfds.py b/jax_ml_stack/tests/test_nnx_with_tfds.py index 1230b8f..4675b66 100644 --- a/jax_ml_stack/tests/test_nnx_with_tfds.py +++ b/jax_ml_stack/tests/test_nnx_with_tfds.py @@ -16,6 +16,7 @@ import functools import unittest from flax import nnx +import grain.python as grain import numpy as np import optax import tensorflow_datasets as tfds @@ -69,6 +70,52 @@ def loss_fn(model, batch): _, grads = grad_fn(model, batch) optimizer.update(grads) + def test_nnx_with_tfds_plus_grain(self): + data_source = tfds.data_source('mnist', split='train') + + sampler = grain.IndexSampler( + num_records=5, + num_epochs=1, + shard_options=grain.NoSharding(), + shuffle=True, + seed=0, + ) + + class DownSample(grain.MapTransform): + shape: tuple[int, int] + + def __init__(self, shape: tuple[int, int]): + self.shape = shape + + def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + image = element['image'] + element['image_scaled'] = image - image.mean() + return element + + operations = [DownSample((16, 16))] + + loader = grain.DataLoader( + data_source=data_source, + operations=operations, + sampler=sampler, + worker_count=0, # Scale to multiple workers in multiprocessing + ) + + model = CNN(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adamw(learning_rate=0.005)) + + def loss_fn(model, batch): + logits = model(batch['image_scaled']) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=np.ravel(batch['label']) + ).mean() + return loss, logits + + for batch in loader: + grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) + _, grads = grad_fn(model, batch) + optimizer.update(grads) + if __name__ == '__main__': unittest.main() diff --git a/pyproject.toml b/pyproject.toml index 7549e2f..099af66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ keywords = [] dependencies = [ "jax==0.4.31", "flax==0.8.5", + "grain==0.2.0", "optax==0.2.3", "orbax==0.1.9", ] @@ -33,11 +34,18 @@ dev = [ "pytest", "pytest-xdist", ] + +# TensorFlow datasets is an extra because it has a large install footprint. tfds = [ "tensorflow==2.17.0", "tensorflow_datasets==4.9.6", ] +# Grain is an extra because as of v0.2.0 it has no OSX wheels. +grain = [ + "grain==0.2.0", +] + [tool.pyink] # Formatting configuration to follow Google style-guide line-length = 80