Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two bugs in whisper generate with prompt_ids regarding generation length #23723

Closed
2 of 4 tasks
connor-henderson opened this issue May 23, 2023 · 1 comment · Fixed by #23724
Closed
2 of 4 tasks

Two bugs in whisper generate with prompt_ids regarding generation length #23723

connor-henderson opened this issue May 23, 2023 · 1 comment · Fixed by #23724

Comments

@connor-henderson
Copy link
Contributor

connor-henderson commented May 23, 2023

System Info

  • transformers version: 4.30.0.dev0
  • Platform: macOS-13.0-arm64-arm-64bit
  • Python version: 3.9.16
  • Huggingface_hub version: 0.12.0
  • Safetensors version: 0.2.8
  • PyTorch version (GPU?): 1.13.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.5.3 (cpu)
  • Jax version: 0.3.6
  • JaxLib version: 0.3.5
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: no

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

# -*- coding: utf-8 -*-
# the above line is for the `prompt_for_error`

from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="English", task="transcribe")

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

it = iter(load_dataset("librispeech_asr", "all", split="test.other", streaming=True))
while it:
  _ = [next(it) for x in range(3)]
  clip = next(it)
  if clip["id"] == '7902-96592-0026':
    break

input_features = processor(clip['audio']['array'], sampling_rate=clip['audio']['sampling_rate'], return_tensors="pt").input_features


# Example of it not limiting generation to max_new_tokens when prompt_ids length too large 
long_prompt = 5 * "Bubalina is a subtribe of wild cattle that includes the various species of true buffalo. Species include the African buffalo, the anoas, and the wild water buffalo (including the domesticated variant water buffalo. Buffaloes can be found naturally in sub-Saharan Africa, South Asia and Southeast Asia, and domestic and feral populations have been introduced to Europe, the Americas, and Australia. In addition to the living species, bubalinans have an extensive fossil record where remains have been found in much of Afro-Eurasia."
prompt_ids = processor.get_prompt_ids(long_prompt)
pred_ids = model.generate(input_features, language="english", task="transcribe", max_new_tokens=10, prompt_ids=prompt_ids)
decoded = processor.decode(pred_ids[0], skip_special_tokens=True)
new_tokens = processor.tokenizer(decoded, add_special_tokens=False)["input_ids"]
print(len(new_tokens)) # should be <=10, is actually 25

# Example of erroring
prompt_for_error = "some text rich in domain specific vocabulary lives here - I wish you would believe me that I am in as great trouble about it as you are - then as archiestered in the dark literally a gas for the astonishment here at the faint and wrestling once more and again all with silent - I'll soon show them that I am not going to be played with - to do this he must scheme lie head till morning then make for the nearest point it's signal for help I also boats crew were already searching for him how to escape - no that was too bad you cannot do that - but there was no chance for his body there the head would not go first - shall I come to father? no - what a queer dream he thought to himself - and I am hungry too 今晚會是我 再回家吧 - oh those bars he meant 雷 exclaimed and he was  advancing towards them, and just as he drew near there was a wrestling noise nd to the window a couple of hands seized the bars there was a scratching of 布側 against stonework and ram スペース 敬射的 金融 敬射的 金融 敬射的 金融 敬射的 金融 敬射的 金融 敬射的 金融 � - I saw you last night and wondered whose boy he was - I think I don't know you Mr. Orphazard "
prompt_ids = processor.get_prompt_ids(prompt_for_error)
pred_ids = model.generate(input_features, language="english", task="transcribe", max_new_tokens=128, prompt_ids=prompt_ids)

Expected behavior

Two issues arising when using whisper generate with prompt_ids:

  1. max_new_tokens doesn't properly limit the generation of new tokens when the length of the provided prompt_ids is too large
  2. An unclear error is thrown with certain long prompt + audio combinations, less clear on this one right now (thank you @dgram0 for raising this in feat: Whisper prompting #22496 (comment))

I believe they have the same root cause where if prompt_ids are provided, the max_new_tokens is recalculated using the length of the text_prompt_ids but before they are trimmed to fit within the context. I'm not certain yet how 2. is caused / fixed by this, but I think its because with a confusing prompt + audio combo the model doesn't know when to stop and needs max_new_tokens to be set properly, otherwise it'll index error. I can confirm that fixing the max_new_tokens recalculation fixes both issues in the example script.

@sanchit-gandhi
Copy link
Contributor

Thanks for the detailed write-up and reproducible code snippet @connor-henderson! Cool that you've found a fix to both already 🙌 By the sounds of it, I agree that the PR should fix both issues by bumping the token slicing logic to before the change of max new tokens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants