Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output None as attention when layer is skipped #30597

Merged

Conversation

jonghwanhyeon
Copy link
Contributor

@jonghwanhyeon jonghwanhyeon commented May 1, 2024

What does this PR do?

Fixes an out of index error caused by skipped layers when output_attentions is True in training mode on WavLM model.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi @patrickvonplaten

@amyeroberts
Copy link
Collaborator

Hi @jonghwanhyeon, thanks for adding this! Could you add a test which passes with this fix and fails on main?

cc @kamilakesbi for first review

@jonghwanhyeon jonghwanhyeon force-pushed the fix-wavlm-skip-layer-out-of-index branch from e0f9396 to 4cd97cc Compare May 1, 2024 15:04
@jonghwanhyeon
Copy link
Contributor Author

Hi @jonghwanhyeon, thanks for adding this! Could you add a test which passes with this fix and fails on main?

cc @kamilakesbi for first review

Added a test case

@kamilakesbi
Copy link
Contributor

kamilakesbi commented May 2, 2024

Hi @jonghwanhyeon,

Thank you very much for working on this!

Could you provide your working environnement (by running transformers-cli env) as well as a minimal reproducer of the bug?

I've run the test you suggested both withlayer_outputs = (None,None,None) and layer_outputs = (None,None), and the test seems to pass in both cases ( I'm using transformers version 4.41.0.dev0).

Thanks!

cc @sanchit-gandhi

@jonghwanhyeon jonghwanhyeon force-pushed the fix-wavlm-skip-layer-out-of-index branch from 2bface9 to 853a95a Compare May 2, 2024 15:42
@jonghwanhyeon jonghwanhyeon force-pushed the fix-wavlm-skip-layer-out-of-index branch from 853a95a to 4e9bafd Compare May 2, 2024 15:44
@jonghwanhyeon
Copy link
Contributor Author

Oh, that is because the value of model.config.layerdrop is so low. Please check the following code and updated commits:

import torch

from transformers import WavLMConfig, WavLMModel

model = WavLMModel(WavLMConfig())
model.config.layerdrop = 1.0
model.train()
outputs = model(torch.randn(1, 8_000), output_attentions=True)

For your information, my environment is as below (with the modifications of (None, None, None)):

  • transformers version: 4.41.0.dev0
  • Platform: Linux-5.4.0-174-generic-x86_64-with-glibc2.35
  • Python version: 3.11.9
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

@kamilakesbi
Copy link
Contributor

Ok, thanks a lot @jonghwanhyeon!

LGTM - I can reproduce the bug and your code seems to fix it. I've tried your test and it passes know.

CC @amyeroberts for a final check!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and adding a test!

@amyeroberts amyeroberts merged commit 4c94093 into huggingface:main May 2, 2024
17 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* Output `None` as attention when layer is skipped

* Add test for output_attentions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants