-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax whisper gradient checkpointing #22897
Flax whisper gradient checkpointing #22897
Conversation
It uses `flax.linen.remat` and follows on PRs huggingface#13657 and huggingface#17994
At the moment, the model loads fine but I then get a weird error when training or generating: │ /data/venvflax/lib/python3.8/site-packages/transformers/models/whisper/modeling_flax_whisper.py: │
│ 520 in __call__ │
│ │
│ 517 │ │ │ residual = hidden_states │
│ 518 │ │ │ │
│ 519 │ │ │ hidden_states = self.encoder_attn_layer_norm(hidden_states) │
│ ❱ 520 │ │ │ hidden_states, cross_attn_weights = self.encoder_attn( │
│ 521 │ │ │ │ hidden_states=hidden_states, │
│ 522 │ │ │ │ key_value_states=encoder_hidden_states, │
│ 523 │ │ │ │ attention_mask=encoder_attention_mask, │
│ │
│ /data/venvflax/lib/python3.8/site-packages/transformers/models/whisper/modeling_flax_whisper.py: │
│ 256 in __call__ │
│ │
│ 253 │ │ elif self.causal: │
│ 254 │ │ │ attention_mask = causal_mask │
│ 255 │ │ elif attention_mask is not None: │
│ ❱ 256 │ │ │ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) │
│ 257 │ │ │
│ 258 │ │ # During fast autoregressive decoding, we feed one position at a time, │
│ 259 │ │ # and cache the keys and values step by step. │
│ │
│ /data/venvflax/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:896 in expand_dims │
│ │
│ 893 axis = _ensure_index_tuple(axis) │
│ 894 if hasattr(a, "expand_dims"): │
│ 895 │ return a.expand_dims(axis) │
│ ❱ 896 return lax.expand_dims(a, axis) │
│ 897 │
│ 898 │
│ 899 @_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: axis -3 is out of bounds for array of dimension 2 I'm not sure what's happening. So I thought maybe @sanchit-gandhi could provide some feedback :) |
The documentation is not available anymore as the PR was closed or merged. |
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
I've been digging and the only difference I can find is that for some reason the parameters for calling Original model encoder_attention_mask=None
deterministic=True
output_hidden_states=False This PR's model: encoder_attention_mask=True
deterministic=False
output_hidden_states=True The rest of params are the same: |
All passing! The main issue was a missing I'll clean up the git history mess, but other than that I think it's finally ready :) |
Closing in favor of #22954. |
It uses
flax.linen.remat
and follows on PRs #13657 and #17994.What does this PR do?
Adds gradient_checkpointing to Flax Whisper models.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sanchit-gandhi @peregilk