From 69519f28419a37d02b5a6d1e9d64d8e6af98272e Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 2 Apr 2024 16:32:50 +0100 Subject: [PATCH] [nnx] remove pytreelib --- flax/experimental/nnx/__init__.py | 2 - flax/experimental/nnx/nnx/pytreelib.py | 269 --------------------- flax/experimental/nnx/nnx/spmd.py | 12 +- flax/experimental/nnx/tests/test_module.py | 6 +- flax/experimental/nnx/tests/test_pytree.py | 265 -------------------- 5 files changed, 8 insertions(+), 546 deletions(-) delete mode 100644 flax/experimental/nnx/nnx/pytreelib.py delete mode 100644 flax/experimental/nnx/tests/test_pytree.py diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 87d9f45880..9a8bb1e742 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -74,8 +74,6 @@ from .nnx.nn.normalization import LayerNorm as LayerNorm from .nnx.nn.normalization import RMSNorm as RMSNorm from .nnx.nn.stochastic import Dropout as Dropout -from .nnx.pytreelib import Pytree as Pytree -from .nnx.pytreelib import TreeNode as TreeNode from .nnx.rnglib import Rngs as Rngs from .nnx.rnglib import RngStream as RngStream from .nnx.rnglib import RngState as RngState diff --git a/flax/experimental/nnx/nnx/pytreelib.py b/flax/experimental/nnx/nnx/pytreelib.py deleted file mode 100644 index 1305e9f701..0000000000 --- a/flax/experimental/nnx/nnx/pytreelib.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import dataclasses -import importlib.util -import inspect -import typing as tp -from abc import ABCMeta -from copy import copy -from functools import partial -from types import MappingProxyType - -import jax -import numpy as np - -from flax.experimental.nnx.nnx import module as modulelib -from flax.experimental.nnx.nnx import reprlib, variables -from flax.experimental.nnx.nnx.state import State - -A = tp.TypeVar('A') -P = tp.TypeVar('P', bound='Pytree') - - -class TreeNode(variables.Variable[A]): - pass - - -@contextlib.contextmanager -def _mutable(obj: P) -> tp.Iterator[None]: - vars(obj)['_pytree__is_mutable'] = True - try: - yield - finally: - del vars(obj)['_pytree__is_mutable'] - - -@contextlib.contextmanager -def _initializing(obj: P) -> tp.Iterator[None]: - vars(obj)['_pytree__initializing'] = True - try: - yield - finally: - del vars(obj)['_pytree__initializing'] - - -class PytreeMeta(ABCMeta): - if not tp.TYPE_CHECKING: - - def __call__(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: - return cls.call(*args, **kwargs) - - def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: - obj: P = cls.__new__(cls, *args, **kwargs) - vars(obj)['_pytree__sorted_fields'] = ['_pytree__sorted_fields'] - - with _mutable(obj), _initializing(obj): - obj.__init__(*args, **kwargs) - - vars(obj)['_pytree__sorted_fields'] = sorted(vars(obj)) - - return obj - - -class Pytree(reprlib.Representable, metaclass=PytreeMeta): - _pytree__is_mutable: bool - _pytree__class_is_mutable: bool - _pytree__sorted_fields: tp.Tuple[str, ...] - - if not tp.TYPE_CHECKING: - - def __setattr__(self, name: str, value: tp.Any) -> None: - self._setattr(name, value) - - def _setattr(self: P, name: str, value: tp.Any): - vars_dict = vars(self) - if '_pytree__initializing' in vars_dict: - pass - elif name not in vars_dict: - raise AttributeError(r'Cannot add new fields to an initialized Pytree') - elif ( - '_pytree__is_mutable' not in vars_dict - and not self._pytree__class_is_mutable - ): - raise AttributeError( - f'{type(self)} is immutable, trying to update field {name}' - ) - - if isinstance(value, (jax.Array, np.ndarray, State)): - raise ValueError( - f"Trying to assign a '{type(value).__name__}' to the Module" - f" attribute '{name}'. This is not supported. Non-hashable " - 'objects are not valid static state in JAX. Please wrap ' - 'the value in a Variable type instead.' - ) - vars_dict[name] = value - - def __init_subclass__(cls, mutable: bool = False): - super().__init_subclass__() - # init class variables - cls._pytree__is_mutable = False - cls._pytree__class_is_mutable = mutable - - # TODO: clean up this in the future once minimal supported version is 0.4.7 - if hasattr(jax.tree_util, 'register_pytree_with_keys'): - if ( - 'flatten_func' - in inspect.signature(jax.tree_util.register_pytree_with_keys).parameters - ): - jax.tree_util.register_pytree_with_keys( - cls, - partial( - cls._pytree__flatten, - with_key_paths=True, - ), - cls._pytree__unflatten, - flatten_func=partial( - cls._pytree__flatten, - with_key_paths=False, - ), - ) - else: - jax.tree_util.register_pytree_with_keys( - cls, - partial( - cls._pytree__flatten, - with_key_paths=True, - ), - cls._pytree__unflatten, - ) - else: - jax.tree_util.register_pytree_node( - cls, - partial( - cls._pytree__flatten, - with_key_paths=False, - ), - cls._pytree__unflatten, - ) - - # flax serialization support - if importlib.util.find_spec('flax') is not None: - from flax import serialization - - serialization.register_serialization_state( - cls, cls._to_flax_state_dict, cls._from_flax_state_dict - ) - - @classmethod - def _pytree__flatten( - cls, - pytree: 'Pytree', - *, - with_key_paths: bool, - ): - all_vars = vars(pytree) - static = {} - node_values = [] - node_names = [] - - for field in pytree._pytree__sorted_fields: - value = all_vars[field] - - if isinstance(value, (modulelib.Module, variables.Variable, Pytree)): - node_names.append(field) - if with_key_paths: - node_values.append((jax.tree_util.GetAttrKey(field), value)) - else: - node_values.append(value) - else: - static[field] = value - - return node_values, (tuple(node_names), MappingProxyType(static)) - - @classmethod - def _pytree__unflatten( - cls: tp.Type[P], - metadata: tp.Tuple[tp.Tuple[str, ...], tp.Mapping[str, tp.Any]], - node_values: tp.Tuple[tp.Any, ...], - ) -> P: - node_names, static_fields = metadata - pytree = object.__new__(cls) - pytree.__dict__.update(zip(node_names, node_values)) - pytree.__dict__.update(static_fields) - return pytree - - @classmethod - def _to_flax_state_dict(cls, pytree: 'Pytree') -> dict[str, tp.Any]: - from flax import serialization - - state_dict = { - name: serialization.to_state_dict(getattr(pytree, name)) - for name, value in vars(pytree).items() - if isinstance(value, (modulelib.Module, variables.Variable, Pytree)) - } - return state_dict - - @classmethod - def _from_flax_state_dict( - cls, - pytree: P, - state: dict[str, tp.Any], - ) -> P: - """Restore the state of a data class.""" - from flax import serialization - - state = state.copy() # copy the state so we can pop the restored fields. - updates = {} - for name, value in vars(pytree).items(): - if not isinstance(value, (modulelib.Module, variables.Variable, Pytree)): - continue - if name not in state: - raise ValueError( - f'Missing field {name} in state dict while restoring' - f' an instance of {type(pytree).__name__},' - f' at path {serialization.current_path()}' - ) - value_state = state.pop(name) - updates[name] = serialization.from_state_dict( - value, value_state, name=name - ) - if state: - names = ','.join(state.keys()) - raise ValueError( - f'Unknown field(s) "{names}" in state dict while' - f' restoring an instance of {type(pytree).__name__}' - f' at path {serialization.current_path()}' - ) - return pytree.replace(**updates) - - def replace(self: P, **kwargs: tp.Any) -> P: - """ - Replace the values of the fields of the object with the values of the - keyword arguments. If the object is a dataclass, `dataclasses.replace` - will be used. Otherwise, a new object will be created with the same - type as the original object. - """ - if dataclasses.is_dataclass(self): - return dataclasses.replace(self, **kwargs) - - unknown_keys = set(kwargs) - set(vars(self)) - if unknown_keys and not self._pytree__class_is_mutable: - raise ValueError( - f'Trying to replace unknown fields {unknown_keys} ' - f"for '{type(self).__name__}'" - ) - - pytree = copy(self) - with _mutable(pytree): - for key, value in kwargs.items(): - setattr(pytree, key, value) - - return pytree - - def __nnx_repr__(self): - yield reprlib.Object(type(self)) - for name, value in vars(self).items(): - yield reprlib.Attr(name, repr(value)) diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py index 51f5f93f30..0554901487 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/experimental/nnx/nnx/spmd.py @@ -20,7 +20,6 @@ from jax.sharding import Mesh, PartitionSpec from flax.experimental.nnx.nnx import variables -from flax.experimental.nnx.nnx.pytreelib import TreeNode from flax.experimental.nnx.nnx.state import State from flax.typing import ( Array, @@ -107,17 +106,16 @@ def f(x): return _maybe_replicate(x) - return jax.tree_util.tree_map( - f, - tree, - is_leaf=lambda x: isinstance(x, variables.Variable) - and not isinstance(x, TreeNode), + return jax.tree_map( + f, tree, is_leaf=lambda x: isinstance(x, variables.Variable) ) def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) - sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), spec) + sharding = jax.tree_util.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), spec + ) return sharding diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index 405cef10b2..ef44281217 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -527,7 +527,7 @@ def test_basic(self): @dataclasses.dataclass class Foo(nnx.Module): a: int - b: nnx.TreeNode[int] + b: nnx.Variable[int] c: nnx.Param[int] d: nnx.Variable[int] e: nnx.Variable[int] @@ -535,7 +535,7 @@ class Foo(nnx.Module): m = Foo( a=1, # static - b=nnx.TreeNode(2), # node + b=nnx.Variable(2), # node c=nnx.Param(3), # param d=nnx.Variable(4), # var e=nnx.BatchStat(5), # var @@ -545,7 +545,7 @@ class Foo(nnx.Module): graphdef, state = m.split() assert len(state) == 4 - assert state.b == nnx.TreeNode(2) + assert state.b == nnx.Variable(2) assert state.c == nnx.Param(3) assert state.d == nnx.Variable(4) assert state.e == nnx.BatchStat(5) diff --git a/flax/experimental/nnx/tests/test_pytree.py b/flax/experimental/nnx/tests/test_pytree.py deleted file mode 100644 index e34a14664f..0000000000 --- a/flax/experimental/nnx/tests/test_pytree.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses -from typing import Generic, TypeVar - -import jax -import pytest - -from flax import serialization -from flax.experimental import nnx - - -class TestPytree: - def test_immutable_pytree(self): - class Foo(nnx.Pytree): - def __init__(self, y) -> None: - self.x = 2 - self.y = nnx.Variable(y) - - pytree = Foo(y=3) - - leaves = jax.tree_util.tree_leaves(pytree) - assert leaves == [3] - - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) - assert pytree.x == 2 - assert pytree.y.value == 6 - - pytree = pytree.replace(x=3) - assert pytree.x == 3 - assert pytree.y.value == 6 - - with pytest.raises( - AttributeError, match='is immutable, trying to update field' - ): - pytree.x = 4 - - def test_immutable_pytree_dataclass(self): - @dataclasses.dataclass(frozen=True) - class Foo(nnx.Pytree): - y: nnx.TreeNode[int] - x: int = dataclasses.field(default=2) - - pytree = Foo(y=nnx.TreeNode(3)) - - leaves = jax.tree_util.tree_leaves(pytree) - assert leaves == [3] - - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) - assert pytree.x == 2 - assert pytree.y.value == 6 - - pytree = pytree.replace(x=3) - assert pytree.x == 3 - assert pytree.y.value == 6 - - with pytest.raises(AttributeError, match='cannot assign to field'): - pytree.x = 4 - - def test_jit(self): - @dataclasses.dataclass - class Foo(nnx.Pytree): - a: nnx.TreeNode[int] - b: int = dataclasses.field() - - module = Foo(a=nnx.TreeNode(1), b=2) - - @jax.jit - def f(m: Foo): - return m.a.value + m.b - - assert f(module) == 3 - - def test_flax_serialization(self): - class Bar(nnx.Pytree): - def __init__(self, a, b): - self.a = a - self.b = nnx.Variable(b) - - @dataclasses.dataclass - class Foo(nnx.Pytree): - bar: Bar - c: nnx.TreeNode[int] - d: int = dataclasses.field() - - foo: Foo = Foo(bar=Bar(a=1, b=2), c=nnx.TreeNode(3), d=4) - - state_dict = serialization.to_state_dict(foo) - - assert state_dict == { - 'bar': { - 'b': nnx.Variable(2), - }, - 'c': nnx.TreeNode(3), - } - - state_dict['bar']['b'] = nnx.Variable(5) - - foo = serialization.from_state_dict(foo, state_dict) - - assert foo.bar.b == nnx.Variable(5) - - del state_dict['bar']['b'] - - with pytest.raises(ValueError, match='Missing field'): - serialization.from_state_dict(foo, state_dict) - - state_dict['bar']['b'] = 5 - - # add unknown field - state_dict['x'] = 6 - - with pytest.raises(ValueError, match='Unknown field'): - serialization.from_state_dict(foo, state_dict) - - def test_generics(self): - T = TypeVar('T') - - class MyClass(nnx.Pytree, Generic[T]): - def __init__(self, x: T): - self.x = x - - MyClass[int] - - def test_key_paths(self): - @dataclasses.dataclass - class Bar(nnx.Pytree): - a: nnx.TreeNode[int] = dataclasses.field(default_factory=lambda: nnx.TreeNode(1)) - b: int = dataclasses.field(default=2) - - @dataclasses.dataclass - class Foo(nnx.Pytree): - x: nnx.TreeNode[int] = dataclasses.field(default_factory=lambda: nnx.TreeNode(3)) - y: int = dataclasses.field(default=4) - z: nnx.TreeNode[Bar] = dataclasses.field(default_factory=lambda: nnx.TreeNode(Bar())) - - foo = Foo() - - path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) - path_values = [(list(map(str, path)), value) for path, value in path_values] - - assert path_values[0] == (['.x', '.raw_value'], 3) - assert path_values[1] == (['.z', '.raw_value', '.a', '.raw_value'], 1) - - def test_replace_unknown_fields_error(self): - class Foo(nnx.Pytree): - pass - - with pytest.raises(ValueError, match='Trying to replace unknown fields'): - Foo().replace(y=1) - - def test_dataclass_inheritance(self): - @dataclasses.dataclass - class A(nnx.Pytree): - a: nnx.TreeNode[int] = dataclasses.field(default_factory=lambda: nnx.TreeNode(1)) - b: int = dataclasses.field(default=2) - - @dataclasses.dataclass - class B(A): - c: nnx.TreeNode[int] = dataclasses.field(default_factory=lambda: nnx.TreeNode(3)) - - pytree = B() - leaves = jax.tree_util.tree_leaves(pytree) - assert leaves == [1, 3] - - def test_pytree_with_new(self): - class A(nnx.Pytree): - def __init__(self, a): - self.a = a - - def __new__(cls, a): - return super().__new__(cls) - - pytree = A(a=1) - - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) - - def test_deterministic_order(self): - class A(nnx.Pytree): - def __init__(self, order: bool): - if order: - self.a = 1 - self.b = 2 - else: - self.b = 2 - self.a = 1 - - p1 = A(order=True) - p2 = A(order=False) - - leaves1 = jax.tree_util.tree_leaves(p1) - leaves2 = jax.tree_util.tree_leaves(p2) - - assert leaves1 == leaves2 - - -class TestMutablePytree: - def test_pytree(self): - class Foo(nnx.Pytree, mutable=True): - def __init__(self, y) -> None: - self.x = 2 - self.y = nnx.Variable(y) - - pytree = Foo(y=3) - - leaves = jax.tree_util.tree_leaves(pytree) - assert leaves == [3] - - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) - assert pytree.x == 2 - assert pytree.y.value == 6 - - pytree = pytree.replace(x=3) - assert pytree.x == 3 - assert pytree.y.value == 6 - - # test mutation - pytree.x = 4 - assert pytree.x == 4 - - def test_no_new_fields_after_init(self): - class Foo(nnx.Pytree, mutable=True): - def __init__(self, x): - self.x = nnx.Variable(x) - - foo = Foo(x=1) - foo.x = 2 - - with pytest.raises(AttributeError, match=r'Cannot add new fields to'): - foo.y = 2 - - def test_pytree_dataclass(self): - @dataclasses.dataclass - class Foo(nnx.Pytree, mutable=True): - y: nnx.TreeNode[int] - x: int = dataclasses.field(default=2) - - pytree: Foo = Foo(y=nnx.TreeNode(3)) - - leaves = jax.tree_util.tree_leaves(pytree) - assert leaves == [3] - - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) - assert pytree.x == 2 - assert pytree.y.value == 6 - - pytree = pytree.replace(x=3) - assert pytree.x == 3 - assert pytree.y.value == 6 - - # test mutation - pytree.x = 4 - assert pytree.x == 4