Skip to content

Commit

Permalink
Update XFormers Attention Monkeypatch to handle Llama-2 70B (GQA) (#339)
Browse files Browse the repository at this point in the history
* Fix XFormers attention for Llama-2 70B (GQA)

Updated XFormers MonkeyPatch to handle GQA as used in Llama-2 70B. All the updated code is taken directly from the Transformers library: huggingface/transformers@07360b6#diff-06392bad3b9e97be9ade60d4ac46f73b6809388f4d507c2ba1384ab872711c51 from their llama_modeling.py file.

* Catch configs without pretraining_tp

* Whitespace bug fix

Command had accidentally been moved out of if-else block.

* pre-commit formatting fixes

Thanks to @winglian
  • Loading branch information
ssmi153 authored Aug 6, 2023
1 parent c93655c commit 10405b9
Showing 1 changed file with 66 additions and 17 deletions.
83 changes: 66 additions & 17 deletions src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from torch import nn

Expand Down Expand Up @@ -38,21 +39,48 @@ def xformers_forward(
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1

if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand All @@ -73,6 +101,14 @@ def xformers_forward(

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = transformers.models.llama.modeling_llama.repeat_kv(
key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)

# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
Expand Down Expand Up @@ -128,10 +164,23 @@ def xformers_forward(
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.transpose(1, 2).contiguous()
# end x-formers vs. not x-formers if-else block

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)

return attn_output, attn_weights, past_key_value


Expand Down

0 comments on commit 10405b9

Please sign in to comment.