diff --git a/flax/core/meta.py b/flax/core/meta.py index 27686a40b5..531b463c7d 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -22,6 +22,7 @@ """ import abc +import dataclasses import functools from typing import Any, Generic, TypeVar from collections.abc import Callable @@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: """Returns the ``NamedSharding`` for this partitioned value.""" return jax.sharding.NamedSharding(mesh, self.get_partition_spec()) + def to_nnx_metadata(self) -> dict[str, Any]: + """Return a dict of metadata that can translate into an `nnx.Variable`.""" + metadata = vars(self) + metadata['sharding'] = metadata.pop('names') + return metadata + + @classmethod + def from_nnx_metadata(cls, metadata: dict[str, Any]): + """Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" + metadata['names'] = metadata.pop('sharding') + fields = {x.name for x in dataclasses.fields(cls)} + return cls(**{k: v for k, v in metadata.items() if k in fields}) + def with_partitioning( fn: Callable[..., Any], diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 93afab7646..cd622bbdae 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any: else: return self.value + def to_nnx_metadata(self) -> dict[str, Any]: + """Return a dict of metadata that can translate into an `nnx.Variable`.""" + metadata = vars(self) + metadata['sharding'] = metadata.pop('names') + metadata['sharding_rules'] = metadata.pop('rules') + return metadata + + @classmethod + def from_nnx_metadata(cls, metadata: dict[str, Any]): + """Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" + metadata['names'] = metadata.pop('sharding') + metadata['rules'] = metadata.pop('sharding_rules') + fields = {x.name for x in dataclasses.fields(cls)} + return cls(**{k: v for k, v in metadata.items() if k in fields}) + def with_logical_partitioning( fn: Callable[..., Any], diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 9d8714274a..d73f645f3b 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str: def register_variable_name_type_pair(name, typ, overwrite = False): - """Register a pair of variable type name (like Linen collections) and its NNX type.""" + """Register a pair of Linen collection name and its NNX type.""" if not overwrite and name in VariableTypeCache: raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' - 'To overwrite, call with `overwrite=True`.') + 'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.') VariableTypeCache[name] = typ @@ -85,8 +85,7 @@ def _variable_parents_count(t: type): class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): - """Default Flax metadata class for `nnx.VariableState`. - """ + """Default Flax metadata class for `nnx.VariableState`.""" var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False) value: Any = struct.field(pytree_node=True) @@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: - if metadata['linen_meta_type'] is not meta.Partitioned: - raise ValueError('Not supporting Linen metadata types other than nn.Partitioned') - return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh']) - return NNXMeta(vs.type, vs.value, vs.get_metadata()) + linen_type = metadata['linen_meta_type'] + if hasattr(linen_type, 'from_nnx_metadata'): + return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) + return linen_type(vs.value, **metadata) + return NNXMeta(vs.type, vs.value, metadata) def get_col_name(keypath: tp.Sequence[Any]) -> str: @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str: def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: - """Convert a Linen variable to an NNX variable. - This process needs the collection name, - """ + """Convert a Linen variable to an NNX variable.""" vtype = variable_type(col) if isinstance(x, NNXMeta): assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' return x.var_type(x.value, **x.metadata) if isinstance(x, meta.AxisMetadata): - if isinstance(x, meta.Partitioned): - return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned) - raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta') - return vtype(x) + x_metadata = vars(x) + if hasattr(x, 'to_nnx_metadata'): + x_metadata = x.to_nnx_metadata() + assert hasattr(x, 'value') + return vtype(**x_metadata, linen_meta_type=type(x)) + return vtype(x) \ No newline at end of file diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 20ac7a2601..d209d89819 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): module = fn assert callable(fn) else: - if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)): + if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module): raise ValueError(f'{fn = } needs to be a method of an NNX Module.') module = fn.__self__ _set_initializing(module, True) @@ -124,6 +124,7 @@ def __init__( self.linen_collections: tuple[str, ...] = () def lazy_init(self, *args, **kwargs): + """A shortcut of calling `nnx.bridge.lazy_init()` upon this module.""" return lazy_init(self, *args, **kwargs) def __call__( @@ -224,28 +225,6 @@ class ToLinen(linen.Module): skip_rng: bool = False metadata_type: tp.Type = bv.NNXMeta - def update_variables(self, module): - """Store the NNX module's graph def and state inside Linen module variables.""" - gdef, state = nnx.split(module) - # Save the graph def. - if self.is_mutable_collection('nnx'): - self.put_variable('nnx', 'graphdef', gdef) - # Sort all the variable types. - types = set(jax.tree.leaves( - jax.tree.map(lambda x: x.type, state, - is_leaf=lambda x: isinstance(x, nnx.VariableState)))) - types = bv.sort_variable_types(types) - _, *state_by_types = nnx.split(module, *types) - # Each variable type goes to its own linen collection, and - # each attribute goes to its own linen variable - for typ, state in zip(types, state_by_types): - collection = bv.variable_type_name(typ) - if self.is_mutable_collection(collection): - for k, v in state.raw_mapping.items(): - v = jax.tree.map(bv.to_linen_var, v, - is_leaf=lambda x: isinstance(x, nnx.VariableState)) - self.put_variable(collection, k, v) - @linen.compact def __call__(self, *args, **kwargs): # init codepath @@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs): module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self))) module = self.nnx_class(*self.args, **module_kwargs) # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. - self.update_variables(module) + self._update_variables(module) return module(*args, **kwargs) # apply codepath @@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs): module = nnx.merge(gdef, nnx_state) nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. out = module(*args, **kwargs) - self.update_variables(module) + self._update_variables(module) return out + def _update_variables(self, module): + """Store the NNX module's graph def and state inside Linen module variables.""" + gdef, state = nnx.split(module) + # Save the graph def. + if self.is_mutable_collection('nnx'): + self.put_variable('nnx', 'graphdef', gdef) + # Sort all the variable types. + types = set(jax.tree.leaves( + jax.tree.map(lambda x: x.type, state, + is_leaf=lambda x: isinstance(x, nnx.VariableState)))) + types = bv.sort_variable_types(types) + _, *state_by_types = nnx.split(module, *types) + # Each variable type goes to its own linen collection, and + # each attribute goes to its own linen variable + for typ, state in zip(types, state_by_types): + collection = bv.variable_type_name(typ) + if self.is_mutable_collection(collection): + for k, v in state.raw_mapping.items(): + v = jax.tree.map(bv.to_linen_var, v, + is_leaf=lambda x: isinstance(x, nnx.VariableState)) + self.put_variable(collection, k, v) + def to_linen(nnx_class: tp.Callable[..., Module], *args, name: str | None = None, **kwargs): - """Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields.""" + """Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields.""" return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name) \ No newline at end of file diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index e18003276b..a7acbbc418 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -89,9 +89,15 @@ def _maybe_replicate(x): else: return None + def from_rules(sharding, sharding_rules): + rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} + return (rules[s] if s in rules else s for s in sharding) + def f(x): if isinstance(x, (variables.VariableState, variables.Variable)): if hasattr(x, 'sharding') and x.sharding: + if hasattr(x, 'sharding_rules') and x.sharding_rules: + return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules))) return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.value)) diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 72d42eb6d4..27f2927fd9 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' from absl.testing import absltest import flax @@ -24,6 +26,12 @@ class TestCompatibility(absltest.TestCase): + def setUp(self): + super().setUp() + dim1 = max(jax.device_count() // 2, 1) + device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1) + self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out')) + def test_functional(self): # Functional API for NNX Modules functional = bridge.functional(nnx.Linear)(32, 64) @@ -135,21 +143,35 @@ def vmap_fn(inner, x): def test_linen_to_nnx_metadata(self): linen_module = nn.Dense( features=64, - kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out'))) + kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',), + rules=(('out-alias', 'out'),)), + ) x = jax.numpy.ones((1, 32)) linen_vars = linen_module.init(jax.random.key(0), x) - nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) - # nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box. + + @nnx.jit + def create_sharded_nnx_module(x): + model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x) + state = nnx.state(model) + sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) + nnx.update(model, sharded_state) + return model + with self.mesh: + nnx_model = create_sharded_nnx_module(x) + + # nn.Partitioned metadata boxes translated into valid nnx.Variable boxes. self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) + self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable) - np.testing.assert_array_equal(linen_vars['params']['kernel'].value, - nnx_model.params['kernel'].value) assert nnx_model.params['kernel'].sharding == ('in', 'out') - _, nnx_state = nnx.split(nnx_model) - self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState) - np.testing.assert_array_equal(linen_vars['params']['kernel'].value, - nnx_state['params']['kernel'].value) - assert nnx_state['params']['kernel'].sharding == ('in', 'out') + assert nnx_model.params['kernel'].value.sharding.is_equivalent_to( + jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2) + + assert nnx_model.params['bias'].sharding == ('out-alias',) + assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),) + assert nnx_model.params['bias'].value.sharding.is_equivalent_to( + jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1) ################## @@ -306,7 +328,9 @@ class LinenMiddle(nn.Module): @nn.compact def __call__(self, x): dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot') - b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout)) + logical_init = nn.with_logical_partitioning( + nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out'))) + b = self.param('b', logical_init, (1, self.dout)) return dot(x) + b class NNXOuter(nnx.Module): @@ -335,6 +359,7 @@ def __call__(self, x): self.assertIsInstance(w, nnx.Param) np.testing.assert_allclose(model(x), x @ w + b) assert hasattr(w, 'sharding') and w.sharding == ('in', 'out') + assert hasattr(b, 'sharding') and b.sharding == ('out-alias', ) def test_linen_nnx_linen(self): # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index be487628fe..ecd6f5c790 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -323,7 +323,7 @@ def f(m: Foo): def test_apply_shardings(self): n_devices = max(jax.local_device_count() // 2, 1) - devices = mesh_utils.create_device_mesh((n_devices, n_devices)) + devices = mesh_utils.create_device_mesh((n_devices, jax.local_device_count() // n_devices)) mesh = jax.sharding.Mesh(devices, ('a', 'b')) def sharding(*args):