diff --git a/flax/experimental/nnx/nnx/filterlib.py b/flax/experimental/nnx/nnx/filterlib.py index dfb3e4792d..6bb3d23be5 100644 --- a/flax/experimental/nnx/nnx/filterlib.py +++ b/flax/experimental/nnx/nnx/filterlib.py @@ -14,7 +14,7 @@ import builtins import dataclasses -from flax.typing import Path +from flax.typing import PathParts import typing as tp if tp.TYPE_CHECKING: @@ -22,7 +22,7 @@ else: ellipsis = tp.Any -Predicate = tp.Callable[[Path, tp.Any], bool] +Predicate = tp.Callable[[PathParts, tp.Any], bool] FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None] Filter = tp.Union[FilterLiteral, tuple[FilterLiteral, ...], list[FilterLiteral]] @@ -48,17 +48,17 @@ def to_predicate(filter: Filter) -> Predicate: @dataclasses.dataclass class AtPath: - path: str + str_key: str - def __call__(self, path: Path, x: tp.Any): - return self.path == path + def __call__(self, path: PathParts, x: tp.Any): + return self.str_key in path @dataclasses.dataclass class OfType: type: type - def __call__(self, path: Path, x: tp.Any): + def __call__(self, path: PathParts, x: tp.Any): return isinstance(x, self.type) @@ -68,7 +68,7 @@ def __init__(self, *filters: Filter): to_predicate(collection_filter) for collection_filter in filters ) - def __call__(self, path: Path, x: tp.Any): + def __call__(self, path: PathParts, x: tp.Any): return any(predicate(path, x) for predicate in self.predicates) @@ -78,7 +78,7 @@ def __init__(self, *filters: Filter): to_predicate(collection_filter) for collection_filter in filters ) - def __call__(self, path: Path, x: tp.Any): + def __call__(self, path: PathParts, x: tp.Any): return all(predicate(path, x) for predicate in self.predicates) @@ -86,15 +86,15 @@ class Not: def __init__(self, collection_filter: Filter): self.predicate = to_predicate(collection_filter) - def __call__(self, path: Path, x: tp.Any): + def __call__(self, path: PathParts, x: tp.Any): return not self.predicate(path, x) class Everything: - def __call__(self, path: Path, x: tp.Any): + def __call__(self, path: PathParts, x: tp.Any): return True class Nothing: - def __call__(self, path: Path, x: tp.Any): + def __call__(self, path: PathParts, x: tp.Any): return False diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index b398b4eba2..a171ed4b39 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -37,11 +37,16 @@ CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf +from flax.experimental.nnx.nnx.state import ( + FlatState, + State, + StateLeaf, + is_state_leaf, +) from flax.experimental.nnx.nnx.rnglib import Rngs from flax.experimental.nnx.nnx.state import State from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable -from flax.typing import Path, PathParts +from flax.typing import PathParts, Key A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -129,28 +134,28 @@ def __str__(self) -> str: @dataclasses.dataclass(frozen=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): type: type - flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]] + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] - def node_dict(self, node: Node) -> dict[str, Leaf]: + def node_dict(self, node: Node) -> dict[Key, Leaf]: nodes, _ = self.flatten(node) return dict(nodes) @dataclasses.dataclass(frozen=True) class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): - set_key: tp.Callable[[Node, str, Leaf], None] - pop_key: tp.Callable[[Node, str], Leaf] + set_key: tp.Callable[[Node, Key, Leaf], None] + pop_key: tp.Callable[[Node, Key], Leaf] create_empty: tp.Callable[[AuxData], Node] clear: tp.Callable[[Node, AuxData], None] - def init(self, node: Node, items: tuple[tuple[str, Leaf], ...]): + def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): for key, value in items: self.set_key(node, key, value) @dataclasses.dataclass(frozen=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): - unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] + unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node] NodeImpl = tp.Union[ @@ -160,9 +165,9 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): def register_graph_node_type( type: type, - flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], - set_key: tp.Callable[[Node, str, Leaf], None], - pop_key: tp.Callable[[Node, str], Leaf], + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], + set_key: tp.Callable[[Node, Key, Leaf], None], + pop_key: tp.Callable[[Node, Key], Leaf], create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node, AuxData], None], ): @@ -243,7 +248,7 @@ def __repr__(self) -> str: @dataclasses.dataclass(repr=False) class _MappingRepr(reprlib.Representable): - mapping: tp.Mapping[str, tp.Any] + mapping: tp.Mapping[Key, tp.Any] def __nnx_repr__(self): yield reprlib.Object(type='', value_sep=': ', start='{', end='}') @@ -263,7 +268,7 @@ def __init__( self, type: tp.Type[Variable[tp.Any]], index: int, - metadata: dict[str, tp.Any], + metadata: dict[Key, tp.Any], ): self._type = type self._index = index @@ -332,10 +337,10 @@ def __init__( self, type: tp.Type[Node], index: int, - attributes: tuple[str, ...], - subgraphs: tp.Iterable[tuple[str, tp.Union['GraphDef[tp.Any]', int]]], - static_fields: tp.Iterable[tuple[str, tp.Any]], - variables: tp.Iterable[tuple[str, VariableDef | int]], + attributes: tuple[Key, ...], + subgraphs: tp.Iterable[tuple[Key, tp.Union['GraphDef[tp.Any]', int]]], + static_fields: tp.Iterable[tuple[Key, tp.Any]], + variables: tp.Iterable[tuple[Key, VariableDef | int]], metadata: tp.Any, ): self._type: type[Node] = type @@ -433,10 +438,10 @@ def _graphdef_unflatten( metadata: tuple[ tp.Type[Node], int, - tuple[str, ...], - tuple[tuple[str, GraphDef[Node] | int], ...], - tuple[tuple[str, tp.Any], ...], - tuple[tuple[str, Variable[Empty] | int], ...], + tuple[Key, ...], + tuple[tuple[Key, GraphDef[Node] | int], ...], + tuple[tuple[Key, tp.Any], ...], + tuple[tuple[Key, Variable[Empty] | int], ...], tp.Any, ], _, @@ -454,7 +459,7 @@ def graph_flatten( /, ) -> tuple[GraphDef[Node], State, tp.Mapping[tp.Any, Index]]: ref_to_index = RefMap[tp.Any, Index]() - flat_state: dict[Path, StateLeaf] = {} + flat_state: dict[PathParts, StateLeaf] = {} graphdef = _graph_flatten((), ref_to_index, flat_state, x) assert not isinstance(graphdef, int) return graphdef, State.from_flat_path(flat_state), ref_to_index @@ -463,7 +468,7 @@ def graph_flatten( def _graph_flatten( path: PathParts, ref_to_index: RefMap[tp.Any, Index], - flat_state: dict[Path, StateLeaf], + flat_state: dict[PathParts, StateLeaf], node: Node, ) -> GraphDef[Node] | int: if not is_node(node): @@ -481,33 +486,26 @@ def _graph_flatten( else: index = -1 - subgraphs: list[tuple[str, tp.Union[GraphDef[Node], int]]] = [] - static_fields: list[tuple[str, tp.Any]] = [] - variables: list[tuple[str, VariableDef | int]] = [] + subgraphs: list[tuple[Key, tp.Union[GraphDef[Node], int]]] = [] + static_fields: list[tuple[Key, tp.Any]] = [] + variables: list[tuple[Key, VariableDef | int]] = [] values, metadata = node_impl.flatten(node) for key, value in values: - if not isinstance(key, str): - raise TypeError( - f'Node (of type {type(node).__name__}) has a key of non-string ' - f'type {type(key).__name__}.' - ) if is_node(value): graphdef = _graph_flatten((*path, key), ref_to_index, flat_state, value) subgraphs.append((key, graphdef)) elif isinstance(value, Variable): - str_path = '/'.join((*path, key)) if value in ref_to_index: variables.append((key, ref_to_index[value])) else: - flat_state[str_path] = value.copy() + flat_state[(*path, key)] = value.copy() variable_index = ref_to_index[value] = len(ref_to_index) variables.append( (key, VariableDef.from_variable(value, variable_index)) ) elif is_state_leaf(value): - str_path = '/'.join((*path, key)) - flat_state[str_path] = value + flat_state[(*path, key)] = value else: static_fields.append((key, value)) @@ -702,16 +700,16 @@ def graph_pop( id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - states = tuple({} for _ in predicates) - _graph_pop(node, id_to_index, path_parts, states, predicates) - return tuple(State(x) for x in states) + flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates) + _graph_pop(node, id_to_index, path_parts, flat_states, predicates) + return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, - states: tuple[dict[Path, tp.Any], ...], + flat_states: tuple[FlatState, ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): @@ -726,17 +724,19 @@ def _graph_pop( for name, value in node_dict.items(): if is_node(value): - _graph_pop(value, id_to_index, (*path_parts, name), states, predicates) + _graph_pop( + value, id_to_index, (*path_parts, name), flat_states, predicates + ) continue elif not is_state_leaf(value): continue elif id(value) in id_to_index: continue - path = '/'.join((*path_parts, name)) + node_path = (*path_parts, name) node_impl = get_node_impl(node) - for state, predicate in zip(states, predicates): - if predicate(path, value): + for state, predicate in zip(flat_states, predicates): + if predicate(node_path, value): if isinstance(node_impl, PytreeNodeImpl): raise ValueError( f'Cannot pop key {name!r} from node of type {type(node).__name__}' @@ -745,7 +745,7 @@ def _graph_pop( node_impl.pop_key(node, name) if isinstance(value, Variable): value = value.copy() - state[path] = value + state[node_path] = value break else: # NOTE: should we raise an error here? @@ -1009,7 +1009,7 @@ def clone(node: Node) -> Node: return static.merge(state) -def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[Path, tp.Any]]: +def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[PathParts, tp.Any]]: visited: set[int] = set() path_parts: PathParts = () yield from _iter_nodes(node, visited, path_parts) @@ -1017,14 +1017,13 @@ def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[Path, tp.Any]]: def _iter_nodes( node: tp.Any, visited: set[int], path_parts: PathParts -) -> tp.Iterator[tuple[Path, tp.Any]]: +) -> tp.Iterator[tuple[PathParts, tp.Any]]: if not is_node(node): return if id(node) in visited: return visited.add(id(node)) - path = '/'.join(path_parts) - yield path, node + yield path_parts, node node_impl = get_node_impl(node) node_dict = node_impl.node_dict(node) for key, value in node_dict.items(): @@ -1204,6 +1203,29 @@ def __nnx_repr__(self): if clear_seen: CONTEXT.seen_modules_repr = None + def graph_node_keys(self) -> tp.Iterable[Key]: + return sorted(key for key in vars(self) if key != '_graph_node__state') + + def graph_node_get_value(self, key: Key) -> tp.Any: + if not isinstance(key, str): + raise KeyError(f'Invalid key: {key!r}') + return vars(self)[key] + + def graph_node_set_key(self, key: Key, value: tp.Any) -> None: + if not isinstance(key, str): + raise KeyError(f'Invalid key: {key!r}') + vars(self)[key] = value + + def graph_node_pop_key(self, key: Key) -> tp.Any: + if not isinstance(key, str): + raise KeyError(f'Invalid key: {key!r}') + return vars(self).pop(key) + + def graph_node_has_key(self, key: Key) -> bool: + if not isinstance(key, str): + raise KeyError(f'Invalid key: {key!r}') + return key in vars(self) + def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -1220,26 +1242,24 @@ def __init_subclass__(cls) -> None: # Graph Definition def _graph_node_flatten(node: GraphNode): nodes = tuple( - (name, value) - for name, value in sorted(vars(node).items()) - if name != '_graph_node__state' + (key, node.graph_node_get_value(key)) for key in node.graph_node_keys() ) return nodes, type(node) -def _graph_node_set_key(node: GraphNode, name: str, value: tp.Any): +def _graph_node_set_key(node: GraphNode, key: Key, value: tp.Any): if ( - hasattr(node, name) - and isinstance(variable := getattr(node, name), Variable) + node.graph_node_has_key(key) + and isinstance(variable := node.graph_node_get_value(key), Variable) and isinstance(value, Variable) ): variable.copy_from(value) else: - setattr(node, name, value) + node.graph_node_set_key(key, value) def _graph_node_pop_key(node: GraphNode, name: str): - return vars(node).pop(name) + return node.graph_node_pop_key(name) def _graph_node_create_empty(cls: tp.Type[G]) -> G: @@ -1266,13 +1286,13 @@ def is_pytree_node(x: tp.Any) -> bool: return not jax.tree_util.all_leaves([x]) -def _key_path_to_str(key: tp.Any) -> str: +def _key_path_to_key(key: tp.Any) -> Key: if isinstance(key, jax.tree_util.SequenceKey): - return str(key.idx) + return key.idx elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): - return str(key.key) + return key.key elif isinstance(key, jax.tree_util.GetAttrKey): return key.name else: @@ -1283,13 +1303,13 @@ def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree ) - nodes = tuple((_key_path_to_str(path[0]), value) for path, value in leaves) + nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves) return nodes, treedef def _unflatten_pytree( - nodes: tuple[tuple[str, tp.Any], ...], treedef: jax.tree_util.PyTreeDef + nodes: tuple[tuple[Key, tp.Any], ...], treedef: jax.tree_util.PyTreeDef ): pytree = treedef.unflatten(value for _, value in nodes) return pytree diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py index 0b4c86dd5b..098cde03d5 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/experimental/nnx/nnx/helpers.py @@ -34,6 +34,7 @@ import jax.numpy as jnp import optax +from flax.experimental.nnx.nnx.graph_utils import Key from flax.experimental.nnx.nnx.module import GraphDef, Module from flax.experimental.nnx.nnx.proxy_caller import ApplyCaller from flax.experimental.nnx.nnx.rnglib import Rngs @@ -105,6 +106,35 @@ def __iter__(self) -> tp.Iterator[A]: def __len__(self) -> int: return self._length + def graph_node_keys(self) -> tp.Iterable[Key]: + elem_keys = [ + int(key) for key in super().graph_node_keys() if key != '_length' + ] + elem_keys.sort() + yield from elem_keys + yield '_length' + + def graph_node_get_value(self, key: Key) -> tp.Any: + if isinstance(key, int): + key = str(key) # type: ignore + return super().graph_node_get_value(key) + + def graph_node_set_key(self, key: Key, value: tp.Any) -> None: + if isinstance(key, int): + key = str(key) # type: ignore + return super().graph_node_set_key(key, value) + + def graph_node_pop_key(self, key: Key) -> tp.Any: + if isinstance(key, int): + key = str(key) # type: ignore + return super().graph_node_pop_key(key) + + def graph_node_has_key(self, key: Key) -> bool: + if isinstance(key, int): + key = str(key) # type: ignore + return super().graph_node_has_key(key) + + class Sequential(List): def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: output: tp.Any = None diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index fb507e9eee..ae92cdce48 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -35,7 +35,7 @@ from flax.experimental.nnx.nnx.rnglib import Rngs from flax.experimental.nnx.nnx.state import State from flax.experimental.nnx.nnx.variables import Variable -from flax.typing import Path +from flax.typing import Path, PathParts A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -265,7 +265,7 @@ def sow( reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) - def modules(self) -> tp.Iterator[tuple[Path, Module]]: + def modules(self) -> tp.Iterator[tuple[PathParts, Module]]: for path, value in graph_utils.iter_nodes(self): if isinstance(value, Module): yield path, value diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py index 70307b67d4..99776ac47b 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -183,7 +183,8 @@ def fork( for name, stream in self._rngs.items(): for predicate, pattern in predicate_pattern: - if predicate(name, stream): + stream_path = (name,) + if predicate(stream_path, stream): fork = stream.fork(pattern) if pattern is None: broadcasts[name] = fork diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index 1aca5bb7ba..f287391a45 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -37,13 +37,12 @@ from flax import traverse_util from flax.experimental.nnx.nnx import filterlib, reprlib from flax.experimental.nnx.nnx.variables import Variable -from flax.typing import Path +from flax.typing import Key, PathParts A = tp.TypeVar('A') -Key = str StateLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array] -FlatState = dict[Path, StateLeaf] +FlatState = dict[PathParts, StateLeaf] def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: @@ -81,9 +80,7 @@ def __init__( def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: return self._mapping - def __getitem__(self, key: Key | int) -> State | StateLeaf: - if isinstance(key, int): - key = str(key) + def __getitem__(self, key: Key) -> State | StateLeaf: value = self._mapping[key] if is_state_leaf(value): return value @@ -94,10 +91,7 @@ def __getattr__(self, key: Key) -> State | StateLeaf: raise AttributeError(f"No attribute '{key}' in State") return self[key] - def __setitem__(self, key: Key | int, value: State | StateLeaf) -> None: - if isinstance(key, int): - key = str(key) - + def __setitem__(self, key: Key, value: State | StateLeaf) -> None: if isinstance(value, State): self._mapping[key] = value._mapping else: @@ -122,12 +116,12 @@ def __nnx_repr__(self): v = NestedStateRepr(v) yield reprlib.Attr(repr(k), v) - def flat_state(self) -> dict[Key, Variable[Variable]]: - return traverse_util.flatten_dict(self._mapping, sep='/') # type: ignore + def flat_state(self) -> dict[PathParts, StateLeaf]: + return traverse_util.flatten_dict(self._mapping) # type: ignore @classmethod def from_flat_path(cls, flat_state: FlatState, /) -> State: - nested_state = traverse_util.unflatten_dict(flat_state, sep='/') + nested_state = traverse_util.unflatten_dict(flat_state) return cls(nested_state) @tp.overload @@ -226,7 +220,6 @@ def __sub__(self, other: 'State') -> 'State': return State.from_flat_path(diff) - def _state_flatten_with_keys(x: State): items = sorted(x._mapping.items(), key=lambda item: item[0]) children = tuple((jtu.DictKey(key), value) for key, value in items) @@ -234,8 +227,8 @@ def _state_flatten_with_keys(x: State): def _state_unflatten( - static: tuple[Path, ...], - leaves: tuple[Variable, ...] | tuple[dict[str, Variable]], + static: tuple[Key, ...], + leaves: tuple[StateLeaf, ...] | tuple[dict[Key, StateLeaf]], ): return State(zip(static, leaves)) diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/experimental/nnx/tests/nn/test_attention.py index f312db7321..6ffca788ce 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/experimental/nnx/tests/nn/test_attention.py @@ -69,11 +69,12 @@ def __call__(self, x, sow_weights=False): _ = module(x, True) intermediates = module.pop(nnx.Intermediate) - assert intermediates['attention_layers/0/attention_weights'].raw_value[ + # assert intermediates['attention_layers/0/attention_weights'].raw_value[ + assert intermediates['attention_layers'][0]['attention_weights'].raw_value[ 0 ].shape == (4, 8, 6, 6) - assert 'attention_layers/1/attention_weights' not in intermediates - assert intermediates['attention_layers/2/attention_weights'].raw_value[ + assert 1 not in intermediates['attention_layers'] + assert intermediates['attention_layers'][2]['attention_weights'].raw_value[ 0 ].shape == (4, 8, 6, 6) diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index 85a30f4c5e..80ad57a9fe 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -27,8 +27,8 @@ def test_flatten(self): static, state, ref_idx = nnx.graph_utils.graph_flatten(g) - state['0']['b'].raw_value = 2 - state['3'].raw_value = 4 + state[0]['b'].raw_value = 2 + state[3].raw_value = 4 assert len(ref_idx) == 2 assert a['b'] in ref_idx @@ -69,7 +69,7 @@ def test_update_dynamic(self): static, state, _ = nnx.graph_utils.graph_flatten(g) - state['0']['b'].raw_value = 3 + state[0]['b'].raw_value = 3 nnx.graph_utils.graph_update_dynamic(g, state) assert g[0]['b'].raw_value == 3 @@ -125,12 +125,12 @@ def test_module_list(self): static, state, _ = nnx.graph_utils.graph_flatten(ls) - assert state['0']['kernel'].raw_value.shape == (2, 2) - assert state['0']['bias'].raw_value.shape == (2,) - assert state['1']['scale'].raw_value.shape == (2,) - assert state['1']['bias'].raw_value.shape == (2,) - assert state['1']['mean'].raw_value.shape == (2,) - assert state['1']['var'].raw_value.shape == (2,) + assert state[0]['kernel'].raw_value.shape == (2, 2) + assert state[0]['bias'].raw_value.shape == (2,) + assert state[1]['scale'].raw_value.shape == (2,) + assert state[1]['bias'].raw_value.shape == (2,) + assert state[1]['mean'].raw_value.shape == (2,) + assert state[1]['var'].raw_value.shape == (2,) def test_shared_variables(self): v = nnx.Param(1) diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index 42832a308c..199b038cfc 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -649,11 +649,11 @@ def __init__(self, *, rngs: nnx.Rngs): modules = list(module.modules()) assert len(modules) == 3 - assert modules[0][0] == '' + assert modules[0][0] == () assert isinstance(modules[0][1], Foo) - assert modules[1][0] == 'submodules/0/a' + assert modules[1][0] == ('submodules', 0, 'a') assert isinstance(modules[1][1], nnx.Linear) - assert modules[2][0] == 'submodules/1/b' + assert modules[2][0] == ('submodules', 1, 'b') assert isinstance(modules[2][1], nnx.Conv) def test_array_in_module(self): diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/experimental/nnx/tests/test_partitioning.py index 53d03d824b..7ed98c1e7b 100644 --- a/flax/experimental/nnx/tests/test_partitioning.py +++ b/flax/experimental/nnx/tests/test_partitioning.py @@ -33,11 +33,11 @@ def test_partition(self): assert len(rest) == 1 # check params - assert params['a']['0'].raw_value == m.a[0].value + assert params['a'][0].raw_value == m.a[0].value assert params['b'].raw_value == m.b.value # check rest - assert rest['a']['1'].raw_value == m.a[1].value + assert rest['a'][1].raw_value == m.a[1].value m2 = graphdef.merge(params, rest) @@ -152,8 +152,8 @@ def test_get_paritition(self): assert vars(m.a)['0'] is not vars(m)['b'] state = m.extract(nnx.Variable) - assert state['a']['0'].raw_value == m.a[0].value - assert state['a']['1'].raw_value == m.a[1].value + assert state['a'][0].raw_value == m.a[0].value + assert state['a'][1].raw_value == m.a[1].value assert state['b'].raw_value == m.b.value assert state.b is not state.a[0] assert len(state.flat_state()) == 3 diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py index 46d541cac0..4a079e0979 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -340,10 +340,10 @@ def f(m: nnx.Dict): assert m.a[0] is m.b assert isinstance(grads, nnx.State) - assert grads['a']['0'].raw_value == 2.0 - assert isinstance(grads.a['0'], nnx.Variable) - assert grads['a']['1'].raw_value == 1.0 - assert isinstance(grads.a['1'], nnx.Variable) + assert grads['a'][0].raw_value == 2.0 + assert isinstance(grads.a[0], nnx.Variable) + assert grads['a'][1].raw_value == 1.0 + assert isinstance(grads.a[1], nnx.Variable) assert len(grads.flat_state()) == 2 m.update(grads) @@ -371,8 +371,8 @@ def f(m: nnx.Dict): grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a']['0'].raw_value == 1.0 - assert isinstance(grads.a['0'], nnx.Param) + assert grads['a'][0].raw_value == 1.0 + assert isinstance(grads.a[0], nnx.Param) assert len(grads) == 2 m.update(grads) @@ -399,8 +399,8 @@ def f(m: nnx.Dict): grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a']['1'].raw_value == 1.0 - assert isinstance(grads.a['1'], nnx.BatchStat) + assert grads['a'][1].raw_value == 1.0 + assert isinstance(grads.a[1], nnx.BatchStat) assert len(grads) == 1 m.update(grads) diff --git a/flax/typing.py b/flax/typing.py index 70cd5c3f01..d566a31c8a 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -17,8 +17,10 @@ Callable, Dict, Generic, + Hashable, Mapping, Optional, + Protocol, Sequence, Tuple, TypeVar, @@ -38,9 +40,16 @@ RNGSequences = Dict[str, PRNGKey] Dtype = Union[jax.typing.DTypeLike, Any] Shape = Sequence[int] +K = TypeVar('K') + + +class Key(Hashable, Protocol): + def __lt__(self: K, value: K, /) -> bool: + ... + Path = str -PathParts = Tuple[str, ...] +PathParts = Tuple[Key, ...] Leaf = Any