-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper: fix prompted max length #24666
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
Just a comment on making applying checks or upper bounds on what we can set as max_new_tokens
# If the user passes `max_new_tokens`, increase its number to account for the prompt | ||
if kwargs.get("max_new_tokens", None) is not None: | ||
kwargs["max_new_tokens"] += len(text_prompt_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I've understood correctly, the previous issue was in part because "max_new_tokens"
is not set by default and therefore specified_max_length
defaulted to max_length
- the max length of the model.
However, the issue was found because it resulted in max_new_tokens
being set to max_length + len(text_prompt_ids)
, resulting in out of bounds, which could still happen (we could set max_new_tokens
to max_length
.
Could we either:
- Place an upper bound on the value of
max_new_tokens
- Or raise a warning if it's going out of bounds?
e.g.:
if kwargs.get("max_new_tokens", None) is not None:
max_new_tokens_w_prompt = kwargs.get("max_new_tokens") + len(text_prompt_ids)
kwargs["max_new_tokens"] = min(max_length, max_new_tokens_w_prompt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, that check should be done at a model level -- some models accept going beyond its maximum length (e.g. rotary and alibi position embeddings), so it makes more sense to place that check in the model, and not on generate.
ATM, we don't do any check of any form, regardless of the model. Should we open a PR to add an informative exception on models with restrictive position embeddings (like Whisper)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A warning message when you set max_length > max_position_embeddings
would be pretty useful for models like Whisper that have a fixed max length (note that it can be a warning message since we might predict the EOS before we hit max_position_embeddings
tokens so the generation could still be valid). Otherwise they fail silently with a very cryptic error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts @sanchit-gandhi alright, we're aligned in terms of the need for additional checks and messaging 👍
I'm not fully happy with emitting a warning as soon as we cross current_length > max_position_embeddings
, as some models can safely cross this limit, but the alternatives (that I've envisioned) have a high engineering cost -- I'm going to add a warning and I'll tag you again when it's included :)
After the latest changes, a warning is emitted when we cross For instance, if you now run from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to("cuda")
inputs = tokenizer(["The quick brown"], return_tensors="pt").to("cuda")
# distilgpt2 has a maximum length of 1024
gen_out = model.generate(**inputs, do_sample=True, eos_token_id=-1, max_length=1025) You'll see
(And if you set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing and iterating on the warning! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice - thanks @gante
What does this PR do?
Fixes #24600
#23724 Added the ability to guide generation with Whiper through
prompt_ids
. It was increasing the generation length by the length of the prompt -- these tokens were being hardcoded, and thus "not generated".However, in the default case, we were already setting the generation length to the maximum allowed model length (see model config). This increment was forcing us to go behind the maximum length and, because the model uses a
nn.Embedding
for the position embedding, indexing exceptions started popping up on long audio inputs :DThis PR modifies the length extension to what I believe was the author's original goal: only increment the length if
max_new_tokens
is passed. By default, this argument is not set and should correspond to the "new" (=non-prompt) generated tokens.