diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index f129c90c74..0a67c0676d 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -20,7 +20,7 @@ from .nnx import compatibility as compatibility from .nnx import graph_utils as graph_utils -from .nnx.errors import TraceContextError as TraceContextError +from .nnx import errors as errors from .nnx.filterlib import All as All from .nnx.filterlib import Not as Not from .nnx.graph_utils import GraphDef as GraphDef diff --git a/flax/experimental/nnx/nnx/compatibility.py b/flax/experimental/nnx/nnx/compatibility.py index e46664eac1..9622a0adda 100644 --- a/flax/experimental/nnx/nnx/compatibility.py +++ b/flax/experimental/nnx/nnx/compatibility.py @@ -65,7 +65,9 @@ def __init__( self.module = module _rngs = ( - {name: stream.key for name, stream in rngs._rngs.items()} if rngs else {} + {name: stream.key.raw_value for name, stream in rngs._rngs.items()} + if rngs + else {} ) # rename default to params if 'params' not in _rngs and 'default' in _rngs: @@ -83,7 +85,9 @@ def __call__( self, *args: Any, rngs: tp.Optional[Rngs] = None, **kwargs: Any ) -> Any: _rngs = ( - {name: stream.key for name, stream in rngs._rngs.items()} if rngs else {} + {name: stream.key.value for name, stream in rngs._rngs.items()} + if rngs + else {} ) variables = { diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index 2a3a18a2b5..125c40f3ec 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -38,8 +38,6 @@ DelayedAccessor, ) from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf -from flax.experimental.nnx.nnx.rnglib import Rngs -from flax.experimental.nnx.nnx.state import State from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable from flax.typing import Path, PathParts @@ -1098,12 +1096,6 @@ def _maybe_insert(x): # --------------------------------------------------------- -@tp.runtime_checkable -class _HasSetup(tp.Protocol): - def setup(self) -> None: - ... - - class ModuleState(reprlib.Representable): __slots__ = ('_trace_state', '_id') @@ -1127,29 +1119,16 @@ def __nnx_repr__(self): class GraphNodeMeta(ABCMeta): if not tp.TYPE_CHECKING: - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self._meta_call(*args, **kwargs) - - def _meta_call(cls: tp.Type[G], *args, **kwargs) -> G: - node = cls.__new__(cls, *args, **kwargs) - vars(node)['_graph_node__state'] = ModuleState() - node.__init__(*args, **kwargs) + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + return _graph_node_meta_call(cls, *args, **kwargs) - if dataclasses.is_dataclass(node): - if isinstance(node, _HasSetup): - node.setup() - assert isinstance(node, GraphNode) +def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G: + node = cls.__new__(cls, *args, **kwargs) + vars(node)['_graph_node__state'] = ModuleState() + node.__init__(*args, **kwargs) - for field in dataclasses.fields(node): - if not field.init: - continue - value = vars(node)[field.name] - # set Rngs instances to None - if isinstance(value, Rngs): - vars(node)[field.name] = None - - return node + return node class GraphNode(reprlib.Representable, metaclass=GraphNodeMeta): @@ -1162,13 +1141,15 @@ def __setattr__(self, name: str, value: Any) -> None: self._setattr(name, value) def _setattr(self, name: str, value: tp.Any) -> None: - if not self._graph_node__state.trace_state.is_valid(): - raise errors.TraceContextError( - 'Cannot mutate GraphNode from different trace level' - ) - + self.check_valid_context( + f"Cannot mutate '{type(self).__name__}' from different trace level" + ) object.__setattr__(self, name, value) + def check_valid_context(self, error_msg: str) -> None: + if not self._graph_node__state.trace_state.is_valid(): + raise errors.TraceContextError(error_msg) + def __deepcopy__(self: G, memo=None) -> G: state, graphdef, _ = graph_utils.graph_flatten(self) graphdef = deepcopy(graphdef) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index c851d3599f..b6ef989587 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -14,6 +14,7 @@ from __future__ import annotations +import dataclasses import typing as tp from functools import partial @@ -26,7 +27,7 @@ graph_utils, ) from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.graph_utils import GraphDef +from flax.experimental.nnx.nnx.graph_utils import GraphDef, GraphNodeMeta from flax.experimental.nnx.nnx.proxy_caller import ( ApplyCaller, CallableProxy, @@ -47,8 +48,40 @@ tuple_reduce = lambda xs, x: xs + (x,) tuple_init = lambda: () +@tp.runtime_checkable +class _HasSetup(tp.Protocol): + def setup(self) -> None: + ... + + +class ModuleMeta(GraphNodeMeta): + if not tp.TYPE_CHECKING: + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + return _module_meta_call(cls, *args, **kwargs) + + +def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M: + module: M = GraphNodeMeta.__call__(cls, *args, **kwargs) + + if dataclasses.is_dataclass(module): + if isinstance(module, _HasSetup): + module.setup() + + assert isinstance(module, Module) + + for field in dataclasses.fields(module): + if not field.init: + continue + value = vars(module)[field.name] + # set Rngs instances to None + if isinstance(value, Rngs): + vars(module)[field.name] = None + + return module + -class Module(graph_utils.GraphNode): +class Module(graph_utils.GraphNode, metaclass=ModuleMeta): @classmethod def init(cls: type[M], *args, **kwargs) -> tuple[State, GraphDef[M]]: return cls(*args, **kwargs).split() diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index 5df48aab37..3bad4e0aec 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -14,6 +14,7 @@ from typing import Optional, Sequence +import jax import jax.numpy as jnp from jax import lax, random @@ -46,7 +47,7 @@ def __call__( *, deterministic: Optional[bool] = None, rngs: Optional[rnglib.Rngs] = None, - ): + ) -> jax.Array: """Applies a random dropout mask to the input. Args: diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py index 70307b67d4..cbce0d270e 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -32,8 +32,11 @@ import typing as tp import jax +import jax.numpy as jnp -from flax.experimental.nnx.nnx import errors, filterlib, tracers +from flax.experimental.nnx.nnx.variables import Variable +from flax.experimental.nnx.nnx import filterlib +from flax.experimental.nnx.nnx.graph_utils import GraphNode Counts = list[int] AxesValue = tp.Union[int, None] @@ -46,20 +49,31 @@ class Missing: MISSING = Missing() +class RngState(Variable[jax.Array]): + pass + @dataclasses.dataclass -class RngStream: - key: jax.Array # dynamic - count: int # static +class RngStream(GraphNode): + def __init__( + self, + key: jax.Array, + count: jax.Array, + ): + self.key = RngState(key) + self.count = RngState(count) def __post_init__(self): if not isinstance(self.key, jax.Array): raise TypeError(f'key must be a jax.Array, got {type(self.key)}') def make_rng(self) -> jax.Array: - count = self.count - self.count += 1 - return jax.random.fold_in(self.key, count) + self.check_valid_context( + "Cannot call 'make_rng' from a different trace level" + ) + key = jax.random.fold_in(self.key.value, self.count.value) + self.count.value += 1 + return key def fork(self, pattern: Pattern) -> jax.Array: if pattern is None: @@ -70,8 +84,8 @@ def fork(self, pattern: Pattern) -> jax.Array: num_splits = pattern else: num_splits = tuple(x if x is not None else 1 for x in pattern) - key = jax.random.split(self.key, num_splits) - self.count += 1 + key = jax.random.split(self.key.value, num_splits) + self.count.value += 1 return key @@ -83,9 +97,7 @@ def fork(self, pattern: Pattern) -> jax.Array: ] -class Rngs(tp.Mapping[str, tp.Callable[[], jax.Array]]): - __slots__ = ('_trace_state', '_rngs', '_counts') - +class Rngs(GraphNode, tp.Mapping[str, tp.Callable[[], jax.Array]]): def __init__( self, default: RngValue | RngDict | None = None, @@ -101,17 +113,12 @@ def __init__( self._rngs = { name: RngStream( key=jax.random.key(value) if isinstance(value, int) else value, - count=0, + count=jnp.array(0, dtype=jnp.uint32), ) for name, value in rngs.items() } - self._trace_state = tracers.TraceState() def _make_rng(self, name: str, error_type: Exception) -> jax.Array: - if not self.is_valid(): - raise errors.TraceContextError( - 'Cannot use Rngs from a different trace level' - ) if name not in self._rngs: if 'default' not in self._rngs: raise error_type(f"No RNG named {name!r} or 'default' found in Rngs.") @@ -144,8 +151,6 @@ def replace(self, **kwargs: tp.Union[int, jax.Array, RngStream]) -> 'Rngs': rngs.update(kwargs) return Rngs(**rngs) - def is_valid(self) -> bool: - return self._trace_state.is_valid() def fork( self, @@ -153,11 +158,6 @@ def fork( /, **patterns: Pattern, ) -> ForkedKeys: - if not self.is_valid(): - raise errors.TraceContextError( - 'Cannot use Rngs from a different trace level' - ) - filter_patterns: list[tuple[filterlib.Filter, Pattern]] if isinstance(_default, dict): # merge default and patterns diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 716c3d1188..34f5dd9391 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -45,8 +45,7 @@ spmd, variables, ) -from flax.experimental.nnx.nnx.graph_utils import GraphNodeMeta -from flax.experimental.nnx.nnx.module import GraphDef, Module +from flax.experimental.nnx.nnx.module import GraphDef, Module, ModuleMeta from flax.experimental.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, @@ -212,7 +211,7 @@ def get_jit_kwargs(self) -> dict[str, tp.Any]: return kwargs -class JITMeta(GraphNodeMeta): +class JITMeta(ModuleMeta): def __call__( self, module_constructor: tp.Callable[..., M], @@ -286,9 +285,6 @@ def jitted_fn( (args, kwargs), input_graph_nodes ) - if 'rngs' in kwargs: - kwargs['rngs'] = rnglib.Rngs(kwargs['rngs']) - out = f(*args, **kwargs) out, output_graph_nodes = graph_utils.extract_graph_nodes(out) @@ -312,9 +308,6 @@ def jit_apply( args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], ) -> tp.Any: - if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], rnglib.Rngs): - kwargs['rngs'] = rngs.fork() - (args, kwargs), input_graph_nodes = graph_utils.extract_graph_nodes( (args, kwargs) ) @@ -441,24 +434,12 @@ def jit( ) jitted_fn = get_jitted_fn(f, options) - if is_init: - - @functools.wraps(f) - def jit_init_wrapper(*args, **kwargs): - _check_args(args) - jit_apply(options, jitted_fn, args, kwargs) - - wrapper = jit_init_wrapper - wrapper.inner = jitted_fn - else: - - @functools.wraps(f) - def jit_apply_wrapper(*args, **kwargs): - _check_args(args) - return jit_apply(options, jitted_fn, args, kwargs) + @functools.wraps(f) + def jit_apply_wrapper(*args, **kwargs): + return jit_apply(options, jitted_fn, args, kwargs) - wrapper = jit_apply_wrapper - wrapper.inner = jitted_fn + wrapper = jit_apply_wrapper + wrapper.inner = jitted_fn return wrapper # type: ignore @@ -478,7 +459,7 @@ class GradOptions: return_value: bool -class GradMeta(GraphNodeMeta): +class GradMeta(ModuleMeta): def __call__( self, module_constructor: tp.Callable[..., M], @@ -734,7 +715,7 @@ class ScanOptions: scan_output: bool -class ScanMeta(GraphNodeMeta): +class ScanMeta(ModuleMeta): def __call__( self, module_constructor: tp.Callable[..., M], @@ -1203,7 +1184,7 @@ def scan_apply_wrapper( # ------------------------------- -class RematMeta(GraphNodeMeta): +class RematMeta(ModuleMeta): def __call__( self, module_constructor: tp.Callable[..., M], @@ -1387,7 +1368,7 @@ class VmapOptions: vmap_metadata: tp.Mapping[str, tp.Any] -class VmapMeta(GraphNodeMeta): +class VmapMeta(ModuleMeta): def __call__( self, module_constructor: tp.Callable[..., M], diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 0d7741f4dd..2d1589d3e7 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -227,8 +227,8 @@ def __setattr__(self, name: str, value: Any) -> None: def _setattr(self, name: str, value: tp.Any): if not self._trace_state.is_valid(): - raise ValueError( - 'Cannot mutate Variable from a different trace level' + raise nnx.errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' ) object.__setattr__(self, name, value) diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index ae3f9ec6de..26e445401b 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -41,8 +41,8 @@ def test_trace_level(self): @jax.jit def f(): with pytest.raises( - nnx.TraceContextError, - match='Cannot mutate GraphNode from different trace level', + nnx.errors.TraceContextError, + match="Cannot mutate 'Dict' from different trace level", ): m.a = 2 diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/experimental/nnx/tests/test_rngs.py index 805c181f3d..210b742a07 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/experimental/nnx/tests/test_rngs.py @@ -38,16 +38,16 @@ def test_fallback_error_no_default(self): def test_rng_stream(self): key0 = jax.random.key(0) rngs = nnx.Rngs(params=key0) - assert rngs._rngs['params'].count == 0 + assert rngs._rngs['params'].count.value == 0 key1 = rngs.params() - assert rngs._rngs['params'].count == 1 - assert rngs._rngs['params'].key is key0 + assert rngs._rngs['params'].count.value == 1 + assert rngs._rngs['params'].key.value is key0 assert not jnp.allclose(key0, key1) key2 = rngs.params() - assert rngs._rngs['params'].count == 2 - assert rngs._rngs['params'].key is key0 + assert rngs._rngs['params'].count.value == 2 + assert rngs._rngs['params'].key.value is key0 assert not jnp.allclose(key1, key2) def test_rng_fork(self): @@ -55,7 +55,7 @@ def test_rng_fork(self): rngs1 = nnx.Rngs(params=key0) rngs2 = nnx.Rngs(rngs1.fork()) - assert rngs2._rngs['params'].count == 0 + assert rngs2._rngs['params'].count.value == 0 key1 = rngs1.params() key2 = rngs2.params() @@ -68,8 +68,8 @@ def test_rng_trace_level_constraints(self): @jax.jit def f(): with pytest.raises( - nnx.TraceContextError, - match='Cannot use Rngs from a different trace level', + nnx.errors.TraceContextError, + match="Cannot call 'make_rng' from a different trace level", ): rngs.params() @@ -78,8 +78,8 @@ def f(): @jax.jit def f(): with pytest.raises( - nnx.TraceContextError, - match='Cannot use Rngs from a different trace level', + nnx.errors.TraceContextError, + match="Cannot call 'make_rng' from a different trace level", ): rngs.fork() @@ -96,8 +96,8 @@ def g(): assert isinstance(rngs1, nnx.Rngs) with pytest.raises( - nnx.TraceContextError, - match='Cannot use Rngs from a different trace level', + nnx.errors.TraceContextError, + match="Cannot call 'make_rng' from a different trace level", ): rngs1.params() @@ -181,3 +181,34 @@ def test_rng_stream_pytree(self): assert set(keys.keys()) == set(keys2.keys()) assert set(keys.splits.keys()) == set(keys2.splits.keys()) assert set(keys.broadcasts.keys()) == set(keys2.broadcasts.keys()) + + def test_jit_updates(self): + class Foo(nnx.Module): + def __init__(self, not_rngs): + rngs = not_rngs + self.linear = nnx.Linear(2, 2, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False) + + def __call__(self, x, rngs): + x = self.linear(x) + x = self.dropout(x, rngs=rngs) + return x + + rngs = nnx.Rngs(0) + m = Foo(rngs) + + # +1 for the Linear kernel, +1 for the Linear bias + assert rngs._rngs['default'].count.value == 2 + + @nnx.jit + def f(m: Foo, x: jax.Array, not_rngs: nnx.Rngs): + rngs = not_rngs + x = m(x, rngs) + x = m(x, rngs) + return x + + x = jnp.ones((2, 2)) + x = f(m, x, rngs) + + # +1 for the Dropout mask + assert rngs._rngs['default'].count.value == 4