-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Add scan_with_axes #18341
base: main
Are you sure you want to change the base?
[Flax] Add scan_with_axes #18341
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
UpdateCurrent API:
model.scan_enable() # to enable scan
model.scan_disable() # to disable scan
model.scan_enable() # to enable scan in the nn.Module
params = model.convert_unroll_to_scan(params) # to convert the unrolled params to scan
model.scan_disable() # to disable scan in the nn.Module (i.e. unrolled)
params = model.convert_scan_to_unroll(params) # to convert the scan params to unrolled With automatic init, the params are converted from unrolled to scan under the hood ( The params are converted to/from scan on the fly, with no memory overhead for conversion (c.f. for i in range(self.config.num_hidden_layers):
# Stack the params for the N layers into one super block
# and remove the unrolled layer params on the fly
# -> no memory overhead for conversion!
unrolled_layer = params.pop(key.replace("0", str(i)))
stacked_params.append(unrolled_layer) Design question!Should we include the unroll -> scan weight conversion in the Possible cases (loaded params, flag for
|
What does this PR do?
Adds
scan_with_axes
to Flax Bert and its derived models.TODO:
make fix-copies
(after review)Fixes #17399
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.