Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax T5 models (at least) should use scan over layers technique #27418

Closed
colehaus opened this issue Nov 9, 2023 · 3 comments
Closed

Flax T5 models (at least) should use scan over layers technique #27418

colehaus opened this issue Nov 9, 2023 · 3 comments

Comments

@colehaus
Copy link

colehaus commented Nov 9, 2023

Feature request

See the technique described here and here. The essence is using a JAX scan instead of a python loop to iterate over layers that have the same structure.

Motivation

The scan over layers technique allows JAX to "see" that the computational structure of each iteration is the same. This can dramatically reduce compile time and also system memory occupied by the JAX compilation cache (i.e. I believe if you have 25 layers in a model, the naive approach will end up with ~25 times as much JIT-compiled code since each layer will result in duplicative output code). My handwritten T5-like model uses ~1/50th of the system memory of the transformers Flax T5 models of similar size. It's easy to get system OOM errors with the current Flax implementation if you end up with multiple versions of the model compiled for different sequence lengths.

Your contribution

It's possible I could submit a PR for this at some point in the future, but I can't be certain.

@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Nov 22, 2023

Hey @colehaus! Sorry for the late reply here. We've currently decided not to implement scan for the Flax models in Transformers. You can see a brief reason for this here: #24587 (comment)

Happy to re-open the conversation if you feel strongly about this! There was a WIP PR that shows how this could be done generally for Transformers models here: #18341

But currently I tend to view scan as a specific feature that can be built on top of the Transformers library by advanced users who require it.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants