-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fcc026b
commit f1e6c09
Showing
8 changed files
with
1,251 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) 2024, Songlin Yang, Yu Zhang | ||
|
||
from __future__ import annotations | ||
|
||
import warnings | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from einops import rearrange | ||
from transformers.cache_utils import Cache | ||
from transformers.utils import logging | ||
|
||
from fla.modules import RotaryEmbedding | ||
from fla.modules.fused_septlinear import FusedSeptLinear | ||
|
||
try: | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
from flash_attn.bert_padding import (index_first_axis, pad_input, | ||
unpad_input) | ||
except ImportError: | ||
warnings.warn("Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", | ||
category=ImportWarning) | ||
flash_attn_func = None | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class SeptAttention(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int = 2048, | ||
num_heads: int = 32, | ||
num_kv_heads: Optional[int] = None, | ||
window_size: Optional[int] = None, | ||
max_position_embeddings: Optional[int] = None, | ||
layer_idx: int = None | ||
): | ||
super().__init__() | ||
|
||
self.num_heads = num_heads | ||
if num_kv_heads is None: | ||
self.num_kv_heads = self.num_heads | ||
else: | ||
self.num_kv_heads = num_kv_heads | ||
self.num_kv_groups = num_heads // self.num_kv_heads | ||
self.hidden_size = hidden_size | ||
self.head_dim = self.hidden_size // self.num_heads | ||
self.kv_dim = self.num_kv_heads * self.head_dim | ||
self.kv_dim = self.num_kv_heads * self.head_dim | ||
self.window_size = window_size | ||
self.max_position_embeddings = max_position_embeddings | ||
self.layer_idx = layer_idx | ||
|
||
self.q_proj = FusedSeptLinear(self.hidden_size, self.hidden_size, bias=False) | ||
self.k_proj = FusedSeptLinear(self.hidden_size, self.kv_dim, bias=False) | ||
self.v_proj = FusedSeptLinear(self.hidden_size, self.kv_dim, bias=False) | ||
self.o_proj = FusedSeptLinear(self.hidden_size, self.hidden_size, bias=False) | ||
|
||
self.rotary = RotaryEmbedding(self.head_dim) | ||
|
||
self.apply(self._initialize_weights) | ||
|
||
def _initialize_weights(self, module: nn.Module): | ||
if getattr(module, "_is_hf_initialized", False): | ||
return | ||
if isinstance(module, nn.Linear): | ||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) | ||
if module.bias is not None: | ||
nn.init.zeros_(module.bias) | ||
module._is_hf_initialized = True | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Cache] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
**kwargs, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
batch_size, q_len, _ = hidden_states.size() | ||
q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads) | ||
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads) | ||
v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b h t d', h=self.num_kv_heads) | ||
|
||
seqlen_offset, max_seqlen = 0, q.shape[1] | ||
if past_key_values is not None: | ||
seqlen_offset = past_key_values.get_seq_length(self.layer_idx) | ||
max_seqlen = q.shape[1] + seqlen_offset | ||
|
||
if attention_mask is not None: | ||
# to deliminate the offsets of padding tokens | ||
seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0) | ||
max_seqlen = q.shape[1] + max(seqlen_offset) | ||
|
||
if self.max_position_embeddings is not None: | ||
max_seqlen = max(max_seqlen, self.max_position_embeddings) | ||
q, k = self.rotary(q, k, seqlen_offset, max_seqlen) | ||
|
||
k = rearrange(k, 'b t h d -> b h t d') | ||
if past_key_values is not None: | ||
k, v = past_key_values.update(k, v, self.layer_idx) | ||
k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d') | ||
if self.num_kv_groups > 1: | ||
k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') | ||
v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') | ||
|
||
if flash_attn_func is None: | ||
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") | ||
|
||
# Contains at least one padding token in the sequence | ||
if attention_mask is not None: | ||
q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len) | ||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_q, max_seqlen_k = max_seq_lens | ||
o = flash_attn_varlen_func( | ||
q, k, v, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_q, | ||
max_seqlen_k=max_seqlen_k, | ||
causal=True, | ||
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) | ||
) | ||
o = pad_input(o, indices_q, batch_size, q_len) | ||
else: | ||
o = flash_attn_func( | ||
q, k, v, | ||
causal=True, | ||
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) | ||
) | ||
o = o.reshape(batch_size, q_len, self.hidden_size) | ||
o = self.o_proj(o) | ||
|
||
if not output_attentions: | ||
attentions = None | ||
|
||
return o, attentions, past_key_values | ||
|
||
def _upad_input(self, q, k, v, attention_mask, q_len): | ||
seqlens = attention_mask.sum(-1, dtype=torch.int32) | ||
indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_k = seqlens.max().item() | ||
cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) | ||
batch_size, seq_len, num_key_value_heads, head_dim = k.shape | ||
|
||
k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) | ||
v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) | ||
if q_len == seq_len: | ||
q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_q = max_seqlen_k | ||
indices_q = indices_k | ||
elif q_len == 1: | ||
max_seqlen_q = 1 | ||
# There is a memcpy here, that is very bad. | ||
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) | ||
indices_q = cu_seqlens_q[:-1] | ||
q = q.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
attention_mask = attention_mask[:, -q_len:] | ||
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) | ||
|
||
return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | ||
|
||
from fla.models.septnet.configuration_septnet import SeptNetConfig | ||
from fla.models.septnet.modeling_septnet import ( | ||
SeptNetForCausalLM, SeptNetModel) | ||
|
||
AutoConfig.register(SeptNetConfig.model_type, SeptNetConfig) | ||
AutoModel.register(SeptNetConfig, SeptNetModel) | ||
AutoModelForCausalLM.register(SeptNetConfig, SeptNetForCausalLM) | ||
|
||
|
||
__all__ = ['SeptNetConfig', 'SeptNetForCausalLM', 'SeptNetModel'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from typing import Optional | ||
|
||
from transformers.configuration_utils import PretrainedConfig | ||
|
||
|
||
class SeptNetConfig(PretrainedConfig): | ||
|
||
model_type = 'bitnet' | ||
keys_to_ignore_at_inference = ['past_key_values'] | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int = 32000, | ||
hidden_size: int = 2048, | ||
hidden_ratio: Optional[int] = 4, | ||
intermediate_size: Optional[int] = None, | ||
num_hidden_layers: int = 24, | ||
num_heads: int = 32, | ||
num_kv_heads: int = None, | ||
hidden_act: str = "swish", | ||
window_size: Optional[int] = None, | ||
max_position_embeddings: int = 2048, | ||
initializer_range: float = 0.02, | ||
elementwise_affine: Optional[bool] = True, | ||
norm_eps: float = 1e-6, | ||
use_cache: bool = True, | ||
pad_token_id: int = None, | ||
bos_token_id: int = 1, | ||
eos_token_id: int = 2, | ||
tie_word_embeddings: bool = False, | ||
attention_bias: bool = False, | ||
fuse_norm: bool = True, | ||
fuse_cross_entropy: bool = True, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
self.hidden_size = hidden_size | ||
self.hidden_ratio = hidden_ratio | ||
self.intermediate_size = intermediate_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_heads = num_heads | ||
self.num_kv_heads = num_kv_heads | ||
self.window_size = window_size | ||
self.max_position_embeddings = max_position_embeddings | ||
|
||
self.hidden_act = hidden_act | ||
self.initializer_range = initializer_range | ||
self.elementwise_affine = elementwise_affine | ||
self.norm_eps = norm_eps | ||
self.use_cache = use_cache | ||
self.attention_bias = attention_bias | ||
self.fuse_cross_entropy = fuse_cross_entropy | ||
self.fuse_norm = fuse_norm | ||
|
||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) |
Oops, something went wrong.