diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index 1df062735c..ac7f33e1de 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -19,61 +19,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "tags": [ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: flax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.8.5)\n", - "Requirement already satisfied: penzai in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.1.3)\n", - "Requirement already satisfied: numpy>=1.22 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.26.4)\n", - "Requirement already satisfied: jax>=0.4.27 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.4.31.dev20240621+0428a1509)\n", - "Requirement already satisfied: msgpack in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.0.8)\n", - "Requirement already satisfied: optax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.2.2)\n", - "Requirement already satisfied: orbax-checkpoint in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.5.20)\n", - "Requirement already satisfied: tensorstore in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.1.63)\n", - "Requirement already satisfied: rich>=11.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (13.7.1)\n", - "Requirement already satisfied: typing-extensions>=4.2 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (4.12.2)\n", - "Requirement already satisfied: PyYAML>=5.4.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (6.0.1)\n", - "Requirement already satisfied: absl-py>=1.4.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (2.1.0)\n", - "Requirement already satisfied: equinox>=0.11.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (0.11.4)\n", - "Requirement already satisfied: ordered_set>=4.1.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (4.1.0)\n", - "Requirement already satisfied: jaxtyping>=0.2.20 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from equinox>=0.11.3->penzai) (0.2.31)\n", - "Requirement already satisfied: jaxlib<=0.4.31,>=0.4.30 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.30)\n", - "Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.0)\n", - "Requirement already satisfied: opt-einsum in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (3.3.0)\n", - "Requirement already satisfied: scipy>=1.9 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (1.14.0)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.2.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.18.0)\n", - "Requirement already satisfied: chex>=0.1.86 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from optax->flax) (0.1.86)\n", - "Requirement already satisfied: etils[epath,epy] in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.7.0)\n", - "Requirement already satisfied: nest_asyncio in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0)\n", - "Requirement already satisfied: protobuf in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (3.20.3)\n", - "Requirement already satisfied: toolz>=0.9.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from chex>=0.1.86->optax->flax) (0.12.1)\n", - "Requirement already satisfied: typeguard==2.13.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.3->penzai) (2.13.3)\n", - "Requirement already satisfied: mdurl~=0.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)\n", - "Requirement already satisfied: fsspec in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (2024.6.0)\n", - "Requirement already satisfied: importlib_resources in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (6.4.0)\n", - "Requirement already satisfied: zipp in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (3.19.2)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.1\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "! pip install -U flax penzai" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -103,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -136,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -149,7 +108,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -187,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -237,13 +196,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -296,13 +255,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -360,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -424,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -437,7 +396,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -495,13 +454,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -544,13 +503,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -562,7 +521,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -593,7 +552,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -664,13 +623,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -682,7 +641,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
(Loading...)
" ], "text/plain": [ "" @@ -710,7 +669,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -719,6 +678,86 @@ "# update with multiple States\n", "nnx.update(model, params, counts)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using Modules as Pytrees\n", + "\n", + "Even though `nnx.split` and `nnx.merge` can be used to interact with any JAX\n", + "API, they are not always the most convenient way to do so as they introduce\n", + "some syntactic overhead. `Module`s and other `Object`-derived types can be\n", + "registered as PyTrees via the `unsafe_pytree` class argument for convenience.\n", + "This allows you to pass Modules directly to JAX functions without having to \n", + "split them first." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "class Block(nnx.Module, unsafe_pytree=True): # <== 👀 unsafe_pytree\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.linear = Linear(din, dout, rngs=rngs)\n", + " self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " return nnx.gelu(self.dropout(self.linear(x)))\n", + " \n", + "model = Block(3, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit # regular jax.jit!\n", + "def forward(model: Block, x: jax.Array):\n", + " y = model(x)\n", + " return y, model # manually propagate state updates\n", + "\n", + "y, model = forward(model, jnp.ones((1, 3)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**WARNING**: The reason the features is called `unsafe` is because NNX's \n", + "reference semantics are broken by JAX's referential transparency, this \n", + "is specially problematic when there is shared state between NNX graph nodes \n", + "as reference identity is lost. Use `unsafe_pytree` only when there's only \n", + "a single top-level object or when top-level object have no shared state\n", + "between them." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before: ma.shared is mb.shared = True\n", + "After: ma.shared is mb.shared = False\n" + ] + } + ], + "source": [ + "class Foo(nnx.Module, unsafe_pytree=True):\n", + " def __init__(self, shared):\n", + " self.shared = shared\n", + "\n", + "shared = nnx.Linear(3, 5, rngs=nnx.Rngs(0))\n", + "ma, mb = Foo(shared), Foo(shared)\n", + "\n", + "print(f'Before: {ma.shared is mb.shared = }')\n", + "\n", + "# flatten + unflatten\n", + "ma, mb = jax.tree.map(lambda x: x, (ma, mb))\n", + "\n", + "print(f'After: {ma.shared is mb.shared = }')" + ] } ], "metadata": { diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index ca838042a4..101ffdfba6 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -378,3 +378,54 @@ model = nnx.merge(graphdef, params, counts) # update with multiple States nnx.update(model, params, counts) ``` + +## Using Modules as Pytrees + +Even though `nnx.split` and `nnx.merge` can be used to interact with any JAX +API, they are not always the most convenient way to do so as they introduce +some syntactic overhead. `Module`s and other `Object`-derived types can be +registered as PyTrees via the `unsafe_pytree` class argument for convenience. +This allows you to pass Modules directly to JAX functions without having to +split them first. + +```{code-cell} ipython3 +class Block(nnx.Module, unsafe_pytree=True): # <== 👀 unsafe_pytree + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.linear = Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) + + def __call__(self, x: jax.Array): + return nnx.gelu(self.dropout(self.linear(x))) + +model = Block(3, 5, rngs=nnx.Rngs(0)) + +@jax.jit # regular jax.jit! +def forward(model: Block, x: jax.Array): + y = model(x) + return y, model # manually propagate state updates + +y, model = forward(model, jnp.ones((1, 3))) +``` + +**WARNING**: The reason the features is called `unsafe` is because NNX's +reference semantics are broken by JAX's referential transparency, this +is specially problematic when there is shared state between NNX graph nodes +as reference identity is lost. Use `unsafe_pytree` only when there's only +a single top-level object or when top-level object have no shared state +between them. + +```{code-cell} ipython3 +class Foo(nnx.Module, unsafe_pytree=True): + def __init__(self, shared): + self.shared = shared + +shared = nnx.Linear(3, 5, rngs=nnx.Rngs(0)) +ma, mb = Foo(shared), Foo(shared) + +print(f'Before: {ma.shared is mb.shared = }') + +# flatten + unflatten +ma, mb = jax.tree.map(lambda x: x, (ma, mb)) + +print(f'After: {ma.shared is mb.shared = }') +``` diff --git a/flax/nnx/nnx/compat/module.py b/flax/nnx/nnx/compat/module.py index 808d699daf..d9a11cd271 100644 --- a/flax/nnx/nnx/compat/module.py +++ b/flax/nnx/nnx/compat/module.py @@ -199,8 +199,8 @@ def is_initializing(self) -> bool: return self._object__state._initializing - def __init_subclass__(cls, experimental_pytree: bool = False) -> None: - super().__init_subclass__(experimental_pytree) + def __init_subclass__(cls, unsafe_pytree: bool = False) -> None: + super().__init_subclass__(unsafe_pytree=unsafe_pytree) cls = dataclasses.dataclass(repr=False)(cls) diff --git a/flax/nnx/nnx/module.py b/flax/nnx/nnx/module.py index 13292bcffb..848680190e 100644 --- a/flax/nnx/nnx/module.py +++ b/flax/nnx/nnx/module.py @@ -392,10 +392,18 @@ def eval(self, **attributes): raise_if_not_found=False, ) - def __init_subclass__(cls, experimental_pytree: bool = False) -> None: + def __init_subclass__(cls, unsafe_pytree: bool = False) -> None: + """ + Args: + unsafe_pytree: If True, the Module subclass will be + registered as a pytree node with JAX. This breaks reference + semantics and should be used with caution, however it can be + useful to use Modules with vanillay JAX transformations. See + `Using Modules as PyTrees `__. + """ super().__init_subclass__() - if experimental_pytree: + if unsafe_pytree: jtu.register_pytree_with_keys( cls, partial(_module_flatten, with_keys=True), diff --git a/flax/nnx/tests/graph_utils_test.py b/flax/nnx/tests/graph_utils_test.py index 52ebcba756..a879b17e6d 100644 --- a/flax/nnx/tests/graph_utils_test.py +++ b/flax/nnx/tests/graph_utils_test.py @@ -404,7 +404,7 @@ class SimpleModule(nnx.Module): pass -class SimplePyTreeModule(nnx.Module, experimental_pytree=True): +class SimplePyTreeModule(nnx.Module, unsafe_pytree=True): pass diff --git a/flax/nnx/tests/module_test.py b/flax/nnx/tests/module_test.py index f627d32337..c5454fc281 100644 --- a/flax/nnx/tests/module_test.py +++ b/flax/nnx/tests/module_test.py @@ -477,7 +477,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): class TestModulePytree: def test_tree_map(self): - class Foo(nnx.Module, experimental_pytree=True): + class Foo(nnx.Module, unsafe_pytree=True): def __init__(self): self.node = nnx.Param(1) self.graphdef = 1 @@ -490,7 +490,7 @@ def __init__(self): assert m.graphdef == 1 def test_static(self): - class C(nnx.Module, experimental_pytree=True): + class C(nnx.Module, unsafe_pytree=True): def __init__(self, x): self.x = x diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index b4e19f7099..ab2004966d 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -31,7 +31,7 @@ ) from collections.abc import Callable, Iterable -from etils import epath +from etils import epath # type: ignore[import-untyped] import jax import orbax.checkpoint as ocp from absl import logging diff --git a/pyproject.toml b/pyproject.toml index 8507626bb9..6293d9e14c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,8 @@ filterwarnings = [ "ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning", # jax.xla_computation is deprecated but TF still uses it. "ignore:.*jax.xla_computation is deprecated.*:DeprecationWarning", + # FutureWarning: The key path API is deprecated and will be removed in a future version + "ignore:.*The key path API is deprecated and will be removed in a future version.*:FutureWarning", ] [tool.coverage.report]