Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: fix prompted max length #24666

Merged
merged 5 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import torch

from ..utils import add_start_docstrings
from ..utils import add_start_docstrings, logging


logger = logging.get_logger(__name__)


STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -46,14 +49,25 @@ class MaxLengthCriteria(StoppingCriteria):
Args:
max_length (`int`):
The maximum length that the output sequence can have in number of tokens.
max_position_embeddings (`int`, `optional`):
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
"""

def __init__(self, max_length: int):
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
self.max_length = max_length
self.max_position_embeddings = max_position_embeddings

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return is_done


class MaxNewTokensCriteria(StoppingCriteria):
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,12 @@ def _get_stopping_criteria(
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList()
if generation_config.max_length is not None:
criteria.append(MaxLengthCriteria(max_length=generation_config.max_length))
criteria.append(
MaxLengthCriteria(
max_length=generation_config.max_length,
max_position_embeddings=self.config.max_position_embeddings,
)
)
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,11 +1715,9 @@ def generate(
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})

# Update the max generation length to include the prompt
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
default_max_length = generation_config.max_new_tokens or generation_config.max_length
non_prompt_max_length = specified_max_length or default_max_length
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if kwargs.get("max_new_tokens", None) is not None:
kwargs["max_new_tokens"] += len(text_prompt_ids)
Comment on lines +1718 to +1720
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I've understood correctly, the previous issue was in part because "max_new_tokens" is not set by default and therefore specified_max_length defaulted to max_length - the max length of the model.

However, the issue was found because it resulted in max_new_tokens being set to max_length + len(text_prompt_ids), resulting in out of bounds, which could still happen (we could set max_new_tokens to max_length.

Could we either:

  • Place an upper bound on the value of max_new_tokens
  • Or raise a warning if it's going out of bounds?

e.g.:

            if kwargs.get("max_new_tokens", None) is not None:
                max_new_tokens_w_prompt = kwargs.get("max_new_tokens") + len(text_prompt_ids)
                kwargs["max_new_tokens"] = min(max_length, max_new_tokens_w_prompt)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, that check should be done at a model level -- some models accept going beyond its maximum length (e.g. rotary and alibi position embeddings), so it makes more sense to place that check in the model, and not on generate.

ATM, we don't do any check of any form, regardless of the model. Should we open a PR to add an informative exception on models with restrictive position embeddings (like Whisper)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning message when you set max_length > max_position_embeddings would be pretty useful for models like Whisper that have a fixed max length (note that it can be a warning message since we might predict the EOS before we hit max_position_embeddings tokens so the generation could still be valid). Otherwise they fail silently with a very cryptic error

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts @sanchit-gandhi alright, we're aligned in terms of the need for additional checks and messaging 👍

I'm not fully happy with emitting a warning as soon as we cross current_length > max_position_embeddings, as some models can safely cross this limit, but the alternatives (that I've envisioned) have a high engineering cost -- I'm going to add a warning and I'll tag you again when it's included :)


# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
Expand Down