Skip to content

Commit

Permalink
llama mixer hybrid
Browse files Browse the repository at this point in the history
  • Loading branch information
blbadger committed Jul 3, 2024
1 parent 3cbd619 commit dbf1653
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
import einops
from einops import rearrange


if is_flash_attn_2_available():
Expand Down Expand Up @@ -668,6 +670,21 @@ def forward(
"sdpa": LlamaSdpaAttention,
}

class MixerLayer(nn.Module):

def __init__(self, length):
super().__init__()
self.length = length
self.conv = nn.Conv1d(length, length, 1, padding='same')

def forward(self, x: torch.tensor):
if x.dim() > 3:
x = rearrange(x, 'b p t f -> (b p) t f')

masked_conv = torch.tril(rearrange(self.conv.weight, 'f d p -> p f d'))
self.conv.weight.data = rearrange(masked_conv, 'p f d -> f d p').contiguous()
return self.conv(x)


class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
Expand All @@ -679,6 +696,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mixer_layer = MixerLayer(512)

def forward(
self,
Expand Down Expand Up @@ -724,7 +742,8 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = residual + hidden_states
mixer_states = self.mixer_layer(hidden_states)
hidden_states = residual + hidden_states + mixer_states

# Fully Connected
residual = hidden_states
Expand Down

0 comments on commit dbf1653

Please sign in to comment.