Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: fix prompted max length #24666

Merged
merged 5 commits into from
Jul 7, 2023
Merged

Conversation

gante
Copy link
Member

@gante gante commented Jul 5, 2023

What does this PR do?

Fixes #24600

#23724 Added the ability to guide generation with Whiper through prompt_ids. It was increasing the generation length by the length of the prompt -- these tokens were being hardcoded, and thus "not generated".

However, in the default case, we were already setting the generation length to the maximum allowed model length (see model config). This increment was forcing us to go behind the maximum length and, because the model uses a nn.Embedding for the position embedding, indexing exceptions started popping up on long audio inputs :D

This PR modifies the length extension to what I believe was the author's original goal: only increment the length if max_new_tokens is passed. By default, this argument is not set and should correspond to the "new" (=non-prompt) generated tokens.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

Just a comment on making applying checks or upper bounds on what we can set as max_new_tokens

Comment on lines +1718 to +1720
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if kwargs.get("max_new_tokens", None) is not None:
kwargs["max_new_tokens"] += len(text_prompt_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I've understood correctly, the previous issue was in part because "max_new_tokens" is not set by default and therefore specified_max_length defaulted to max_length - the max length of the model.

However, the issue was found because it resulted in max_new_tokens being set to max_length + len(text_prompt_ids), resulting in out of bounds, which could still happen (we could set max_new_tokens to max_length.

Could we either:

  • Place an upper bound on the value of max_new_tokens
  • Or raise a warning if it's going out of bounds?

e.g.:

            if kwargs.get("max_new_tokens", None) is not None:
                max_new_tokens_w_prompt = kwargs.get("max_new_tokens") + len(text_prompt_ids)
                kwargs["max_new_tokens"] = min(max_length, max_new_tokens_w_prompt)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, that check should be done at a model level -- some models accept going beyond its maximum length (e.g. rotary and alibi position embeddings), so it makes more sense to place that check in the model, and not on generate.

ATM, we don't do any check of any form, regardless of the model. Should we open a PR to add an informative exception on models with restrictive position embeddings (like Whisper)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning message when you set max_length > max_position_embeddings would be pretty useful for models like Whisper that have a fixed max length (note that it can be a warning message since we might predict the EOS before we hit max_position_embeddings tokens so the generation could still be valid). Otherwise they fail silently with a very cryptic error

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts @sanchit-gandhi alright, we're aligned in terms of the need for additional checks and messaging 👍

I'm not fully happy with emitting a warning as soon as we cross current_length > max_position_embeddings, as some models can safely cross this limit, but the alternatives (that I've envisioned) have a high engineering cost -- I'm going to add a warning and I'll tag you again when it's included :)

@gante
Copy link
Member Author

gante commented Jul 7, 2023

@amyeroberts @sanchit-gandhi

After the latest changes, a warning is emitted when we cross config.max_position_embeddings for the first time.

For instance, if you now run

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to("cuda")

inputs = tokenizer(["The quick brown"], return_tensors="pt").to("cuda")
# distilgpt2 has a maximum length of 1024
gen_out = model.generate(**inputs, do_sample=True, eos_token_id=-1, max_length=1025)

You'll see

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (1024). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.

(And if you set max_length=1026, you'll see the warning right before the exceptions. This is because we can technically generate config.max_position_embeddings + 1 tokens even with restrictive position embeddings, although we shouldn't!)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and iterating on the warning! 🤗

@gante gante merged commit f614b6e into huggingface:main Jul 7, 2023
@gante gante deleted the prompted_whisper branch July 7, 2023 17:11
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice - thanks @gante

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IndexError: index -1 is out of bounds for dimension 1 with size 0
4 participants