Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation and implementation are inconsistent for forced_decoder_ids option in GenerationMixin.generate #19602

Closed
2 of 4 tasks
koreyou opened this issue Oct 14, 2022 · 1 comment · Fixed by #19640
Closed
2 of 4 tasks

Comments

@koreyou
Copy link
Contributor

koreyou commented Oct 14, 2022

System Info

  • transformers version: 4.23.0
  • Platform: macOS-12.6-arm64-arm-64bit
  • Python version: 3.9.13
  • Huggingface_hub version: 0.10.1
  • PyTorch version (GPU?): 1.11.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

Text generation: @patrickvonplaten, @Narsil, @gante
Documentation: @sgugger, @stevhliu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained('t5-small')
model = AutoModelForSeq2SeqLM.from_pretrained('t5-small')

input = 'This is a dummy input.'
decoder_start_text = 'But is should still work, because'

input_ids = tokenizer.encode(input, return_tensors='pt')
decoder_start_ids = tokenizer.encode(decoder_start_text, add_special_tokens=False)

# This raises an error as attached below
outputs = model.generate(
    input_ids,
    forced_decoder_ids=decoder_start_ids
)

# This is against the documentation but works
outputs = model.generate(
    input_ids,
    forced_decoder_ids={i: id for i, id in enumerate(decoder_start_ids)}
)

Expected behavior

According to the documentation, GeneratorMixin.generate accepts a list of int for forced_decoder_ids . However, above reproduction raises the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 outputs = model.generate(
      2     input_ids,
      3     forced_decoder_ids=decoder_start_ids
      4 )

File ~/.pyenv/versions/3.9.13/envs/dummy_proj/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/.pyenv/versions/3.9.13/envs/dummy_proj/lib/python3.9/site-packages/transformers/generation_utils.py:1353, in GenerationMixin.generate(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, suppress_tokens, begin_suppress_tokens, forced_decoder_ids, **model_kwargs)
   1348     raise ValueError(
   1349         "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
   1350     )
   1352 # 7. prepare distribution pre_processing samplers
-> 1353 logits_processor = self._get_logits_processor(
   1354     repetition_penalty=repetition_penalty,
   1355     no_repeat_ngram_size=no_repeat_ngram_size,
   1356     encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
   1357     input_ids_seq_length=input_ids_seq_length,
   1358     encoder_input_ids=inputs_tensor,
   1359     bad_words_ids=bad_words_ids,
   1360     min_length=min_length,
   1361     max_length=max_length,
   1362     eos_token_id=eos_token_id,
   1363     forced_bos_token_id=forced_bos_token_id,
   1364     forced_eos_token_id=forced_eos_token_id,
   1365     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
   1366     num_beams=num_beams,
   1367     num_beam_groups=num_beam_groups,
   1368     diversity_penalty=diversity_penalty,
   1369     remove_invalid_values=remove_invalid_values,
   1370     exponential_decay_length_penalty=exponential_decay_length_penalty,
   1371     logits_processor=logits_processor,
   1372     renormalize_logits=renormalize_logits,
   1373     suppress_tokens=suppress_tokens,
   1374     begin_suppress_tokens=begin_suppress_tokens,
   1375     forced_decoder_ids=forced_decoder_ids,
   1376 )
   1378 # 8. prepare stopping criteria
   1379 stopping_criteria = self._get_stopping_criteria(
   1380     max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
   1381 )

File ~/.pyenv/versions/3.9.13/envs/dummy_proj/lib/python3.9/site-packages/transformers/generation_utils.py:786, in GenerationMixin._get_logits_processor(self, repetition_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, input_ids_seq_length, encoder_input_ids, bad_words_ids, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id, prefix_allowed_tokens_fn, num_beams, num_beam_groups, diversity_penalty, remove_invalid_values, exponential_decay_length_penalty, logits_processor, renormalize_logits, suppress_tokens, begin_suppress_tokens, forced_decoder_ids)
    784     processors.append(SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
    785 if forced_decoder_ids is not None:
--> 786     processors.append(ForceTokensLogitsProcessor(forced_decoder_ids))
    787 processors = self._merge_criteria_processor_list(processors, logits_processor)
    788 # `LogitNormalization` should always be the last logit processor, when present

File ~/.pyenv/versions/3.9.13/envs/dummy_proj/lib/python3.9/site-packages/transformers/generation_logits_process.py:742, in ForceTokensLogitsProcessor.__init__(self, force_token_map)
    741 def __init__(self, force_token_map):
--> 742     self.force_token_map = dict(force_token_map)

It is clear that implementation is expecting Dict[int, str] as shown in here. Hence I believe that implementation and documentation are inconsistent.

FYI, other functions in GeneratorMixin seems to expect List[int] as in the documentation.

@gante
Copy link
Member

gante commented Oct 14, 2022

Hi @koreyou 👋 The documentation is indeed incorrect -- It accepts a list of pairs integers (List[List[int]]) that can be convertible to a Dict[int, int], containing the index and the token to be forced, correspondingly (e.g. this list of lists).

Would you like to open a PR to fix the documentation? 🤗

(cc @ArthurZucker @patrickvonplaten)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants