-
Notifications
You must be signed in to change notification settings - Fork 648
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a1b37dff", | ||
"metadata": {}, | ||
"source": [ | ||
"# NNX Demo" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e8099a6f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from functools import partial\n", | ||
"import jax\n", | ||
"from jax import random, numpy as jnp\n", | ||
"from flax.experimental import nnx" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bcc5cffe", | ||
"metadata": {}, | ||
"source": [ | ||
"### [1] NNX is Pythonic" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d99b73af", | ||
"metadata": { | ||
"outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"class Block(nnx.Module):\n", | ||
" def __init__(self, din, dout, x, *, rngs):\n", | ||
" self.linear = nnx.Linear(din, dout, rngs=rngs,\n", | ||
" kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal() , ('data', 'mp')))\n", | ||
" self.bn = nnx.BatchNorm(dout, rngs=rngs)\n", | ||
"\n", | ||
" def __call__(self, x, *, train: bool):\n", | ||
" x = self.linear(x)\n", | ||
" x = self.bn(x, use_running_average=not train)\n", | ||
" x = nnx.relu(x)\n", | ||
" return x\n", | ||
"\n", | ||
"\n", | ||
"class MLP(nnx.Module):\n", | ||
" def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading\n", | ||
" self.blocks = [\n", | ||
" Block(dim, dim, rngs=rngs) for _ in range(nlayers)\n", | ||
" ]\n", | ||
" self.count = Count(0) # stateful variables are defined as attributes\n", | ||
"\n", | ||
" def __call__(self, x, *, train: bool):\n", | ||
" self.count += 1 # in-place stateful updates\n", | ||
" for block in self.blocks:\n", | ||
" x = block(x, train=train)\n", | ||
" return x\n", | ||
"\n", | ||
"class Count(nnx.Variable): # custom Variable types define the \"collections\"\n", | ||
" pass\n", | ||
"\n", | ||
"model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", | ||
"y = model(jnp.ones((2, 4)), train=False) # call methods directly\n", | ||
"\n", | ||
"print(f'{model = }')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "523aa27c", | ||
"metadata": {}, | ||
"source": [ | ||
"Because NNX Modules contain their own state, they are very easily to inspect:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6f278ec4", | ||
"metadata": { | ||
"outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"print(f'{model.count = }')\n", | ||
"print(f'{model.blocks[0].linear.kernel = }')\n", | ||
"# print(f'{model.blocks.sdf.kernel = }') # typesafe inspection" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "95f389f2", | ||
"metadata": {}, | ||
"source": [ | ||
"### [2] Model Surgery is Intuitive" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "96f61108", | ||
"metadata": { | ||
"outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Module sharing\n", | ||
"model.blocks[1] = model.blocks[3]\n", | ||
"# Weight tying\n", | ||
"model.blocks[0].linear.variables.kernel = model.blocks[-1].linear.variables.kernel\n", | ||
"# Monkey patching\n", | ||
"def my_optimized_layer(x, *, train: bool): return x\n", | ||
"model.blocks[2] = my_optimized_layer\n", | ||
"\n", | ||
"y = model(jnp.ones((2, 4)), train=False) # still works\n", | ||
"print(f'{y.shape = }')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "aca5a6cd", | ||
"metadata": {}, | ||
"source": [ | ||
"### [3] Interacting with JAX is easy" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c166dcc7", | ||
"metadata": { | ||
"outputId": "9a3f378b-739e-4f45-9968-574651200ede" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"state, static = model.split()\n", | ||
"\n", | ||
"# state is a dictionary-like JAX pytree\n", | ||
"print(f'{state = }'[:500] + '\\n...')\n", | ||
"\n", | ||
"# static is also a JAX pytree, but containing no data, just metadata\n", | ||
"print(f'\\n{static = }'[:300] + '\\n...')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9f03e3af", | ||
"metadata": { | ||
"outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"state, static = model.split()\n", | ||
"\n", | ||
"@jax.jit\n", | ||
"def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", | ||
" model = static.merge(state)\n", | ||
" y = model(x, train=True)\n", | ||
" state, _ = model.split()\n", | ||
" return y, state\n", | ||
"\n", | ||
"x = jnp.ones((2, 4))\n", | ||
"y, state = forward(static,state, x)\n", | ||
"\n", | ||
"model.update(state)\n", | ||
"\n", | ||
"print(f'{y.shape = }')\n", | ||
"print(f'{model.count = }')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9e23dbb4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"params, batch_stats, counts, static = model.split(nnx.Param, nnx.BatchStat, Count)\n", | ||
"\n", | ||
"@jax.jit\n", | ||
"def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", | ||
" model = static.merge(params, batch_stats, counts)\n", | ||
" y = model(x, train=True)\n", | ||
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", | ||
" return y, params, batch_stats, counts\n", | ||
"\n", | ||
"x = jnp.ones((2, 4))\n", | ||
"y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n", | ||
"\n", | ||
"model.update(params, batch_stats, counts)\n", | ||
"\n", | ||
"print(f'{y.shape = }')\n", | ||
"print(f'{model.count = }')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2461bfe8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Parent(nnx.Module):\n", | ||
"\n", | ||
" def __init__(self, model: MLP):\n", | ||
" self.model = model\n", | ||
"\n", | ||
" def __call__(self, x, *, train: bool):\n", | ||
"\n", | ||
" params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count)\n", | ||
"\n", | ||
" @jax.jit\n", | ||
" def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", | ||
" model = static.merge(params, batch_stats, counts)\n", | ||
" y = model(x, train=True)\n", | ||
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", | ||
" return y, params, batch_stats, counts\n", | ||
"\n", | ||
" y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n", | ||
"\n", | ||
" self.model.update(params, batch_stats, counts)\n", | ||
"\n", | ||
" return y\n", | ||
"\n", | ||
"parent = Parent(model)\n", | ||
"\n", | ||
"y = parent(jnp.ones((2, 4)), train=False)\n", | ||
"\n", | ||
"print(f'{y.shape = }')\n", | ||
"print(f'{parent.model.count = }')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"formats": "ipynb,md:myst" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.