Skip to content

Commit

Permalink
pre-commit formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Aug 6, 2023
1 parent 64852ae commit 9793faf
Showing 1 changed file with 44 additions and 31 deletions.
75 changes: 44 additions & 31 deletions src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from torch import nn
import torch.nn.functional as F

try:
import xformers.ops
Expand Down Expand Up @@ -39,44 +39,48 @@ def xformers_forward(
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()

if not hasattr(self, 'pretraining_tp'):
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1

if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = (
query_states
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
key_states
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
value_states
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand All @@ -98,8 +102,12 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = transformers.models.llama.modeling_llama.repeat_kv(key_states, self.num_key_value_groups)
value_states = transformers.models.llama.modeling_llama.repeat_kv(value_states, self.num_key_value_groups)
key_states = transformers.models.llama.modeling_llama.repeat_kv(
key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)

# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
Expand Down Expand Up @@ -157,17 +165,22 @@ def xformers_forward(
)

attn_output = attn_output.transpose(1, 2).contiguous()
#end x-formers vs. not x-formers if-else block
# end x-formers vs. not x-formers if-else block

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

return attn_output, attn_weights, past_key_value


Expand Down

0 comments on commit 9793faf

Please sign in to comment.