Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Prompting max_new_tokens #25422

Closed
2 of 4 tasks
Helene-Maxcici opened this issue Aug 9, 2023 · 6 comments · Fixed by #26164
Closed
2 of 4 tasks

Whisper Prompting max_new_tokens #25422

Helene-Maxcici opened this issue Aug 9, 2023 · 6 comments · Fixed by #26164

Comments

@Helene-Maxcici
Copy link

Helene-Maxcici commented Aug 9, 2023

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (False)
  • Tensorflow version (GPU?): 2.12.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.1 (cpu)
  • Jax version: 0.4.14
  • JaxLib version: 0.4.14
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Bug Related

We keep model.config.max_length=448. The error happens when:

  1. len(prompt_ids) + max_new_tokens > model.config.max_length + 1
  2. We fix max_new_tokens in model.generate()
  3. The length of the generated new tokens reaches its maximum. This mainly occurs when Whisper fails to predict the eos token and starts repeating some sequence of tokens.
from transformers import (WhisperFeatureExtractor, WhisperProcessor, WhisperForConditionalGeneration)
from datasets import load_dataset

# Load dataset
fleurs_fr = load_dataset("google/fleurs", "fr_fr", split="test")

# Load Processor + Model
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# Chosen a sample that causes repetition
i = 512
input_speech = fleurs_fr[i]["audio"]["array"]
sr = fleurs_fr[i]["audio"]["sampling_rate"]

# Create big enough prompt text
# It should be sliced inside generate anyway
prompt_text = " bien," * 113
prompt_ids = processor.get_prompt_ids(prompt_text)

# Generate
input_features = processor(input_speech, return_tensors="pt",
                            sampling_rate=16e3).input_features

output_with_prompt = model.generate(input_features,
                                    language="fr",
                                    task="transcribe",
                                    prompt_ids= prompt_ids,
                                    max_new_tokens=224)

Output:

IndexError                                Traceback (most recent call last)
[<ipython-input-4-3420d576291f>](https://localhost:8080/#) in <cell line: 4>()
      2                             sampling_rate=16e3).input_features
      3 
----> 4 output_with_prompt = model.generate(input_features,
      5                                     language="fr",
      6                                     task="transcribe",

3 frames
[/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py](https://localhost:8080/#) in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, return_token_timestamps, **kwargs)
   1747                 )
   1748 
-> 1749         outputs = super().generate(
   1750             inputs,
   1751             generation_config,

[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

[/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py](https://localhost:8080/#) in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1536 
   1537             # 11. run greedy search
-> 1538             return self.greedy_search(
   1539                 input_ids,
   1540                 logits_processor=logits_processor,

[/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py](https://localhost:8080/#) in greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2370                 continue  # don't waste resources running the code we don't need
   2371 
-> 2372             next_token_logits = outputs.logits[:, -1, :]
   2373 
   2374             # pre-process distribution

IndexError: index -1 is out of bounds for dimension 1 with size 0

The bug might be caused by no condition set on max_new_tokens inside the generate() function, which might be a general bug for generation and not only for prompting.

Note

Also, as I was reading the code I noticed this line:
text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]

It slices the text prompt ids and takes (self.config.max_length // 2 + 1) tokens instead of (self.config.max_length // 2 - 1) as taken in the original code of Whisper here.

Expected behavior

  • Clear warning or error about surpassing the model.max_length.
  • Being able to set max_new_tokens=224 ( = max_length // 2) during prompting.
@Helene-Maxcici Helene-Maxcici changed the title Whisper max_new_tokens/Prompting Whisper Prompting, max_new_tokens/ Aug 9, 2023
@Helene-Maxcici Helene-Maxcici changed the title Whisper Prompting, max_new_tokens/ Whisper Prompting max_new_tokens Aug 9, 2023
@connor-henderson
Copy link
Contributor

Hi @Helene-Maxcici! Thanks for writing this issue, there’s definitely an out of bounds issue here.

Appreciate you catching the precedence issue that the slicing doesn’t quite match OpenAI’s, we should change that in the fix PR so its slicing one less than half the max_length instead one one more than half. Ultimately it’s not at the root of this problem since the prompt isn’t competing for space with anything else, like a prefix, and we could just decrement the max_new_tokens param by 1 and this script would run, or alternatively after updating the slicing to match OpenAI’s we could still increment max_new_tokens by 2 to 226 and it would still have this error.

Instead, I think the issue is that the length stopping criteria warning here doesn’t capture the out of bounds issue for this model since the it looks here for max_position_embeddings in the generation_config, but the value is named max_target_positions for Whisper. Not sure if Hugging Face would prefer that we rename the value in Whisper’s generation config to max_position_embeddings or add a second config attribute check for max_target_positions to determine what to pass to the stopping criteria, or something else but @sanchit-gandhi could say more

@M-Ali-ML
Copy link

I'm not sure if this will help or not but I faced the same error running

generated_tokens = (
                model.generate(
                    input_features=batch["input_features"].to("cuda"),
                    decoder_input_ids=batch["labels"][:, :4].to("cuda"),
                    max_new_tokens=448,
                )

However if I use PEFT model as in

model = WhisperForConditionalGeneration.from_pretrained(
  peft_config.base_model_name_or_path, device_map="auto", load_in_8bit=True)
  model = PeftModel.from_pretrained(model, evaluate_model)

I don't face this issue if I set the max_new_tokens to 224 in either case (PEFT or without)

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Aug 11, 2023

Thanks for the excellent issue description @Helene-Maxcici and for the astute remarks @connor-henderson! IMO each of the findings deserves a PR of its own:

  • For the max length issue, I think the best thing we can do is throw a warning in the .generate method for Whisper when the model's max length is exceeded. Probably, this can be placed after we determine the correct max_length / max_new_tokens with prompting:
    # If the user passes `max_new_tokens`, increase its number to account for the prompt
    I would be against changing the config/generation_config for the model, since this is very difficult to do without breaking changes. Since Whisper is quite unique in its approach to prompting, I think we're safe to just add a check in the Whisper model's .generate method, rather than the more generic one (cc @gante)
  • Agree with your spot and @connor-henderson's remarks with the slicing difference: this would be a quick PR to fix!

Would you like to open a PR for one or both of these issues @Helene-Maxcici? Happy to help guide the integration process, or answer any questions / queries along the way!

@Helene-Maxcici
Copy link
Author

Hi @sanchit-gandhi , thank you for your response! I would be happy to open a PR for each.

@gante
Copy link
Member

gante commented Aug 16, 2023

Thank you for opening a well-explained issue, @Helene-Maxcici! 🤗

Since this issue is particular to Whisper, which modifies max_new_tokens in its generate function, I agree -- we should add a warning in Whisper's generate (cc @sanchit-gandhi)

@sanchit-gandhi
Copy link
Contributor

The slicing bug was fixed by @connor-henderson in #23724. The check for exceeding the max length of the model should be fixed by #26164.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants