diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb
index 1df062735c..ac7f33e1de 100644
--- a/docs/nnx/nnx_basics.ipynb
+++ b/docs/nnx/nnx_basics.ipynb
@@ -19,61 +19,20 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"tags": [
"skip-execution"
]
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Requirement already satisfied: flax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.8.5)\n",
- "Requirement already satisfied: penzai in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.1.3)\n",
- "Requirement already satisfied: numpy>=1.22 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.26.4)\n",
- "Requirement already satisfied: jax>=0.4.27 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.4.31.dev20240621+0428a1509)\n",
- "Requirement already satisfied: msgpack in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.0.8)\n",
- "Requirement already satisfied: optax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.2.2)\n",
- "Requirement already satisfied: orbax-checkpoint in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.5.20)\n",
- "Requirement already satisfied: tensorstore in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.1.63)\n",
- "Requirement already satisfied: rich>=11.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (13.7.1)\n",
- "Requirement already satisfied: typing-extensions>=4.2 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (4.12.2)\n",
- "Requirement already satisfied: PyYAML>=5.4.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (6.0.1)\n",
- "Requirement already satisfied: absl-py>=1.4.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (2.1.0)\n",
- "Requirement already satisfied: equinox>=0.11.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (0.11.4)\n",
- "Requirement already satisfied: ordered_set>=4.1.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (4.1.0)\n",
- "Requirement already satisfied: jaxtyping>=0.2.20 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from equinox>=0.11.3->penzai) (0.2.31)\n",
- "Requirement already satisfied: jaxlib<=0.4.31,>=0.4.30 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.30)\n",
- "Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.0)\n",
- "Requirement already satisfied: opt-einsum in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (3.3.0)\n",
- "Requirement already satisfied: scipy>=1.9 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (1.14.0)\n",
- "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.2.0)\n",
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.18.0)\n",
- "Requirement already satisfied: chex>=0.1.86 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from optax->flax) (0.1.86)\n",
- "Requirement already satisfied: etils[epath,epy] in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.7.0)\n",
- "Requirement already satisfied: nest_asyncio in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0)\n",
- "Requirement already satisfied: protobuf in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (3.20.3)\n",
- "Requirement already satisfied: toolz>=0.9.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from chex>=0.1.86->optax->flax) (0.12.1)\n",
- "Requirement already satisfied: typeguard==2.13.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.3->penzai) (2.13.3)\n",
- "Requirement already satisfied: mdurl~=0.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)\n",
- "Requirement already satisfied: fsspec in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (2024.6.0)\n",
- "Requirement already satisfied: importlib_resources in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (6.4.0)\n",
- "Requirement already satisfied: zipp in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (3.19.2)\n",
- "\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.1\u001b[0m\n",
- "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"! pip install -U flax penzai"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -103,7 +62,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -136,7 +95,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -149,7 +108,7 @@
{
"data": {
"text/html": [
- "
(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -187,7 +146,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -237,13 +196,13 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -296,13 +255,13 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -360,7 +319,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -424,7 +383,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -437,7 +396,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -495,13 +454,13 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -544,13 +503,13 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -562,7 +521,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -593,7 +552,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -664,13 +623,13 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -682,7 +641,7 @@
{
"data": {
"text/html": [
- "(Loading...)
"
+ "(Loading...)
"
],
"text/plain": [
""
@@ -710,7 +669,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -719,6 +678,86 @@
"# update with multiple States\n",
"nnx.update(model, params, counts)"
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Using Modules as Pytrees\n",
+ "\n",
+ "Even though `nnx.split` and `nnx.merge` can be used to interact with any JAX\n",
+ "API, they are not always the most convenient way to do so as they introduce\n",
+ "some syntactic overhead. `Module`s and other `Object`-derived types can be\n",
+ "registered as PyTrees via the `unsafe_pytree` class argument for convenience.\n",
+ "This allows you to pass Modules directly to JAX functions without having to \n",
+ "split them first."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Block(nnx.Module, unsafe_pytree=True): # <== 👀 unsafe_pytree\n",
+ " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n",
+ " self.linear = Linear(din, dout, rngs=rngs)\n",
+ " self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)\n",
+ "\n",
+ " def __call__(self, x: jax.Array):\n",
+ " return nnx.gelu(self.dropout(self.linear(x)))\n",
+ " \n",
+ "model = Block(3, 5, rngs=nnx.Rngs(0))\n",
+ "\n",
+ "@jax.jit # regular jax.jit!\n",
+ "def forward(model: Block, x: jax.Array):\n",
+ " y = model(x)\n",
+ " return y, model # manually propagate state updates\n",
+ "\n",
+ "y, model = forward(model, jnp.ones((1, 3)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**WARNING**: The reason the features is called `unsafe` is because NNX's \n",
+ "reference semantics are broken by JAX's referential transparency, this \n",
+ "is specially problematic when there is shared state between NNX graph nodes \n",
+ "as reference identity is lost. Use `unsafe_pytree` only when there's only \n",
+ "a single top-level object or when top-level object have no shared state\n",
+ "between them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Before: ma.shared is mb.shared = True\n",
+ "After: ma.shared is mb.shared = False\n"
+ ]
+ }
+ ],
+ "source": [
+ "class Foo(nnx.Module, unsafe_pytree=True):\n",
+ " def __init__(self, shared):\n",
+ " self.shared = shared\n",
+ "\n",
+ "shared = nnx.Linear(3, 5, rngs=nnx.Rngs(0))\n",
+ "ma, mb = Foo(shared), Foo(shared)\n",
+ "\n",
+ "print(f'Before: {ma.shared is mb.shared = }')\n",
+ "\n",
+ "# flatten + unflatten\n",
+ "ma, mb = jax.tree.map(lambda x: x, (ma, mb))\n",
+ "\n",
+ "print(f'After: {ma.shared is mb.shared = }')"
+ ]
}
],
"metadata": {
diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md
index ca838042a4..101ffdfba6 100644
--- a/docs/nnx/nnx_basics.md
+++ b/docs/nnx/nnx_basics.md
@@ -378,3 +378,54 @@ model = nnx.merge(graphdef, params, counts)
# update with multiple States
nnx.update(model, params, counts)
```
+
+## Using Modules as Pytrees
+
+Even though `nnx.split` and `nnx.merge` can be used to interact with any JAX
+API, they are not always the most convenient way to do so as they introduce
+some syntactic overhead. `Module`s and other `Object`-derived types can be
+registered as PyTrees via the `unsafe_pytree` class argument for convenience.
+This allows you to pass Modules directly to JAX functions without having to
+split them first.
+
+```{code-cell} ipython3
+class Block(nnx.Module, unsafe_pytree=True): # <== 👀 unsafe_pytree
+ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
+ self.linear = Linear(din, dout, rngs=rngs)
+ self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
+
+ def __call__(self, x: jax.Array):
+ return nnx.gelu(self.dropout(self.linear(x)))
+
+model = Block(3, 5, rngs=nnx.Rngs(0))
+
+@jax.jit # regular jax.jit!
+def forward(model: Block, x: jax.Array):
+ y = model(x)
+ return y, model # manually propagate state updates
+
+y, model = forward(model, jnp.ones((1, 3)))
+```
+
+**WARNING**: The reason the features is called `unsafe` is because NNX's
+reference semantics are broken by JAX's referential transparency, this
+is specially problematic when there is shared state between NNX graph nodes
+as reference identity is lost. Use `unsafe_pytree` only when there's only
+a single top-level object or when top-level object have no shared state
+between them.
+
+```{code-cell} ipython3
+class Foo(nnx.Module, unsafe_pytree=True):
+ def __init__(self, shared):
+ self.shared = shared
+
+shared = nnx.Linear(3, 5, rngs=nnx.Rngs(0))
+ma, mb = Foo(shared), Foo(shared)
+
+print(f'Before: {ma.shared is mb.shared = }')
+
+# flatten + unflatten
+ma, mb = jax.tree.map(lambda x: x, (ma, mb))
+
+print(f'After: {ma.shared is mb.shared = }')
+```
diff --git a/flax/nnx/nnx/compat/module.py b/flax/nnx/nnx/compat/module.py
index 808d699daf..d9a11cd271 100644
--- a/flax/nnx/nnx/compat/module.py
+++ b/flax/nnx/nnx/compat/module.py
@@ -199,8 +199,8 @@ def is_initializing(self) -> bool:
return self._object__state._initializing
- def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
- super().__init_subclass__(experimental_pytree)
+ def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
+ super().__init_subclass__(unsafe_pytree=unsafe_pytree)
cls = dataclasses.dataclass(repr=False)(cls)
diff --git a/flax/nnx/nnx/module.py b/flax/nnx/nnx/module.py
index 13292bcffb..848680190e 100644
--- a/flax/nnx/nnx/module.py
+++ b/flax/nnx/nnx/module.py
@@ -392,10 +392,18 @@ def eval(self, **attributes):
raise_if_not_found=False,
)
- def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
+ def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
+ """
+ Args:
+ unsafe_pytree: If True, the Module subclass will be
+ registered as a pytree node with JAX. This breaks reference
+ semantics and should be used with caution, however it can be
+ useful to use Modules with vanillay JAX transformations. See
+ `Using Modules as PyTrees `__.
+ """
super().__init_subclass__()
- if experimental_pytree:
+ if unsafe_pytree:
jtu.register_pytree_with_keys(
cls,
partial(_module_flatten, with_keys=True),
diff --git a/flax/nnx/tests/graph_utils_test.py b/flax/nnx/tests/graph_utils_test.py
index 52ebcba756..a879b17e6d 100644
--- a/flax/nnx/tests/graph_utils_test.py
+++ b/flax/nnx/tests/graph_utils_test.py
@@ -404,7 +404,7 @@ class SimpleModule(nnx.Module):
pass
-class SimplePyTreeModule(nnx.Module, experimental_pytree=True):
+class SimplePyTreeModule(nnx.Module, unsafe_pytree=True):
pass
diff --git a/flax/nnx/tests/module_test.py b/flax/nnx/tests/module_test.py
index f627d32337..c5454fc281 100644
--- a/flax/nnx/tests/module_test.py
+++ b/flax/nnx/tests/module_test.py
@@ -477,7 +477,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):
class TestModulePytree:
def test_tree_map(self):
- class Foo(nnx.Module, experimental_pytree=True):
+ class Foo(nnx.Module, unsafe_pytree=True):
def __init__(self):
self.node = nnx.Param(1)
self.graphdef = 1
@@ -490,7 +490,7 @@ def __init__(self):
assert m.graphdef == 1
def test_static(self):
- class C(nnx.Module, experimental_pytree=True):
+ class C(nnx.Module, unsafe_pytree=True):
def __init__(self, x):
self.x = x
diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py
index b4e19f7099..ab2004966d 100644
--- a/flax/training/checkpoints.py
+++ b/flax/training/checkpoints.py
@@ -31,7 +31,7 @@
)
from collections.abc import Callable, Iterable
-from etils import epath
+from etils import epath # type: ignore[import-untyped]
import jax
import orbax.checkpoint as ocp
from absl import logging
diff --git a/pyproject.toml b/pyproject.toml
index 8507626bb9..6293d9e14c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -149,6 +149,8 @@ filterwarnings = [
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning",
# jax.xla_computation is deprecated but TF still uses it.
"ignore:.*jax.xla_computation is deprecated.*:DeprecationWarning",
+ # FutureWarning: The key path API is deprecated and will be removed in a future version
+ "ignore:.*The key path API is deprecated and will be removed in a future version.*:FutureWarning",
]
[tool.coverage.report]