Skip to content

Commit

Permalink
[nnx] add demo.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 6, 2024
1 parent d9585e0 commit 13c273d
Show file tree
Hide file tree
Showing 2 changed files with 422 additions and 0 deletions.
251 changes: 251 additions & 0 deletions flax/experimental/nnx/docs/demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a1b37dff",
"metadata": {},
"source": [
"# NNX Demo"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8099a6f",
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"import jax\n",
"from jax import random, numpy as jnp\n",
"from flax.experimental import nnx"
]
},
{
"cell_type": "markdown",
"id": "bcc5cffe",
"metadata": {},
"source": [
"### [1] NNX is Pythonic"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d99b73af",
"metadata": {
"outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718"
},
"outputs": [],
"source": [
"\n",
"class Block(nnx.Module):\n",
" def __init__(self, din, dout, x, *, rngs):\n",
" self.linear = nnx.Linear(din, dout, rngs=rngs,\n",
" kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal() , ('data', 'mp')))\n",
" self.bn = nnx.BatchNorm(dout, rngs=rngs)\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
" x = self.linear(x)\n",
" x = self.bn(x, use_running_average=not train)\n",
" x = nnx.relu(x)\n",
" return x\n",
"\n",
"\n",
"class MLP(nnx.Module):\n",
" def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading\n",
" self.blocks = [\n",
" Block(dim, dim, rngs=rngs) for _ in range(nlayers)\n",
" ]\n",
" self.count = Count(0) # stateful variables are defined as attributes\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
" self.count += 1 # in-place stateful updates\n",
" for block in self.blocks:\n",
" x = block(x, train=train)\n",
" return x\n",
"\n",
"class Count(nnx.Variable): # custom Variable types define the \"collections\"\n",
" pass\n",
"\n",
"model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n",
"y = model(jnp.ones((2, 4)), train=False) # call methods directly\n",
"\n",
"print(f'{model = }')"
]
},
{
"cell_type": "markdown",
"id": "523aa27c",
"metadata": {},
"source": [
"Because NNX Modules contain their own state, they are very easily to inspect:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f278ec4",
"metadata": {
"outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449"
},
"outputs": [],
"source": [
"print(f'{model.count = }')\n",
"print(f'{model.blocks[0].linear.kernel = }')\n",
"# print(f'{model.blocks.sdf.kernel = }') # typesafe inspection"
]
},
{
"cell_type": "markdown",
"id": "95f389f2",
"metadata": {},
"source": [
"### [2] Model Surgery is Intuitive"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96f61108",
"metadata": {
"outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45"
},
"outputs": [],
"source": [
"# Module sharing\n",
"model.blocks[1] = model.blocks[3]\n",
"# Weight tying\n",
"model.blocks[0].linear.variables.kernel = model.blocks[-1].linear.variables.kernel\n",
"# Monkey patching\n",
"def my_optimized_layer(x, *, train: bool): return x\n",
"model.blocks[2] = my_optimized_layer\n",
"\n",
"y = model(jnp.ones((2, 4)), train=False) # still works\n",
"print(f'{y.shape = }')"
]
},
{
"cell_type": "markdown",
"id": "aca5a6cd",
"metadata": {},
"source": [
"### [3] Interacting with JAX is easy"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c166dcc7",
"metadata": {
"outputId": "9a3f378b-739e-4f45-9968-574651200ede"
},
"outputs": [],
"source": [
"state, static = model.split()\n",
"\n",
"# state is a dictionary-like JAX pytree\n",
"print(f'{state = }'[:500] + '\\n...')\n",
"\n",
"# static is also a JAX pytree, but containing no data, just metadata\n",
"print(f'\\n{static = }'[:300] + '\\n...')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f03e3af",
"metadata": {
"outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d"
},
"outputs": [],
"source": [
"state, static = model.split()\n",
"\n",
"@jax.jit\n",
"def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n",
" model = static.merge(state)\n",
" y = model(x, train=True)\n",
" state, _ = model.split()\n",
" return y, state\n",
"\n",
"x = jnp.ones((2, 4))\n",
"y, state = forward(static,state, x)\n",
"\n",
"model.update(state)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{model.count = }')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e23dbb4",
"metadata": {},
"outputs": [],
"source": [
"params, batch_stats, counts, static = model.split(nnx.Param, nnx.BatchStat, Count)\n",
"\n",
"@jax.jit\n",
"def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n",
" model = static.merge(params, batch_stats, counts)\n",
" y = model(x, train=True)\n",
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n",
" return y, params, batch_stats, counts\n",
"\n",
"x = jnp.ones((2, 4))\n",
"y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n",
"\n",
"model.update(params, batch_stats, counts)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{model.count = }')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2461bfe8",
"metadata": {},
"outputs": [],
"source": [
"class Parent(nnx.Module):\n",
"\n",
" def __init__(self, model: MLP):\n",
" self.model = model\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
"\n",
" params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count)\n",
"\n",
" @jax.jit\n",
" def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n",
" model = static.merge(params, batch_stats, counts)\n",
" y = model(x, train=True)\n",
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n",
" return y, params, batch_stats, counts\n",
"\n",
" y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n",
"\n",
" self.model.update(params, batch_stats, counts)\n",
"\n",
" return y\n",
"\n",
"parent = Parent(model)\n",
"\n",
"y = parent(jnp.ones((2, 4)), train=False)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{parent.model.count = }')"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 13c273d

Please sign in to comment.