Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support linen.LogicallyPartitioned <-> nnx.Variable #4161

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import abc
import dataclasses
import functools
from typing import Any, Generic, TypeVar
from collections.abc import Callable
Expand Down Expand Up @@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
"""Returns the ``NamedSharding`` for this partitioned value."""
return jax.sharding.NamedSharding(mesh, self.get_partition_spec())

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
metadata['names'] = metadata.pop('sharding')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_partitioning(
fn: Callable[..., Any],
Expand Down
15 changes: 15 additions & 0 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any:
else:
return self.value

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
metadata['sharding_rules'] = metadata.pop('rules')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
metadata['names'] = metadata.pop('sharding')
metadata['rules'] = metadata.pop('sharding_rules')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_logical_partitioning(
fn: Callable[..., Any],
Expand Down
30 changes: 15 additions & 15 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:


def register_variable_name_type_pair(name, typ, overwrite = False):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
"""Register a pair of Linen collection name and its NNX type."""
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call with `overwrite=True`.')
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
VariableTypeCache[name] = typ


Expand All @@ -85,8 +85,7 @@ def _variable_parents_count(t: type):


class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
"""Default Flax metadata class for `nnx.VariableState`.
"""
"""Default Flax metadata class for `nnx.VariableState`."""

var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
value: Any = struct.field(pytree_node=True)
Expand All @@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
if 'linen_meta_type' in metadata:
if metadata['linen_meta_type'] is not meta.Partitioned:
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
return NNXMeta(vs.type, vs.value, vs.get_metadata())
linen_type = metadata['linen_meta_type']
if hasattr(linen_type, 'from_nnx_metadata'):
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
return linen_type(vs.value, **metadata)
return NNXMeta(vs.type, vs.value, metadata)


def get_col_name(keypath: tp.Sequence[Any]) -> str:
Expand All @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:


def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
"""Convert a Linen variable to an NNX variable.
This process needs the collection name,
"""
"""Convert a Linen variable to an NNX variable."""
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
if isinstance(x, meta.AxisMetadata):
if isinstance(x, meta.Partitioned):
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
return vtype(x)
x_metadata = vars(x)
if hasattr(x, 'to_nnx_metadata'):
x_metadata = x.to_nnx_metadata()
assert hasattr(x, 'value')
return vtype(**x_metadata, linen_meta_type=type(x))
return vtype(x)
53 changes: 27 additions & 26 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
module = fn
assert callable(fn)
else:
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
module = fn.__self__
_set_initializing(module, True)
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(
self.linen_collections: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
return lazy_init(self, *args, **kwargs)

def __call__(
Expand Down Expand Up @@ -224,28 +225,6 @@ class ToLinen(linen.Module):
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)

@linen.compact
def __call__(self, *args, **kwargs):
# init codepath
Expand All @@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs):
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
self.update_variables(module)
self._update_variables(module)
return module(*args, **kwargs)

# apply codepath
Expand All @@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs):
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
self._update_variables(module)
return out

def _update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)


def to_linen(nnx_class: tp.Callable[..., Module], *args,
name: str | None = None, **kwargs):
"""Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields."""
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
6 changes: 6 additions & 0 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@ def _maybe_replicate(x):
else:
return None

def from_rules(sharding, sharding_rules):
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return (rules[s] if s in rules else s for s in sharding)

def f(x):
if isinstance(x, (variables.VariableState, variables.Variable)):
if hasattr(x, 'sharding') and x.sharding:
if hasattr(x, 'sharding_rules') and x.sharding_rules:
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
return x.replace(PartitionSpec(*x.sharding))
else:
return x.replace(_maybe_replicate(x.value))
Expand Down
47 changes: 36 additions & 11 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

from absl.testing import absltest
import flax
Expand All @@ -24,6 +26,12 @@


class TestCompatibility(absltest.TestCase):
def setUp(self):
super().setUp()
dim1 = max(jax.device_count() // 2, 1)
device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1)
self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out'))

def test_functional(self):
# Functional API for NNX Modules
functional = bridge.functional(nnx.Linear)(32, 64)
Expand Down Expand Up @@ -135,21 +143,35 @@ def vmap_fn(inner, x):
def test_linen_to_nnx_metadata(self):
linen_module = nn.Dense(
features=64,
kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')))
kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')),
bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',),
rules=(('out-alias', 'out'),)),
)
x = jax.numpy.ones((1, 32))
linen_vars = linen_module.init(jax.random.key(0), x)
nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
# nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box.

@nnx.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x)
state = nnx.state(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
nnx.update(model, sharded_state)
return model
with self.mesh:
nnx_model = create_sharded_nnx_module(x)

# nn.Partitioned metadata boxes translated into valid nnx.Variable boxes.
self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned)
self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned)
self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable)
np.testing.assert_array_equal(linen_vars['params']['kernel'].value,
nnx_model.params['kernel'].value)
assert nnx_model.params['kernel'].sharding == ('in', 'out')
_, nnx_state = nnx.split(nnx_model)
self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState)
np.testing.assert_array_equal(linen_vars['params']['kernel'].value,
nnx_state['params']['kernel'].value)
assert nnx_state['params']['kernel'].sharding == ('in', 'out')
assert nnx_model.params['kernel'].value.sharding.is_equivalent_to(
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2)

assert nnx_model.params['bias'].sharding == ('out-alias',)
assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),)
assert nnx_model.params['bias'].value.sharding.is_equivalent_to(
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1)


##################
Expand Down Expand Up @@ -306,7 +328,9 @@ class LinenMiddle(nn.Module):
@nn.compact
def __call__(self, x):
dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot')
b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout))
logical_init = nn.with_logical_partitioning(
nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out')))
b = self.param('b', logical_init, (1, self.dout))
return dot(x) + b

class NNXOuter(nnx.Module):
Expand Down Expand Up @@ -335,6 +359,7 @@ def __call__(self, x):
self.assertIsInstance(w, nnx.Param)
np.testing.assert_allclose(model(x), x @ w + b)
assert hasattr(w, 'sharding') and w.sharding == ('in', 'out')
assert hasattr(b, 'sharding') and b.sharding == ('out-alias', )

def test_linen_nnx_linen(self):
# TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def f(m: Foo):

def test_apply_shardings(self):
n_devices = max(jax.local_device_count() // 2, 1)
devices = mesh_utils.create_device_mesh((n_devices, n_devices))
devices = mesh_utils.create_device_mesh((n_devices, jax.local_device_count() // n_devices))
mesh = jax.sharding.Mesh(devices, ('a', 'b'))

def sharding(*args):
Expand Down
Loading