Skip to content

Commit

Permalink
add SeptNet
Browse files Browse the repository at this point in the history
  • Loading branch information
DustinWang1 committed Nov 27, 2024
1 parent fcc026b commit f1e6c09
Show file tree
Hide file tree
Showing 8 changed files with 1,251 additions and 2 deletions.
9 changes: 7 additions & 2 deletions fla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from fla.layers import (ABCAttention, Attention, BasedLinearAttention,
DeltaNet, GatedLinearAttention, GatedSlotAttention,
HGRN2Attention, HGRNAttention, LinearAttention,
MultiScaleRetention, ReBasedLinearAttention, BitAttention)
MultiScaleRetention, ReBasedLinearAttention, BitAttention,
SeptAttention)
from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM,
DeltaNetModel, GLAForCausalLM, GLAModel,
GSAForCausalLM, GSAModel, HGRN2ForCausalLM, HGRN2Model,
HGRNForCausalLM, LinearAttentionForCausalLM,
LinearAttentionModel, RetNetForCausalLM, RetNetModel,
RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM,
TransformerModel, BitNetForCausalLM, BitNetModel)
TransformerModel, BitNetForCausalLM, BitNetModel,
SeptNetForCausalLM, SeptNetModel)

__all__ = [
'ABCAttention',
Expand All @@ -25,6 +27,7 @@
'LinearAttention',
'MultiScaleRetention',
'ReBasedLinearAttention',
'SeptAttention',
'ABCForCausalLM',
'ABCModel',
'BitNetForCausalLM',
Expand All @@ -47,6 +50,8 @@
'RWKV6Model',
'TransformerForCausalLM',
'TransformerModel',
'SeptNetForCausalLM',
'SeptNetModel',
'chunk_gla',
'chunk_retention',
'fused_chunk_based',
Expand Down
2 changes: 2 additions & 0 deletions fla/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .rebased import ReBasedLinearAttention
from .rwkv6 import RWKV6Attention
from .bitattn import BitAttention
from .septattn import SeptAttention

__all__ = [
'ABCAttention',
Expand All @@ -28,4 +29,5 @@
'MultiScaleRetention',
'ReBasedLinearAttention',
'RWKV6Attention',
'SeptAttention'
]
170 changes: 170 additions & 0 deletions fla/layers/septattn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang, Yu Zhang

from __future__ import annotations

import warnings
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from transformers.cache_utils import Cache
from transformers.utils import logging

from fla.modules import RotaryEmbedding
from fla.modules.fused_septlinear import FusedSeptLinear

try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import (index_first_axis, pad_input,
unpad_input)
except ImportError:
warnings.warn("Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
category=ImportWarning)
flash_attn_func = None

logger = logging.get_logger(__name__)


class SeptAttention(nn.Module):

def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 32,
num_kv_heads: Optional[int] = None,
window_size: Optional[int] = None,
max_position_embeddings: Optional[int] = None,
layer_idx: int = None
):
super().__init__()

self.num_heads = num_heads
if num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = num_kv_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.kv_dim = self.num_kv_heads * self.head_dim
self.window_size = window_size
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx

self.q_proj = FusedSeptLinear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = FusedSeptLinear(self.hidden_size, self.kv_dim, bias=False)
self.v_proj = FusedSeptLinear(self.hidden_size, self.kv_dim, bias=False)
self.o_proj = FusedSeptLinear(self.hidden_size, self.hidden_size, bias=False)

self.rotary = RotaryEmbedding(self.head_dim)

self.apply(self._initialize_weights)

def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
batch_size, q_len, _ = hidden_states.size()
q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b h t d', h=self.num_kv_heads)

seqlen_offset, max_seqlen = 0, q.shape[1]
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset

if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
max_seqlen = q.shape[1] + max(seqlen_offset)

if self.max_position_embeddings is not None:
max_seqlen = max(max_seqlen, self.max_position_embeddings)
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)

k = rearrange(k, 'b t h d -> b h t d')
if past_key_values is not None:
k, v = past_key_values.update(k, v, self.layer_idx)
k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d')
if self.num_kv_groups > 1:
k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d')
v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d')

if flash_attn_func is None:
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")

# Contains at least one padding token in the sequence
if attention_mask is not None:
q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
o = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=True,
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
)
o = pad_input(o, indices_q, batch_size, q_len)
else:
o = flash_attn_func(
q, k, v,
causal=True,
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
)
o = o.reshape(batch_size, q_len, self.hidden_size)
o = self.o_proj(o)

if not output_attentions:
attentions = None

return o, attentions, past_key_values

def _upad_input(self, q, k, v, attention_mask, q_len):
seqlens = attention_mask.sum(-1, dtype=torch.int32)
indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_k = seqlens.max().item()
cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
batch_size, seq_len, num_key_value_heads, head_dim = k.shape

k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
if q_len == seq_len:
q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_q = max_seqlen_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_q = 1
# There is a memcpy here, that is very bad.
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -q_len:]
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)

return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
2 changes: 2 additions & 0 deletions fla/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
TransformerModel)
from fla.models.bitnet import (BitNetConfig, BitNetForCausalLM, BitNetModel)
from fla.models.septnet import (SeptNetConfig, SeptNetForCausalLM, SeptNetModel)

__all__ = [
'ABCConfig', 'ABCForCausalLM', 'ABCModel',
Expand All @@ -33,5 +34,6 @@
'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
'SambaConfig', 'SambaForCausalLM', 'SambaModel',
'SeptNetConfig', 'SeptNetForCausalLM', 'SeptNetModel',
'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'
]
14 changes: 14 additions & 0 deletions fla/models/septnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.septnet.configuration_septnet import SeptNetConfig
from fla.models.septnet.modeling_septnet import (
SeptNetForCausalLM, SeptNetModel)

AutoConfig.register(SeptNetConfig.model_type, SeptNetConfig)
AutoModel.register(SeptNetConfig, SeptNetModel)
AutoModelForCausalLM.register(SeptNetConfig, SeptNetForCausalLM)


__all__ = ['SeptNetConfig', 'SeptNetForCausalLM', 'SeptNetModel']
63 changes: 63 additions & 0 deletions fla/models/septnet/configuration_septnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-

from typing import Optional

from transformers.configuration_utils import PretrainedConfig


class SeptNetConfig(PretrainedConfig):

model_type = 'bitnet'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 32,
num_kv_heads: int = None,
hidden_act: str = "swish",
window_size: Optional[int] = None,
max_position_embeddings: int = 2048,
initializer_range: float = 0.02,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
attention_bias: bool = False,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.window_size = window_size
self.max_position_embeddings = max_position_embeddings

self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_norm = fuse_norm

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading

0 comments on commit f1e6c09

Please sign in to comment.