Skip to content

Commit

Permalink
[nnx] Rngs and RngStream inherit from GraphNode
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 30, 2024
1 parent 514c111 commit 453f38b
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 110 deletions.
2 changes: 1 addition & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .nnx import compatibility as compatibility
from .nnx import graph_utils as graph_utils
from .nnx.errors import TraceContextError as TraceContextError
from .nnx import errors as errors
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.graph_utils import GraphDef as GraphDef
Expand Down
8 changes: 6 additions & 2 deletions flax/experimental/nnx/nnx/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
self.module = module

_rngs = (
{name: stream.key for name, stream in rngs._rngs.items()} if rngs else {}
{name: stream.key.raw_value for name, stream in rngs._rngs.items()}
if rngs
else {}
)
# rename default to params
if 'params' not in _rngs and 'default' in _rngs:
Expand All @@ -83,7 +85,9 @@ def __call__(
self, *args: Any, rngs: tp.Optional[Rngs] = None, **kwargs: Any
) -> Any:
_rngs = (
{name: stream.key for name, stream in rngs._rngs.items()} if rngs else {}
{name: stream.key.value for name, stream in rngs._rngs.items()}
if rngs
else {}
)

variables = {
Expand Down
47 changes: 14 additions & 33 deletions flax/experimental/nnx/nnx/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
DelayedAccessor,
)
from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf
from flax.experimental.nnx.nnx.rnglib import Rngs
from flax.experimental.nnx.nnx.state import State
from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable
from flax.typing import Path, PathParts

Expand Down Expand Up @@ -1098,12 +1096,6 @@ def _maybe_insert(x):
# ---------------------------------------------------------


@tp.runtime_checkable
class _HasSetup(tp.Protocol):
def setup(self) -> None:
...


class ModuleState(reprlib.Representable):
__slots__ = ('_trace_state', '_id')

Expand All @@ -1127,29 +1119,16 @@ def __nnx_repr__(self):
class GraphNodeMeta(ABCMeta):
if not tp.TYPE_CHECKING:

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._meta_call(*args, **kwargs)

def _meta_call(cls: tp.Type[G], *args, **kwargs) -> G:
node = cls.__new__(cls, *args, **kwargs)
vars(node)['_graph_node__state'] = ModuleState()
node.__init__(*args, **kwargs)
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
return _graph_node_meta_call(cls, *args, **kwargs)

if dataclasses.is_dataclass(node):
if isinstance(node, _HasSetup):
node.setup()

assert isinstance(node, GraphNode)
def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G:
node = cls.__new__(cls, *args, **kwargs)
vars(node)['_graph_node__state'] = ModuleState()
node.__init__(*args, **kwargs)

for field in dataclasses.fields(node):
if not field.init:
continue
value = vars(node)[field.name]
# set Rngs instances to None
if isinstance(value, Rngs):
vars(node)[field.name] = None

return node
return node


class GraphNode(reprlib.Representable, metaclass=GraphNodeMeta):
Expand All @@ -1162,13 +1141,15 @@ def __setattr__(self, name: str, value: Any) -> None:
self._setattr(name, value)

def _setattr(self, name: str, value: tp.Any) -> None:
if not self._graph_node__state.trace_state.is_valid():
raise errors.TraceContextError(
'Cannot mutate GraphNode from different trace level'
)

self.check_valid_context(
f"Cannot mutate '{type(self).__name__}' from different trace level"
)
object.__setattr__(self, name, value)

def check_valid_context(self, error_msg: str) -> None:
if not self._graph_node__state.trace_state.is_valid():
raise errors.TraceContextError(error_msg)

def __deepcopy__(self: G, memo=None) -> G:
state, graphdef, _ = graph_utils.graph_flatten(self)
graphdef = deepcopy(graphdef)
Expand Down
37 changes: 35 additions & 2 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import dataclasses
import typing as tp
from functools import partial

Expand All @@ -26,7 +27,7 @@
graph_utils,
)
from flax.experimental.nnx.nnx import variables as variableslib
from flax.experimental.nnx.nnx.graph_utils import GraphDef
from flax.experimental.nnx.nnx.graph_utils import GraphDef, GraphNodeMeta
from flax.experimental.nnx.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
Expand All @@ -47,8 +48,40 @@
tuple_reduce = lambda xs, x: xs + (x,)
tuple_init = lambda: ()

@tp.runtime_checkable
class _HasSetup(tp.Protocol):
def setup(self) -> None:
...


class ModuleMeta(GraphNodeMeta):
if not tp.TYPE_CHECKING:

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
return _module_meta_call(cls, *args, **kwargs)


def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M:
module: M = GraphNodeMeta.__call__(cls, *args, **kwargs)

if dataclasses.is_dataclass(module):
if isinstance(module, _HasSetup):
module.setup()

assert isinstance(module, Module)

for field in dataclasses.fields(module):
if not field.init:
continue
value = vars(module)[field.name]
# set Rngs instances to None
if isinstance(value, Rngs):
vars(module)[field.name] = None

return module


class Module(graph_utils.GraphNode):
class Module(graph_utils.GraphNode, metaclass=ModuleMeta):
@classmethod
def init(cls: type[M], *args, **kwargs) -> tuple[State, GraphDef[M]]:
return cls(*args, **kwargs).split()
Expand Down
3 changes: 2 additions & 1 deletion flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Optional, Sequence

import jax
import jax.numpy as jnp
from jax import lax, random

Expand Down Expand Up @@ -46,7 +47,7 @@ def __call__(
*,
deterministic: Optional[bool] = None,
rngs: Optional[rnglib.Rngs] = None,
):
) -> jax.Array:
"""Applies a random dropout mask to the input.
Args:
Expand Down
50 changes: 25 additions & 25 deletions flax/experimental/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
import typing as tp

import jax
import jax.numpy as jnp

from flax.experimental.nnx.nnx import errors, filterlib, tracers
from flax.experimental.nnx.nnx.variables import Variable
from flax.experimental.nnx.nnx import filterlib
from flax.experimental.nnx.nnx.graph_utils import GraphNode

Counts = list[int]
AxesValue = tp.Union[int, None]
Expand All @@ -46,20 +49,31 @@ class Missing:

MISSING = Missing()

class RngState(Variable[jax.Array]):
pass


@dataclasses.dataclass
class RngStream:
key: jax.Array # dynamic
count: int # static
class RngStream(GraphNode):
def __init__(
self,
key: jax.Array,
count: jax.Array,
):
self.key = RngState(key)
self.count = RngState(count)

def __post_init__(self):
if not isinstance(self.key, jax.Array):
raise TypeError(f'key must be a jax.Array, got {type(self.key)}')

def make_rng(self) -> jax.Array:
count = self.count
self.count += 1
return jax.random.fold_in(self.key, count)
self.check_valid_context(
"Cannot call 'make_rng' from a different trace level"
)
key = jax.random.fold_in(self.key.value, self.count.value)
self.count.value += 1
return key

def fork(self, pattern: Pattern) -> jax.Array:
if pattern is None:
Expand All @@ -70,8 +84,8 @@ def fork(self, pattern: Pattern) -> jax.Array:
num_splits = pattern
else:
num_splits = tuple(x if x is not None else 1 for x in pattern)
key = jax.random.split(self.key, num_splits)
self.count += 1
key = jax.random.split(self.key.value, num_splits)
self.count.value += 1
return key


Expand All @@ -83,9 +97,7 @@ def fork(self, pattern: Pattern) -> jax.Array:
]


class Rngs(tp.Mapping[str, tp.Callable[[], jax.Array]]):
__slots__ = ('_trace_state', '_rngs', '_counts')

class Rngs(GraphNode, tp.Mapping[str, tp.Callable[[], jax.Array]]):
def __init__(
self,
default: RngValue | RngDict | None = None,
Expand All @@ -101,17 +113,12 @@ def __init__(
self._rngs = {
name: RngStream(
key=jax.random.key(value) if isinstance(value, int) else value,
count=0,
count=jnp.array(0, dtype=jnp.uint32),
)
for name, value in rngs.items()
}
self._trace_state = tracers.TraceState()

def _make_rng(self, name: str, error_type: Exception) -> jax.Array:
if not self.is_valid():
raise errors.TraceContextError(
'Cannot use Rngs from a different trace level'
)
if name not in self._rngs:
if 'default' not in self._rngs:
raise error_type(f"No RNG named {name!r} or 'default' found in Rngs.")
Expand Down Expand Up @@ -144,20 +151,13 @@ def replace(self, **kwargs: tp.Union[int, jax.Array, RngStream]) -> 'Rngs':
rngs.update(kwargs)
return Rngs(**rngs)

def is_valid(self) -> bool:
return self._trace_state.is_valid()

def fork(
self,
_default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING,
/,
**patterns: Pattern,
) -> ForkedKeys:
if not self.is_valid():
raise errors.TraceContextError(
'Cannot use Rngs from a different trace level'
)

filter_patterns: list[tuple[filterlib.Filter, Pattern]]
if isinstance(_default, dict):
# merge default and patterns
Expand Down
Loading

0 comments on commit 453f38b

Please sign in to comment.