Skip to content

Commit

Permalink
[nnx] stabilize unsafe_pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 5, 2024
1 parent 0fb1777 commit d2ca0ef
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 73 deletions.
169 changes: 104 additions & 65 deletions docs/nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,54 @@ model = nnx.merge(graphdef, params, counts)
# update with multiple States
nnx.update(model, params, counts)
```

## Using Modules as Pytrees

Even though `nnx.split` and `nnx.merge` can be used to interact with any JAX
API, they are not always the most convenient way to do so as they introduce
some syntactic overhead. `Module`s and other `Object`-derived types can be
registered as PyTrees via the `unsafe_pytree` class argument for convenience.
This allows you to pass Modules directly to JAX functions without having to
split them first.

```{code-cell} ipython3
class Block(nnx.Module, unsafe_pytree=True): # <== 👀 unsafe_pytree
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
def __call__(self, x: jax.Array):
return nnx.gelu(self.dropout(self.linear(x)))
model = Block(3, 5, rngs=nnx.Rngs(0))
@jax.jit # regular jax.jit!
def forward(model: Block, x: jax.Array):
y = model(x)
return y, model # manually propagate state updates
y, model = forward(model, jnp.ones((1, 3)))
```

**WARNING**: The reason the features is called `unsafe` is because NNX's
reference semantics are broken by JAX's referential transparency, this
is specially problematic when there is shared state between NNX graph nodes
as reference identity is lost. Use `unsafe_pytree` only when there's only
a single top-level object or when top-level object have no shared state
between them.

```{code-cell} ipython3
class Foo(nnx.Module, unsafe_pytree=True):
def __init__(self, shared):
self.shared = shared
shared = nnx.Linear(3, 5, rngs=nnx.Rngs(0))
ma, mb = Foo(shared), Foo(shared)
print(f'Before: {ma.shared is mb.shared = }')
# flatten + unflatten
ma, mb = jax.tree.map(lambda x: x, (ma, mb))
print(f'After: {ma.shared is mb.shared = }')
```
4 changes: 2 additions & 2 deletions flax/nnx/nnx/compat/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def is_initializing(self) -> bool:

return self._object__state._initializing

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__(experimental_pytree)
def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
super().__init_subclass__(unsafe_pytree=unsafe_pytree)

cls = dataclasses.dataclass(repr=False)(cls)

Expand Down
12 changes: 10 additions & 2 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,18 @@ def eval(self, **attributes):
raise_if_not_found=False,
)

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
"""
Args:
unsafe_pytree: If True, the Module subclass will be
registered as a pytree node with JAX. This breaks reference
semantics and should be used with caution, however it can be
useful to use Modules with vanillay JAX transformations. See
`Using Modules as PyTrees <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#using-modules-as-pytrees>`__.
"""
super().__init_subclass__()

if experimental_pytree:
if unsafe_pytree:
jtu.register_pytree_with_keys(
cls,
partial(_module_flatten, with_keys=True),
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class SimpleModule(nnx.Module):
pass


class SimplePyTreeModule(nnx.Module, experimental_pytree=True):
class SimplePyTreeModule(nnx.Module, unsafe_pytree=True):
pass


Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/tests/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):

class TestModulePytree:
def test_tree_map(self):
class Foo(nnx.Module, experimental_pytree=True):
class Foo(nnx.Module, unsafe_pytree=True):
def __init__(self):
self.node = nnx.Param(1)
self.graphdef = 1
Expand All @@ -490,7 +490,7 @@ def __init__(self):
assert m.graphdef == 1

def test_static(self):
class C(nnx.Module, experimental_pytree=True):
class C(nnx.Module, unsafe_pytree=True):
def __init__(self, x):
self.x = x

Expand Down
2 changes: 1 addition & 1 deletion flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from collections.abc import Callable, Iterable

from etils import epath
from etils import epath # type: ignore[import-untyped]
import jax
import orbax.checkpoint as ocp
from absl import logging
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ filterwarnings = [
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning",
# jax.xla_computation is deprecated but TF still uses it.
"ignore:.*jax.xla_computation is deprecated.*:DeprecationWarning",
# FutureWarning: The key path API is deprecated and will be removed in a future version
"ignore:.*The key path API is deprecated and will be removed in a future version.*:FutureWarning",
]

[tool.coverage.report]
Expand Down

0 comments on commit d2ca0ef

Please sign in to comment.