Skip to content

Commit

Permalink
Implement LLaMA (vllm-project#9)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
  • Loading branch information
WoosukKwon and zhuohan123 authored Mar 30, 2023
1 parent 54bcc0f commit bc9de71
Show file tree
Hide file tree
Showing 7 changed files with 500 additions and 35 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
## Installation

```bash
pip install psutil numpy torch transformers
pip install flash-attn # This may take up to 10 mins.
pip install psutil numpy ray torch
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install flash-attn # This may take up to 20 mins.
pip install -e .
```

Expand Down
1 change: 1 addition & 0 deletions cacheflow/master/simple_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ def print_response(
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
output = output.strip()
print(f'Seq {seq.seq_id}: {output!r}')
357 changes: 357 additions & 0 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
"""1D LLaMA model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers import PreTrainedModel

from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaRMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states


class LlamaRotaryEmbedding(torch.nn.Module):

def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.max_position_embeddings = max_position_embeddings

inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
self.register_buffer("inv_freq", inv_freq)

# Create cos and sin embeddings.
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=self.inv_freq.dtype)
sin = emb.sin().to(dtype=self.inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)

def forward(
self,
positions: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
return cos, sin


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
# TODO: Optimize.
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
# TODO: Merge the gate and down linear layers.
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()

def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
x = self.act_fn(gate) * up
x, _ = self.down_proj(x)
return x


class LlamaAttention(nn.Module):

def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5

# TODO: Merge the QKV linear layers.
self.q_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.k_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.v_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
# FIXME(woosuk): Rename this.
self.attn = OPTCacheFlowAttention(scale=self.scaling)

def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)

# Apply rotrary embedding.
# TODO: Optimize.
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
cos, sin = self.rotary_emb(positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)

key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output


class LlamaDecoderLayer(nn.Module):

def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states


class LlamaModel(nn.Module):

def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states


class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = LlamaModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler()

def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
return next_tokens

_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break

assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)

@staticmethod
def get_weights(model_name: str, path: str):
if not os.path.isfile(os.path.join(model_name, "config.json")):
raise ValueError("LLaMA model's model_name has to be a path"
"to the huggingface model's directory.")
path = os.path.join(model_name, f"np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)

with lock:
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
if os.path.exists(test_weight_path):
return path

bin_files = glob.glob(os.path.join(model_name, "*.bin"))

for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())

return path
Loading

0 comments on commit bc9de71

Please sign in to comment.