diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 9446c62484..a7b0919108 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -29,9 +29,9 @@ from .nnx.dataclasses import ( dataclass as dataclass, field as field, - node_field as node_field, + treenode_field as treenode_field, param_field as param_field, - var_field as var_field, + variable_field as variable_field, ) from .nnx.errors import TraceContextError as TraceContextError from .nnx.helpers import ( diff --git a/flax/experimental/nnx/nnx/dataclasses.py b/flax/experimental/nnx/nnx/dataclasses.py index 4bf780a0f6..a888f5b616 100644 --- a/flax/experimental/nnx/nnx/dataclasses.py +++ b/flax/experimental/nnx/nnx/dataclasses.py @@ -17,7 +17,7 @@ import typing_extensions as tpe -from flax.experimental.nnx.nnx import variables +from flax.experimental.nnx.nnx import pytreelib, variables from flax.experimental import nnx A = tp.TypeVar("A") @@ -44,7 +44,7 @@ def field( ) -def node_field( +def treenode_field( *, default: tp.Any = dataclasses.MISSING, default_factory: tp.Any = dataclasses.MISSING, @@ -59,10 +59,10 @@ def node_field( else: metadata = dict(metadata) - if "nnx_box_fn" in metadata: - raise ValueError("'nnx_box_fn' found in metadata") + if "nnx_variable_constructor" in metadata: + raise ValueError("'nnx_variable_constructor' found in metadata") - metadata["nnx_box_fn"] = lambda value: variables.Variable(value) + metadata["nnx_variable_constructor"] = lambda value: pytreelib.TreeNode(value) return field( default=default, @@ -75,7 +75,7 @@ def node_field( ) -def var_field( +def variable_field( variable_type: tp.Type[variables.Variable[tp.Any]], *, default: tp.Any = dataclasses.MISSING, @@ -91,10 +91,10 @@ def var_field( else: metadata = dict(metadata) - if "nnx_box_fn" in metadata: - raise ValueError("'nnx_box_fn' found in metadata") + if "nnx_variable_constructor" in metadata: + raise ValueError("'nnx_variable_constructor' found in metadata") - metadata["nnx_box_fn"] = lambda value: variable_type(value) + metadata["nnx_variable_constructor"] = lambda value: variable_type(value) return field( default=default, @@ -117,7 +117,7 @@ def param_field( compare: bool = True, metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, ) -> tp.Any: - return var_field( + return variable_field( variables.Param, default=default, default_factory=default_factory, @@ -150,8 +150,8 @@ def dataclass( @tpe.dataclass_transform( field_specifiers=( field, - node_field, - var_field, + treenode_field, + variable_field, param_field, ) ) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 4472bd3b1f..2f91583489 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -24,6 +24,7 @@ import numpy as np from flax.experimental.nnx.nnx import errors, ids, partitioning, reprlib, tracers, variables as variableslib +from flax.experimental.nnx.nnx.contextlib import Context from flax.experimental.nnx.nnx.variables import Variable, Sharding from flax.experimental.nnx.nnx.state import State @@ -45,6 +46,13 @@ def __call__(self, accessor: "DelayedAccessor", /, *args, **kwargs) -> tp.Any: ... +@tp.runtime_checkable +class _HasSetup(tp.Protocol): + + def setup(self) -> None: + ... + + @dataclasses.dataclass class CallableProxy: _proxy_context: _ProxyContext @@ -273,21 +281,29 @@ class ModuleMeta(ABCMeta): def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._meta_call(*args, **kwargs) - def _meta_call(self: tp.Type[M], *args, **kwargs) -> M: - module = self.__new__(self, *args, **kwargs) + def _meta_call(cls: tp.Type[M], *args, **kwargs) -> M: + module = cls.__new__(cls, *args, **kwargs) vars(module)["_module__state"] = ModuleState() module.__init__(*args, **kwargs) if dataclasses.is_dataclass(module): + if isinstance(module, _HasSetup): + module.setup() + assert isinstance(module, Module) + for field in dataclasses.fields(module): - if "nnx_box_fn" not in field.metadata: + value = vars(module)[field.name] + # set Context instances to None + if isinstance(value, Context): + vars(module)[field.name] = None continue - container_fn = field.metadata["nnx_box_fn"] - value = vars(module)[field.name] - value = container_fn(value) - vars(module)[field.name] = value + if "nnx_variable_constructor" not in field.metadata: + continue + + variable_constructor = field.metadata["nnx_variable_constructor"] + vars(module)[field.name] = variable_constructor(value) return module diff --git a/flax/experimental/nnx/nnx/pytreelib.py b/flax/experimental/nnx/nnx/pytreelib.py index 330f04ad26..d021b42908 100644 --- a/flax/experimental/nnx/nnx/pytreelib.py +++ b/flax/experimental/nnx/nnx/pytreelib.py @@ -70,10 +70,10 @@ def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: if dataclasses.is_dataclass(obj): assert isinstance(obj, Pytree) for field in dataclasses.fields(obj): - if "nnx_box_fn" not in field.metadata: + if "nnx_variable_constructor" not in field.metadata: continue - container_fn = field.metadata["nnx_box_fn"] + container_fn = field.metadata["nnx_variable_constructor"] value = vars(obj)[field.name] value = container_fn(value) vars(obj)[field.name] = value diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index 48b2c577a2..9c80293c76 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from typing import Any, TypeVar import jax @@ -477,9 +478,9 @@ def test_basic(self): @nnx.dataclass class Foo(nnx.Module): a: int - b: int = nnx.node_field() + b: int = nnx.treenode_field() c: int = nnx.param_field() - d: int = nnx.var_field(nnx.BatchStat) + d: int = nnx.variable_field(nnx.BatchStat) e: int f: int @@ -495,7 +496,7 @@ class Foo(nnx.Module): state, moduledef = m.split() assert len(state) == 4 - assert state.variables["b"] == nnx.Variable(2) + assert state.variables["b"] == nnx.TreeNode(2) assert state.variables["c"] == nnx.Param(3) assert state.variables["d"] == nnx.BatchStat(4) assert state.variables["f"] == nnx.Variable(6) @@ -503,12 +504,48 @@ class Foo(nnx.Module): def test_no_override(self): @nnx.dataclass class Foo(nnx.Module): - a: int = nnx.node_field() + a: int = nnx.treenode_field() with pytest.raises(ValueError, match="is not compatible with return type"): _m = Foo(a=nnx.Param(1)) - _m = Foo(a=nnx.Variable(1)) + _m = Foo(a=nnx.TreeNode(1)) + + def test_context_none_after_init(self): + @dataclasses.dataclass + class DFoo(nnx.Module): + din: int + dout: int + ctx: nnx.Context + + def __post_init__(self): + self.bar = nnx.Linear(self.din, self.dout, ctx=self.ctx) + + def __call__(self, x): + return self.bar(x) + + m = DFoo(1, 1, ctx=nnx.context(0)) + + assert hasattr(m, "bar") + assert m.ctx is None + + def test_setup_is_called(self): + @dataclasses.dataclass + class DFoo(nnx.Module): + din: int + dout: int + ctx: nnx.Context + + def setup(self): + self.bar = nnx.Linear(self.din, self.dout, ctx=self.ctx) + + def __call__(self, x): + return self.bar(x) + + m = DFoo(1, 1, ctx=nnx.context(0)) + + assert hasattr(m, "bar") + assert m.ctx is None class TestModuleDef: diff --git a/flax/experimental/nnx/tests/test_pytree.py b/flax/experimental/nnx/tests/test_pytree.py index 1071e73138..1b2fd01629 100644 --- a/flax/experimental/nnx/tests/test_pytree.py +++ b/flax/experimental/nnx/tests/test_pytree.py @@ -51,7 +51,7 @@ def __init__(self, y) -> None: def test_immutable_pytree_dataclass(self): @nnx.dataclass(frozen=True) class Foo(nnx.Pytree): - y: int = nnx.node_field() + y: int = nnx.treenode_field() x: int = nnx.field(default=2) pytree = Foo(y=3) @@ -73,7 +73,7 @@ class Foo(nnx.Pytree): def test_jit(self): @nnx.dataclass class Foo(nnx.Pytree): - a: int = nnx.node_field() + a: int = nnx.treenode_field() b: int = nnx.field() module = Foo(a=1, b=2) @@ -94,7 +94,7 @@ def __init__(self, a, b): @nnx.dataclass class Foo(nnx.Pytree): bar: Bar - c: int = nnx.node_field() + c: int = nnx.treenode_field() d: int = nnx.field() foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4) @@ -140,14 +140,14 @@ def __init__(self, x: T): def test_key_paths(self): @nnx.dataclass class Bar(nnx.Pytree): - a: int = nnx.node_field(default=1) + a: int = nnx.treenode_field(default=1) b: int = nnx.field(default=2) @nnx.dataclass class Foo(nnx.Pytree): - x: int = nnx.node_field(default=3) + x: int = nnx.treenode_field(default=3) y: int = nnx.field(default=4) - z: Bar = nnx.node_field(default_factory=Bar) + z: Bar = nnx.treenode_field(default_factory=Bar) foo = Foo() @@ -167,12 +167,12 @@ class Foo(nnx.Pytree): def test_dataclass_inheritance(self): @nnx.dataclass class A(nnx.Pytree): - a: int = nnx.node_field(default=1) + a: int = nnx.treenode_field(default=1) b: int = nnx.field(default=2) @nnx.dataclass class B(A): - c: int = nnx.node_field(default=3) + c: int = nnx.treenode_field(default=3) pytree = B() leaves = jax.tree_util.tree_leaves(pytree) @@ -252,7 +252,7 @@ def __init__(self, x): def test_pytree_dataclass(self): @nnx.dataclass class Foo(nnx.Pytree, mutable=True): - y: int = nnx.node_field() + y: int = nnx.treenode_field() x: int = nnx.field(default=2) pytree: Foo = Foo(y=3)