Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient checkpointing to Whisper Flax #22954

Merged

Conversation

versae
Copy link
Contributor

@versae versae commented Apr 24, 2023

It uses flax.linen.remat and follows on PRs #13657 and #17994.

What does this PR do?

Adds gradient_checkpointing to Flax Whisper models.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@sanchit-gandhi @peregilk

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 24, 2023

The documentation is not available anymore as the PR was closed or merged.

@versae versae mentioned this pull request Apr 24, 2023
5 tasks
@versae versae marked this pull request as ready for review April 24, 2023 10:59
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice @versae! Just a few minor nits, but otherwise this PR is looking good!

src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
attention_mask,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: remat does not support key-word arguments, hence the need to change to pure arguments

src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_flax_whisper.py Outdated Show resolved Hide resolved
@versae
Copy link
Contributor Author

versae commented Apr 26, 2023

Thanks for the review, @sanchit-gandhi! Should be all good now 😃.

@sanchit-gandhi
Copy link
Contributor

Amazing @versae! Requesting final review before we can get this merged 🤗

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@sgugger sgugger merged commit ba0dc54 into huggingface:main Apr 26, 2023
@versae versae deleted the add-gradient-checkpointing-whisper-flax branch April 27, 2023 08:13
@versae
Copy link
Contributor Author

versae commented Apr 27, 2023

Thank you! I learnt a lot 🤓

gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* Add gradient checkpointing to Whisper Flax

* self.gradient_checkpointing only needed in nn.Module, removing unnecessary comments
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Add gradient checkpointing to Whisper Flax

* self.gradient_checkpointing only needed in nn.Module, removing unnecessary comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants