Skip to content

Commit

Permalink
FIX Don't assume past_key_valus for encoder models
Browse files Browse the repository at this point in the history
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to huggingface#2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
  • Loading branch information
BenjaminBossan committed Oct 14, 2024
1 parent 749b924 commit a4d6e4c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
if peft_config.peft_type == PeftType.POLY:
model_kwargs["task_ids"] = task_ids
if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
if model_kwargs.get("past_key_values", None) is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
batch_size = model_kwargs["decoder_input_ids"].shape[0]
past_key_values = self.get_prompt(batch_size)
model_kwargs["past_key_values"] = past_key_values
Expand Down

0 comments on commit a4d6e4c

Please sign in to comment.