-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Align bridge variable tree structures #4194
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
8eceb3c
to
172e832
Compare
flax/nnx/bridge/wrappers.py
Outdated
import dataclasses | ||
import typing as tp | ||
from typing import Any | ||
|
||
from flax import nnx | ||
from flax import linen | ||
from flax import traverse_util |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: flax.nnx.traversals
has similar same APIs but with more accurate typing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Adopted.
flax/nnx/bridge/wrappers.py
Outdated
@@ -85,6 +87,50 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): | |||
return fn | |||
|
|||
|
|||
def _recursive_merge(dict1, dict2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be expressed using traversals
/ traverse_util
as:
flat_map = traversal.flatten_mapping(dict1)
flat_map |= traversal.flatten_mapping(dict2)
return traversal.unflatten_mapping(flat_map)
Not sure if there are edge cases though but it seems easy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Adopted.
flax/nnx/bridge/wrappers.py
Outdated
dict1[key] = _recursive_merge(dict1[key], value) | ||
else: | ||
# Merge non-dictionary values | ||
dict1[key] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is mutating dict1
always safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using nnx.traversals
is indeed easier!
flax/nnx/bridge/wrappers.py
Outdated
def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: | ||
linen_structured = {} | ||
for kp, v in traverse_util.flatten_dict( | ||
nnx_attrs, is_leaf=lambda _, x: isinstance(x, nnx.Variable | nnx.GraphDef)).items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know Unions were now valid in isinstance
, nice!
55cbdef
to
cac253e
Compare
Fixing bridge API in a bunch of ways:
ToNNX
will convert the whole variables structure to NNX style. If your underlying Linen module has variablefoo
at collectionbar
, itsToNNX
version will have an attributefoo
with typebar
, instead of an attributebar
with a dict{'foo': ...}
.This means you can freely put
ToNNX
in the top or middle or back of the whole model layer stack, and the weight pytree structure shouldn't change.Same goes for
ToLinen
- if your top-level type is Linen, the whole variable tree shall be Linen-style.If you have a vanilla
nnx.Variable
with no sharding metadata, hooks, etc,ToLinen
will not convert it into anNNXMeta
, but instead just keep the vanilla JAX array inside. This makes it more intuitive and pytree-structure-proof for any Linen users not using partitioning metadata.nn.get_partition_spec
now works onNNXMeta
wrappers, and any other wrapper that hasget_partition_spec
method.Updated the
nnx.bridge
guide accordingly.