-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support linen.LogicallyPartitioned <-> nnx.Variable #4161
Conversation
fdb4e73
to
821bc64
Compare
flax/core/meta.py
Outdated
metadata['sharding'] = metadata['names'] | ||
del metadata['names'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metadata['sharding'] = metadata['names'] | |
del metadata['names'] | |
metadata['sharding'] = metadata.pop('names') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea!
flax/core/meta.py
Outdated
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" | ||
metadata['names'] = metadata['sharding'] | ||
del metadata['sharding'] | ||
fields = [x.name for x in dataclasses.fields(cls)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels like fields
should be set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dataclasses.fields
should already return de-duplicated names, otherwise the class cannot be built... but yeah we can add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I mean because in the line below we just use it to check for membership
flax/linen/spmd.py
Outdated
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" | ||
metadata['names'], metadata['rules'] = metadata['sharding'], metadata['sharding_rules'] | ||
del metadata['sharding'], metadata['sharding_rules'] | ||
fields = [x.name for x in dataclasses.fields(cls)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
nnx.spmd
API will now recognize anothernnx.Variable
field:sharding_rules
linen.LogicallyPartitioned
boxes and convert them to this format.