diff --git a/setup.py b/setup.py index 5741aa7b49..3130ad737e 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ "datasets", "diffusers", "scipy", + "protobuf", + "sentencepiece", ] setup( diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 089cda9e4c..4690e76615 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -34,7 +34,7 @@ from safetensors import safe_open from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import PreTrainedModel +from transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedModel from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.utils import PushToHubMixin @@ -730,6 +730,18 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) - if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] past_key_values = post_process_fn(past_key_values) + elif peft_config.num_transformer_submodules == 1: + # Dont' apply this to encoder-decoder models and not to models requiring special processing. + # local import in case users use a very old transformers version + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + elif peft_config.num_transformer_submodules == 2 and self.base_model._supports_cache_class: + # Dont' apply this to encoder-decoder models that don't support new Cachc format yet + # If we don't apply this, prefix-tuning fails to update cross-attn cache + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + past_key_values.cross_attention_cache = DynamicCache() + past_key_values.is_updated = { + layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache)) + } return past_key_values else: if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: @@ -2066,10 +2078,20 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if peft_config.peft_type == PeftType.POLY: model_kwargs["task_ids"] = task_ids - if model_kwargs.get("past_key_values", None) is None and peft_config.peft_type == PeftType.PREFIX_TUNING: - batch_size = model_kwargs["decoder_input_ids"].shape[0] - past_key_values = self.get_prompt(batch_size) - model_kwargs["past_key_values"] = past_key_values + elif peft_config.peft_type == PeftType.PREFIX_TUNING: + past_key_values = model_kwargs.get("past_key_values", None) + cache_position = model_kwargs.get("cache_position", [None]) + # check prefill stage + is_prefill_stage = ( + # old cache implementation + (past_key_values is None) + # new cache implementation + or (isinstance(past_key_values, Cache) and (cache_position[0] == 0)) + ) + if is_prefill_stage: + batch_size = model_kwargs["decoder_input_ids"].shape[0] + new_past_key_values = self.get_prompt(batch_size) + model_kwargs["past_key_values"] = new_past_key_values return model_kwargs diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index ad10805baf..3ad373ac01 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -11,13 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from unittest.mock import Mock, call, patch import pytest import torch +from datasets import load_dataset from parameterized import parameterized -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) from peft import ( AdaLoraConfig, @@ -466,3 +474,34 @@ def test_prompt_learning_with_grouped_query_attention(self): x = torch.tensor([[1, 2, 3]]) # does not raise model(x) + + def test_prefix_tuning_mistral(self): + # See issue 869, 1962 + model_id = "hf-internal-testing/tiny-random-MistralForCausalLM" + base_model = AutoModelForCausalLM.from_pretrained(model_id) + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(base_model, peft_config) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + def process(samples): + tokenized = tokenizer(samples["quote"], truncation=True, max_length=128) + return tokenized + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(process, batched=True) + + with tempfile.TemporaryDirectory() as tmp_dirname: + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + num_train_epochs=1, + max_steps=5, + per_device_train_batch_size=4, + output_dir=tmp_dirname, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + trainer.train() diff --git a/tests/testing_common.py b/tests/testing_common.py index 3eec02510f..954f79be5f 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1601,8 +1601,8 @@ def get_output(model): output_peft = get_output(peft_model) - # first check trivial case is not true that peft does not affect the output; for this to work, init_lora_weight - # must be False + # first check trivial case is not true that peft does not affect the output; for this to work, init_weight + # must be False (if the config supports it) if isinstance(peft_model, StableDiffusionPipeline): # for SD, check that most pixels have different values assert (output_before != output_peft).float().mean() > 0.8