Skip to content

Commit

Permalink
Add Support for Mistral Model in Llama-Adapter Method (huggingface#1433)
Browse files Browse the repository at this point in the history
* Support Mistral For llama-adapter

* Update src/peft/tuners/adaption_prompt/layer.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/peft/tuners/adaption_prompt/layer.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* corrected logic and added test

* removed commented out code

* Added seperate test functions for mistral

* missed self.assert

* ruff formatting

---------

Co-authored-by: Prakhar Saxena <prakharsxena11111@gmail.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
  • Loading branch information
3 people authored and BenjaminBossan committed Mar 14, 2024
1 parent 2616ec0 commit df599a8
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 7 deletions.
7 changes: 7 additions & 0 deletions src/peft/tuners/adaption_prompt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def is_adaption_prompt(self) -> bool:
v_proj_layer="v_proj",
o_proj_layer="o_proj",
),
"mistral": ModelTypeConfig( # same as llama,
compute_query_states=llama_compute_query_states,
target_modules="self_attn",
k_proj_layer="k_proj",
v_proj_layer="v_proj",
o_proj_layer="o_proj",
),
}


Expand Down
18 changes: 13 additions & 5 deletions src/peft/tuners/adaption_prompt/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,38 @@ def forward(self, **kwargs):
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
factor = (
self.model.k_proj.in_features // self.model.k_proj.out_features
) # Mistral has different input and output dimension for k_proj and v_proj layers

if k_proj_layer == v_proj_layer:
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
# (bsz, num_heads, adapter_len, head_dim)

# (bsz, num_key_value_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# (bsz, num_heads, adapter_len, head_dim)
adapter_v = (
value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)

# Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
# (bsz, num_heads, adapter_len, head_dim)
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
# Recompute query states.
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states
# (bsz, num_heads, q_len, head_dim)
query_states = compute_query_states(model=self.model, **kwargs)

previous_dtype = query_states.dtype

# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
self.model.head_dim
Expand All @@ -108,6 +115,7 @@ def forward(self, **kwargs):
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
# (bsz, q_len, num_heads * head_dim)
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)

# (bsz, q_len, hidden_size)
if o_proj_layer is not None:
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
Expand Down
7 changes: 6 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)

seq_len = q_len

if past_key_value is not None:
Expand Down
Loading

0 comments on commit df599a8

Please sign in to comment.