Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BitNet #85

Merged
merged 10 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions fla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# -*- coding: utf-8 -*-

from fla.layers import (ABCAttention, Attention, BasedLinearAttention,
DeltaNet, GatedLinearAttention, GatedSlotAttention,
HGRN2Attention, HGRNAttention, LinearAttention,
MultiScaleRetention, ReBasedLinearAttention)
from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM,
DeltaNetModel, GLAForCausalLM, GLAModel,
GSAForCausalLM, GSAModel, HGRN2ForCausalLM, HGRN2Model,
HGRNForCausalLM, LinearAttentionForCausalLM,
LinearAttentionModel, RetNetForCausalLM, RetNetModel,
RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM,
TransformerModel)
BitAttention, DeltaNet, GatedLinearAttention,
GatedSlotAttention, HGRN2Attention, HGRNAttention,
LinearAttention, MultiScaleRetention,
ReBasedLinearAttention)
from fla.models import (ABCForCausalLM, ABCModel, BitNetForCausalLM,
BitNetModel, DeltaNetForCausalLM, DeltaNetModel,
GLAForCausalLM, GLAModel, GSAForCausalLM, GSAModel,
HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
LinearAttentionForCausalLM, LinearAttentionModel,
RetNetForCausalLM, RetNetModel, RWKV6ForCausalLM,
RWKV6Model, TransformerForCausalLM, TransformerModel)

__all__ = [
'ABCAttention',
'Attention',
'BasedLinearAttention',
'BitAttention',
'DeltaNet',
'HGRNAttention',
'HGRN2Attention',
Expand All @@ -26,6 +28,8 @@
'ReBasedLinearAttention',
'ABCForCausalLM',
'ABCModel',
'BitNetForCausalLM',
'BitNetModel',
'DeltaNetForCausalLM',
'DeltaNetModel',
'HGRNForCausalLM',
Expand Down
4 changes: 3 additions & 1 deletion fla/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .abc import ABCAttention
from .attn import Attention
from .based import BasedLinearAttention
from .bitattn import BitAttention
from .delta_net import DeltaNet
from .gla import GatedLinearAttention
from .gsa import GatedSlotAttention
Expand All @@ -17,6 +18,7 @@
'ABCAttention',
'Attention',
'BasedLinearAttention',
'BitAttention',
'DeltaNet',
'GatedLinearAttention',
'GatedSlotAttention',
Expand All @@ -25,5 +27,5 @@
'LinearAttention',
'MultiScaleRetention',
'ReBasedLinearAttention',
'RWKV6Attention'
'RWKV6Attention',
]
183 changes: 183 additions & 0 deletions fla/layers/bitattn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang, Yu Zhang

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from transformers.utils import logging

from fla.modules import RMSNorm, RotaryEmbedding
from fla.modules.fused_bitlinear import FusedBitLinear

if TYPE_CHECKING:
from fla.models.utils import Cache

try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import (index_first_axis, pad_input,
unpad_input)
except ImportError:
warnings.warn(
"Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
category=ImportWarning
)
flash_attn_func = None

logger = logging.get_logger(__name__)


class BitAttention(nn.Module):

def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 32,
num_kv_heads: Optional[int] = None,
window_size: Optional[int] = None,
rope_theta: Optional[float] = 10000.,
max_position_embeddings: Optional[int] = None,
norm_first: bool = False,
norm_eps: float = 1e-5,
layer_idx: int = None
):
super().__init__()

self.num_heads = num_heads
if num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = num_kv_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.kv_dim = self.num_kv_heads * self.head_dim
self.window_size = window_size
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.norm_first = norm_first
self.layer_idx = layer_idx

if norm_first:
self.norm = RMSNorm(self.hidden_size, eps=norm_eps)
self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)

self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)

batch_size, q_len, _ = hidden_states.size()

if self.norm_first:
hidden_states = self.norm(hidden_states)

q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)

seqlen_offset, max_seqlen = 0, q_len
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset

if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
max_seqlen = q.shape[1] + max(seqlen_offset)

if self.max_position_embeddings is not None:
max_seqlen = max(max_seqlen, self.max_position_embeddings)
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)

if past_key_values is not None:
k, v = past_key_values.update(
attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
layer_idx=self.layer_idx,
offset=q_len,
cache_kwargs=dict(window_size=self.window_size)
)['attn_state']
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
v = rearrange(v, '... (h d) -> ... h d', h=self.num_kv_heads)

if flash_attn_func is None:
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")

# Contains at least one padding token in the sequence
if attention_mask is not None:
q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
o = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=True,
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
)
o = pad_input(o, indices_q, batch_size, q_len)
else:
o = flash_attn_func(
q, k, v,
causal=True,
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
)
o = o.reshape(batch_size, q_len, self.hidden_size)
o = self.o_proj(o)

if not output_attentions:
attentions = None

return o, attentions, past_key_values

def _upad_input(self, q, k, v, attention_mask, q_len):
seqlens = attention_mask.sum(-1, dtype=torch.int32)
indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_k = seqlens.max().item()
cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
batch_size, seq_len, num_key_value_heads, head_dim = k.shape

k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
if q_len == seq_len:
q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_q = max_seqlen_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_q = 1
# There is a memcpy here, that is very bad.
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -q_len:]
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)

return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
2 changes: 2 additions & 0 deletions fla/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM,
DeltaNetModel)
from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
Expand All @@ -20,6 +21,7 @@

__all__ = [
'ABCConfig', 'ABCForCausalLM', 'ABCModel',
'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
'GLAConfig', 'GLAForCausalLM', 'GLAModel',
'GSAConfig', 'GSAForCausalLM', 'GSAModel',
Expand Down
13 changes: 13 additions & 0 deletions fla/models/bitnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.bitnet.configuration_bitnet import BitNetConfig
from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel

AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
AutoModel.register(BitNetConfig, BitNetModel)
AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM)


__all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel']
68 changes: 68 additions & 0 deletions fla/models/bitnet/configuration_bitnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-

from typing import Optional

from transformers.configuration_utils import PretrainedConfig


class BitNetConfig(PretrainedConfig):

model_type = 'bitnet'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
num_hidden_layers: int = 24,
num_heads: int = 32,
num_kv_heads: int = None,
window_size: Optional[int] = None,
rope_theta: Optional[float] = 10000.,
max_position_embeddings: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
initializer_range: float = 0.02,
elementwise_affine: Optional[bool] = True,
norm_first: bool = False,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
attention_bias: bool = False,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.window_size = window_size
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act

self.initializer_range = initializer_range
self.elementwise_affine = elementwise_affine
self.norm_first = norm_first
self.norm_eps = norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_norm = fuse_norm

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading