Skip to content

Commit

Permalink
improve Module dataclass support
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 26, 2023
1 parent bee3fb9 commit c9bc86b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 37 deletions.
4 changes: 2 additions & 2 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from .nnx.dataclasses import (
dataclass as dataclass,
field as field,
node_field as node_field,
treenode_field as treenode_field,
param_field as param_field,
var_field as var_field,
variable_field as variable_field,
)
from .nnx.errors import TraceContextError as TraceContextError
from .nnx.helpers import (
Expand Down
24 changes: 12 additions & 12 deletions flax/experimental/nnx/nnx/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import typing_extensions as tpe

from flax.experimental.nnx.nnx import variables
from flax.experimental.nnx.nnx import pytreelib, variables
from flax.experimental import nnx

A = tp.TypeVar("A")
Expand All @@ -44,7 +44,7 @@ def field(
)


def node_field(
def treenode_field(
*,
default: tp.Any = dataclasses.MISSING,
default_factory: tp.Any = dataclasses.MISSING,
Expand All @@ -59,10 +59,10 @@ def node_field(
else:
metadata = dict(metadata)

if "nnx_box_fn" in metadata:
raise ValueError("'nnx_box_fn' found in metadata")
if "nnx_variable_constructor" in metadata:
raise ValueError("'nnx_variable_constructor' found in metadata")

metadata["nnx_box_fn"] = lambda value: variables.Variable(value)
metadata["nnx_variable_constructor"] = lambda value: pytreelib.TreeNode(value)

return field(
default=default,
Expand All @@ -75,7 +75,7 @@ def node_field(
)


def var_field(
def variable_field(
variable_type: tp.Type[variables.Variable[tp.Any]],
*,
default: tp.Any = dataclasses.MISSING,
Expand All @@ -91,10 +91,10 @@ def var_field(
else:
metadata = dict(metadata)

if "nnx_box_fn" in metadata:
raise ValueError("'nnx_box_fn' found in metadata")
if "nnx_variable_constructor" in metadata:
raise ValueError("'nnx_variable_constructor' found in metadata")

metadata["nnx_box_fn"] = lambda value: variable_type(value)
metadata["nnx_variable_constructor"] = lambda value: variable_type(value)

return field(
default=default,
Expand All @@ -117,7 +117,7 @@ def param_field(
compare: bool = True,
metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None,
) -> tp.Any:
return var_field(
return variable_field(
variables.Param,
default=default,
default_factory=default_factory,
Expand Down Expand Up @@ -150,8 +150,8 @@ def dataclass(
@tpe.dataclass_transform(
field_specifiers=(
field,
node_field,
var_field,
treenode_field,
variable_field,
param_field,
)
)
Expand Down
30 changes: 23 additions & 7 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np

from flax.experimental.nnx.nnx import errors, ids, partitioning, reprlib, tracers, variables as variableslib
from flax.experimental.nnx.nnx.contextlib import Context
from flax.experimental.nnx.nnx.variables import Variable, Sharding
from flax.experimental.nnx.nnx.state import State

Expand All @@ -45,6 +46,13 @@ def __call__(self, accessor: "DelayedAccessor", /, *args, **kwargs) -> tp.Any:
...


@tp.runtime_checkable
class _HasSetup(tp.Protocol):

def setup(self) -> None:
...


@dataclasses.dataclass
class CallableProxy:
_proxy_context: _ProxyContext
Expand Down Expand Up @@ -273,21 +281,29 @@ class ModuleMeta(ABCMeta):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._meta_call(*args, **kwargs)

def _meta_call(self: tp.Type[M], *args, **kwargs) -> M:
module = self.__new__(self, *args, **kwargs)
def _meta_call(cls: tp.Type[M], *args, **kwargs) -> M:
module = cls.__new__(cls, *args, **kwargs)
vars(module)["_module__state"] = ModuleState()
module.__init__(*args, **kwargs)

if dataclasses.is_dataclass(module):
if isinstance(module, _HasSetup):
module.setup()

assert isinstance(module, Module)

for field in dataclasses.fields(module):
if "nnx_box_fn" not in field.metadata:
value = vars(module)[field.name]
# set Context instances to None
if isinstance(value, Context):
vars(module)[field.name] = None
continue

container_fn = field.metadata["nnx_box_fn"]
value = vars(module)[field.name]
value = container_fn(value)
vars(module)[field.name] = value
if "nnx_variable_constructor" not in field.metadata:
continue

variable_constructor = field.metadata["nnx_variable_constructor"]
vars(module)[field.name] = variable_constructor(value)

return module

Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
if dataclasses.is_dataclass(obj):
assert isinstance(obj, Pytree)
for field in dataclasses.fields(obj):
if "nnx_box_fn" not in field.metadata:
if "nnx_variable_constructor" not in field.metadata:
continue

container_fn = field.metadata["nnx_box_fn"]
container_fn = field.metadata["nnx_variable_constructor"]
value = vars(obj)[field.name]
value = container_fn(value)
vars(obj)[field.name] = value
Expand Down
47 changes: 42 additions & 5 deletions flax/experimental/nnx/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import Any, TypeVar

import jax
Expand Down Expand Up @@ -477,9 +478,9 @@ def test_basic(self):
@nnx.dataclass
class Foo(nnx.Module):
a: int
b: int = nnx.node_field()
b: int = nnx.treenode_field()
c: int = nnx.param_field()
d: int = nnx.var_field(nnx.BatchStat)
d: int = nnx.variable_field(nnx.BatchStat)
e: int
f: int

Expand All @@ -495,20 +496,56 @@ class Foo(nnx.Module):
state, moduledef = m.split()

assert len(state) == 4
assert state.variables["b"] == nnx.Variable(2)
assert state.variables["b"] == nnx.TreeNode(2)
assert state.variables["c"] == nnx.Param(3)
assert state.variables["d"] == nnx.BatchStat(4)
assert state.variables["f"] == nnx.Variable(6)

def test_no_override(self):
@nnx.dataclass
class Foo(nnx.Module):
a: int = nnx.node_field()
a: int = nnx.treenode_field()

with pytest.raises(ValueError, match="is not compatible with return type"):
_m = Foo(a=nnx.Param(1))

_m = Foo(a=nnx.Variable(1))
_m = Foo(a=nnx.TreeNode(1))

def test_context_none_after_init(self):
@dataclasses.dataclass
class DFoo(nnx.Module):
din: int
dout: int
ctx: nnx.Context

def __post_init__(self):
self.bar = nnx.Linear(self.din, self.dout, ctx=self.ctx)

def __call__(self, x):
return self.bar(x)

m = DFoo(1, 1, ctx=nnx.context(0))

assert hasattr(m, "bar")
assert m.ctx is None

def test_setup_is_called(self):
@dataclasses.dataclass
class DFoo(nnx.Module):
din: int
dout: int
ctx: nnx.Context

def setup(self):
self.bar = nnx.Linear(self.din, self.dout, ctx=self.ctx)

def __call__(self, x):
return self.bar(x)

m = DFoo(1, 1, ctx=nnx.context(0))

assert hasattr(m, "bar")
assert m.ctx is None


class TestModuleDef:
Expand Down
18 changes: 9 additions & 9 deletions flax/experimental/nnx/tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, y) -> None:
def test_immutable_pytree_dataclass(self):
@nnx.dataclass(frozen=True)
class Foo(nnx.Pytree):
y: int = nnx.node_field()
y: int = nnx.treenode_field()
x: int = nnx.field(default=2)

pytree = Foo(y=3)
Expand All @@ -73,7 +73,7 @@ class Foo(nnx.Pytree):
def test_jit(self):
@nnx.dataclass
class Foo(nnx.Pytree):
a: int = nnx.node_field()
a: int = nnx.treenode_field()
b: int = nnx.field()

module = Foo(a=1, b=2)
Expand All @@ -94,7 +94,7 @@ def __init__(self, a, b):
@nnx.dataclass
class Foo(nnx.Pytree):
bar: Bar
c: int = nnx.node_field()
c: int = nnx.treenode_field()
d: int = nnx.field()

foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4)
Expand Down Expand Up @@ -140,14 +140,14 @@ def __init__(self, x: T):
def test_key_paths(self):
@nnx.dataclass
class Bar(nnx.Pytree):
a: int = nnx.node_field(default=1)
a: int = nnx.treenode_field(default=1)
b: int = nnx.field(default=2)

@nnx.dataclass
class Foo(nnx.Pytree):
x: int = nnx.node_field(default=3)
x: int = nnx.treenode_field(default=3)
y: int = nnx.field(default=4)
z: Bar = nnx.node_field(default_factory=Bar)
z: Bar = nnx.treenode_field(default_factory=Bar)

foo = Foo()

Expand All @@ -167,12 +167,12 @@ class Foo(nnx.Pytree):
def test_dataclass_inheritance(self):
@nnx.dataclass
class A(nnx.Pytree):
a: int = nnx.node_field(default=1)
a: int = nnx.treenode_field(default=1)
b: int = nnx.field(default=2)

@nnx.dataclass
class B(A):
c: int = nnx.node_field(default=3)
c: int = nnx.treenode_field(default=3)

pytree = B()
leaves = jax.tree_util.tree_leaves(pytree)
Expand Down Expand Up @@ -252,7 +252,7 @@ def __init__(self, x):
def test_pytree_dataclass(self):
@nnx.dataclass
class Foo(nnx.Pytree, mutable=True):
y: int = nnx.node_field()
y: int = nnx.treenode_field()
x: int = nnx.field(default=2)

pytree: Foo = Foo(y=3)
Expand Down

0 comments on commit c9bc86b

Please sign in to comment.