-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Scan & Gradient checkpointing in Flax #17399
Comments
I'm not sure you would need both versions within a same script (scan and unscanned, or with and without checkpointing which affects only training anyway). Then maybe you could just add it directly as an arg to You would just have to use some naming conventions on your params to see if you need to scan/unscan when loading a checkpoint. |
Suppose you have a training script, it would be useful to be able to use |
I'm not sure it would be worth it:
|
Hey @patrickvonplaten, I'm keen to get gradient checkpointing working in JAX for long-t5. If this is not on the cards to be added soon happy to work on a PR for it if that works with you all? |
Feature request
We should add scan and remat (gradient checkpointing) to the most important Flax/JAX models (BERT, GPT2, OPT, T5, BART, Wav2Vec2).
Motivation
Scan allows for much faster compilation and memory savings and
remat
is the equivalent ofgradient_checkpointing
in PyTorch.@sanchit-gandhi already uses both features in the Flax Seq2Seq Speech project - see: https://github.com/sanchit-gandhi/seq2seq-speech so it'd be quite trivial to get them working.
Implementation details:
Given that both
scan
andremat
are not related to the model architecture, they should IMO not be in the model's config (We've done this mistake in PyTorch and don't want to repeat it here).I would advocate for the following API:
and
As can be seen here: https://github.com/sanchit-gandhi/seq2seq-speech/blob/b28d0c25c8fad0f9ffa6707f91f7aba320d44a4b/models/modeling_flax_wav2vec2.py#L504
We'll need to re-initialize the
flax.linen.module
inside the model. However this should be fine since it just means that we dosimilar to this line:
transformers/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Line 868 in 71e6027
We can see along the PR how much logic can reside in
modeling_flax_utils.py
and how much would go into the specific models, e.g.modeling_flax_wav2vec2.py
.The same API / logic could be used for the
gradient_checkpointing
.Your contribution
Happy to give this implementation a shot with @sanchit-gandhi and @patil-suraj .
Also would love to hear feedback from @borisdayma @marcvanzee about the API
The text was updated successfully, but these errors were encountered: