Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add scan_with_axes #18341

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Jul 28, 2022

What does this PR do?

Adds scan_with_axes to Flax Bert and its derived models.

TODO:

  • Fix cookie cutter template
  • Run make fix-copies (after review)

Fixes #17399

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi sanchit-gandhi added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jul 28, 2022
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Aug 23, 2022

Update

Current API:

  1. With automatic init:
model.scan_enable()  # to enable scan
model.scan_disable()  # to disable scan
  1. Without automatic init:
model.scan_enable()  # to enable scan in the nn.Module
params = model.convert_unroll_to_scan(params) # to convert the unrolled params to scan

model.scan_disable()  # to disable scan in the nn.Module (i.e. unrolled)
params = model.convert_scan_to_unroll(params) # to convert the scan params to unrolled

With automatic init, the params are converted from unrolled to scan under the hood (.convert_unroll_to_scan()).

The params are converted to/from scan on the fly, with no memory overhead for conversion (c.f. .convert_unroll_to_scan()):

for i in range(self.config.num_hidden_layers):
    # Stack the params for the N layers into one super block
    # and remove the unrolled layer params on the fly
    # -> no memory overhead for conversion!
    unrolled_layer = params.pop(key.replace("0", str(i)))
    stacked_params.append(unrolled_layer)

Design question!

Should we include the unroll -> scan weight conversion in the .from_pretrained() method if use_scan=True is passed? Currently, setting use_scan=True will enable scan in the nn.Module, but will leave the params untouched. This is fine if loading scanned params from pre-trained. If loading unrolled params from pre-trained, the shape of the weights will not match those expected by the nn.Module (unrolled params, scanned nn.Module), meaning the weights will not be loaded!

Possible cases (loaded params, flag for use_scan):

Params use_scan Mismatch Action
Unrolled False None Fine!
Unrolled True params unrolled, nn.Module scanned Requires conversion of params to scan
Scan False params scan, nn.Module unrolled Requires conversion of params to unrolled
Scan True None Fine!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Scan & Gradient checkpointing in Flax
2 participants