Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT doesn't inject virtual tokens into generate forward pass #2134

Closed
2 of 4 tasks
Kamichanw opened this issue Oct 6, 2024 · 7 comments
Closed
2 of 4 tasks

PEFT doesn't inject virtual tokens into generate forward pass #2134

Kamichanw opened this issue Oct 6, 2024 · 7 comments

Comments

@Kamichanw
Copy link

System Info

  • transformers version: 4.46.0.dev0
  • Platform: Linux-5.4.0-148-generic-x86_64-with-glibc2.31
  • Python version: 3.9.19
  • Huggingface_hub version: 0.24.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA RTX A6000

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I met the problem by the following code:

import peft
from transformers import IdeficsForVisionText2Text, AutoProcessor
import sys
import torch

sys.path.insert(0, "..")
import config  # my own file

device = torch.device("cuda:3")
model = IdeficsForVisionText2Text.from_pretrained(
    config.idefics_9b_path, torch_dtype=torch.float16
).to(device)
processor = AutoProcessor.from_pretrained(
    config.idefics_9b_path, torch_dtype=torch.float16
)
model = peft.get_peft_model(
    model,
    peft.PrefixTuningConfig(
        peft_type="PREFIX_TUNING",
        task_type="CAUSAL_LM",
        num_virtual_tokens=2,
        token_dim=4096,
        num_transformer_submodules=1,
        num_attention_heads=32,
        num_layers=32,
        encoder_hidden_size=768,
    ),
    mixed=False,
)
inputs = processor(["hello"]).to(device)
model.eval()
model.generate(**inputs)

When I add print(past_key_values) in transformers side, I got DynamicCache(), which means the virtual tokens weren't injected to forward pass.

Expected behavior

It should get a cache with length of num_virtual_tokens.

@BenjaminBossan
Copy link
Member

Could you please clarify where you check past_key_values and what you would expect there?

Also, note that #2096 is in the works that should hopefully fix some issues that prefix tuning has with the latest transformers version. If possible, you could check if that branch fixes the error.

@Kamichanw
Copy link
Author

Kamichanw commented Oct 7, 2024

#2096 never fixes anything except suppressing the warnings since here will convert legacy past_key_values to correct Cache instance. I added a print(past_key_values) at this line and I got DynamicCache(), which made me curious about whether prefix tokens were injected to forward pass when generating.

Then, I started to debug step by step until I found here. The past_key_values is DynamicCache() while num_virtual_tokens is 1. I wonder if this is correct behavior because in my opinion, the virtual tokens should be injected as past_key_values just as what we done in training procedure.

@BenjaminBossan
Copy link
Member

When I add print(past_key_values) in transformers side, I got DynamicCache(), which means the virtual tokens weren't injected to forward pass.

Then, I started to debug step by step until I found here. The past_key_values is DynamicCache() while num_virtual_tokens is 1. I wonder if this is correct behavior because in my opinion, the virtual tokens should be injected as past_key_values just as what we done in training procedure.

I'm not sure if I follow. I set a debugger at the line you mentioned. past_key_values is indeed a DynamicCache instance, and it should contain the virtual tokens. When I have 2 virtual tokens, past_key_values.get_seq_length() returns 2. When I have 20 virtual tokens, past_key_values.get_seq_length() returns 20, etc. What would be your expectation?

@DavdGao
Copy link

DavdGao commented Oct 29, 2024

I had a similar problem.

I am currently training the Qwen2.5-7B-Instruct model. After completing the training, I loaded the model using the AutoPeftModelForCausalLM.from_pretrained interface. However, I discovered that the newly added weight named "prompt_embeddings" is not participating in the inference process.

To confirm this, I attempted to change all the training data to {"prompt": "1+1=", "response": "11"}. After the training was completed, but before the Python program exited, I tested the prompt "1+1=". At this point, the model outputs "11", indicating that the training had indeed taken effect.

However, when I subsequently loaded the model from disk using AutoPeftModelForCausalLM.from_pretrained and asked "1+1=" again, I was unable to obtain the result "11". This led me to conclude that these new weights are not being incorporated into the inference process. (I have printed the model_named_parameters to ensure the "prompt_embeddings" is loaded)

@DavdGao
Copy link

DavdGao commented Oct 29, 2024

When I installed the peft package from github repo (0.13.3.dev0), I can use AutoPeftModelForCausalLM.from_pretrained to get expected output from my trained model, i.e. 1+1=11.

However, the following code still failed

def inference(path_model, messages):
    tokenizer = AutoTokenizer.from_pretrained(path_model, use_fast=False)
    model = AutoPeftModelForCausalLM.from_pretrained(path_model, device_map="auto")
    model.eval()
    
    prompt_tokenized = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    
    output_tokenized = model.generate(input_ids=prompt_tokenized.to("cuda"), do_sample=False, max_length=10000)[0]
    output_tokenized = output_tokenized[prompt_tokenized.size(1):]
    output = tokenizer.decode(output_tokenized, skip_special_tokens=True)
    
    return output

I met the following error:

  File "/cpfs/data/gaodawei.gdw/utils.py", line 59, in get_local_model_response
    output_tokenized = model.generate(input_ids=prompt_tokenized.to("cuda"), do_sample=False, max_length=10000)[0]
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/peft/src/peft/peft_model.py", line 1746, in generate
    outputs = self.base_model.generate(**kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/generation/utils.py", line 3206, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 523, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cpfs/data/gaodawei.gdw/miniconda3/envs/pft/lib/python3.11/site-packages/transformers/cache_utils.py", line 447, in update
    self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)

Only if I change my code into device_map="cuda:0" and output_tokenized = model.generate(input_ids=prompt_tokenized.to("cuda:0") ..., the model is loaded and works normally.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Oct 30, 2024
See huggingface#2134

After introducing the usage of DynamicCache for prefix tuning, a bug
could now occur if the model is dispatched to different devices. This is
because we need to move the key and value cache for each layer to that
layer's respective device.

The new code mostly consists of code copied from transformers to be
consistent with how transformers solves this.
@BenjaminBossan
Copy link
Member

Thanks for reporting @DavdGao. #2189 should fix that issue, if you can give it a try and report back, that would help.

BenjaminBossan added a commit that referenced this issue Nov 1, 2024
See #2134

After introducing the usage of DynamicCache for prefix tuning, a bug
could now occur if the model is dispatched to different devices. This is
because we need to move the key and value cache for each layer to that
layer's respective device.

The new code mostly consists of code copied from transformers to be
consistent with how transformers solves this.
@DavdGao
Copy link

DavdGao commented Nov 5, 2024

Thanks for reporting @DavdGao. #2189 should fix that issue, if you can give it a try and report back, that would help.

@BenjaminBossan Thanks, I have pulled the latest commit 7295b33, and it fixes the issue in inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants