diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 7f41436045..3cc3d3dd69 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -2062,7 +2062,7 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if peft_config.peft_type == PeftType.POLY: model_kwargs["task_ids"] = task_ids - if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING: + if model_kwargs.get("past_key_values", None) is None and peft_config.peft_type == PeftType.PREFIX_TUNING: batch_size = model_kwargs["decoder_input_ids"].shape[0] past_key_values = self.get_prompt(batch_size) model_kwargs["past_key_values"] = past_key_values