diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 0e1414aa08bb0b..f16b46ce6b68e4 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -735,10 +735,11 @@ def __call__(self, input_ids, scores): class ForceTokensLogitsProcessor(LogitsProcessor): - r"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they - are sampled at their corresponding index.""" + r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token + indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are + sampled at their corresponding index.""" - def __init__(self, force_token_map): + def __init__(self, force_token_map: List[List[int]]): self.force_token_map = dict(force_token_map) def __call__(self, input_ids, scores): diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 25e287d3875a15..1a6e0ba97b817c 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -547,10 +547,11 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf. class TFForceTokensLogitsProcessor(TFLogitsProcessor): - r"""This processor can be used to force a list of tokens. The processor will set their log probs to `0` and all - other tokens to `-inf` so that they are sampled at their corresponding index.""" + r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token + indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to + `-inf` so that they are sampled at their corresponding index.""" - def __init__(self, force_token_map): + def __init__(self, force_token_map: List[List[int]]): force_token_map = dict(force_token_map) # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the # index of the array corresponds to the index of the token to be forced, for XLA compatibility. diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 14d4d0072e4d49..68a3f27c9e8eac 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -406,7 +406,7 @@ def generate( forced_eos_token_id=None, suppress_tokens: Optional[List[int]] = None, begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[List[int]]] = None, **model_kwargs, ) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: r""" @@ -506,8 +506,10 @@ def generate( begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`): A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`): - A list of tokens that will be forced as beginning tokens, before sampling. + forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): + A list of pairs of integers which indicates a mapping from generation indices to token indices that + will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always + be a token of index 123. model_specific_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. @@ -1493,9 +1495,10 @@ def _generate( begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`): A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`): - A list of tokens that will be forced as beginning tokens. - + forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): + A list of pairs of integers which indicates a mapping from generation indices to token indices that + will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always + be a token of index 123. model_kwargs: Additional model specific kwargs will be forwarded to the `call` function of the model. @@ -2147,7 +2150,7 @@ def _get_logits_processor( forced_eos_token_id: int, suppress_tokens: Optional[List[int]] = None, begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[List[int]]] = None, ) -> TFLogitsProcessorList: """ This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6356ead5441dcc..ad533a06f1db5b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -696,7 +696,7 @@ def _get_logits_processor( renormalize_logits: Optional[bool], suppress_tokens: Optional[List[int]] = None, begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[List[int]]] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] @@ -956,7 +956,7 @@ def generate( exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, suppress_tokens: Optional[List[int]] = None, begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[List[int]]] = None, **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" @@ -1121,9 +1121,10 @@ def generate( begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`): A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`): - A list of tokens that will be forced as beginning tokens, before sampling. - + forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): + A list of pairs of integers which indicates a mapping from generation indices to token indices that + will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always + be a token of index 123. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs