Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax whisper gradient checkpointing #22897

Closed

Conversation

versae
Copy link
Contributor

@versae versae commented Apr 20, 2023

It uses flax.linen.remat and follows on PRs #13657 and #17994.

What does this PR do?

Adds gradient_checkpointing to Flax Whisper models.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@sanchit-gandhi @peregilk

@versae
Copy link
Contributor Author

versae commented Apr 20, 2023

At the moment, the model loads fine but I then get a weird error when training or generating:

/data/venvflax/lib/python3.8/site-packages/transformers/models/whisper/modeling_flax_whisper.py: │
│ 520 in __call__                                                                                  │
│                                                                                                  │
│    517 │   │   │   residual = hidden_states                                                      │
│    518 │   │   │                                                                                 │
│    519 │   │   │   hidden_states = self.encoder_attn_layer_norm(hidden_states)                   │
│ ❱  520 │   │   │   hidden_states, cross_attn_weights = self.encoder_attn(                        │
│    521 │   │   │   │   hidden_states=hidden_states,                                              │
│    522 │   │   │   │   key_value_states=encoder_hidden_states,                                   │
│    523 │   │   │   │   attention_mask=encoder_attention_mask,                                    │
│                                                                                                  │
│ /data/venvflax/lib/python3.8/site-packages/transformers/models/whisper/modeling_flax_whisper.py: │
│ 256 in __call__                                                                                  │
│                                                                                                  │
│    253 │   │   elif self.causal:                                                                 │
│    254 │   │   │   attention_mask = causal_mask                                                  │
│    255 │   │   elif attention_mask is not None:                                                  │
│ ❱  256 │   │   │   attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))               │
│    257 │   │                                                                                     │
│    258 │   │   # During fast autoregressive decoding, we feed one position at a time,            │259 │   │   # and cache the keys and values step by step.                                     │
│                                                                                                  │
│ /data/venvflax/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:896 in expand_dims        │
│                                                                                                  │
│    893   axis = _ensure_index_tuple(axis)                                                        │
│    894   if hasattr(a, "expand_dims"):                                                           │
│    895return a.expand_dims(axis)                                                            │
│ ❱  896   return lax.expand_dims(a, axis)                                                         │
│    897                                                                                           │
│    898                                                                                           │
│    899 @_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: axis -3 is out of bounds for array of dimension 2

I'm not sure what's happening. So I thought maybe @sanchit-gandhi could provide some feedback :)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@versae
Copy link
Contributor Author

versae commented Apr 22, 2023

I've been digging and the only difference I can find is that for some reason the parameters for calling FlaxWhisperDecoderLayerCollection.__call__() in FlaxWhisperDecoder.__call__() are different in this PR's model than in the original implementation. I tested this using a tiny model

Original model

encoder_attention_mask=None
deterministic=True
output_hidden_states=False

This PR's model:

encoder_attention_mask=True
deterministic=False
output_hidden_states=True

The rest of params are the same: hidden_states, attention_mask, encoder_hidden_states, init_cache, output_attentions and return_dict. The problem is that while the first decoder layers loads fine, the second one gets an attention_mask value of True for some reason, making any tensor operation to fail.

@versae
Copy link
Contributor Author

versae commented Apr 22, 2023

All passing! The main issue was a missing self.gradient_checkpointing in the FlaxWhisperPreTrainedModel.__init__() function. Took me forever to debug it.

I'll clean up the git history mess, but other than that I think it's finally ready :)

@versae versae marked this pull request as ready for review April 22, 2023 18:10
@versae
Copy link
Contributor Author

versae commented Apr 24, 2023

Closing in favor of #22954.

@versae versae closed this Apr 24, 2023
@versae versae deleted the flax-whisper-gradient-checkpointing branch April 27, 2023 08:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants