From bc9de71ce5f06217ea86953c5e30326fce043cda Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Mar 2023 21:25:32 -0700 Subject: [PATCH] Implement LLaMA (#9) Co-authored-by: Zhuohan Li --- README.md | 6 +- cacheflow/master/simple_frontend.py | 1 + cacheflow/models/llama.py | 357 ++++++++++++++++++++++++++++ cacheflow/models/memory_analyzer.py | 156 ++++++++++-- cacheflow/models/model_utils.py | 6 +- cacheflow/models/opt.py | 7 +- cacheflow/models/sample.py | 2 +- 7 files changed, 500 insertions(+), 35 deletions(-) create mode 100644 cacheflow/models/llama.py diff --git a/README.md b/README.md index 41fd7107a116e..69e9e1afdf26e 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,10 @@ ## Installation ```bash -pip install psutil numpy torch transformers -pip install flash-attn # This may take up to 10 mins. +pip install psutil numpy ray torch +pip install git+https://github.com/huggingface/transformers # Required for LLaMA. +pip install sentencepiece # Required for LlamaTokenizer. +pip install flash-attn # This may take up to 20 mins. pip install -e . ``` diff --git a/cacheflow/master/simple_frontend.py b/cacheflow/master/simple_frontend.py index 3e3fa5987b252..9a86226b90e8e 100644 --- a/cacheflow/master/simple_frontend.py +++ b/cacheflow/master/simple_frontend.py @@ -61,4 +61,5 @@ def print_response( for seq in seq_group.seqs: token_ids = seq.get_token_ids() output = self.tokenizer.decode(token_ids, skip_special_tokens=True) + output = output.strip() print(f'Seq {seq.seq_id}: {output!r}') diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py new file mode 100644 index 0000000000000..4ddbc698eb789 --- /dev/null +++ b/cacheflow/models/llama.py @@ -0,0 +1,357 @@ +"""1D LLaMA model compatible with HuggingFace weights.""" +import os +import glob +import filelock +from tqdm import tqdm +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from transformers import LlamaConfig +from transformers import PreTrainedModel + +from cacheflow.models import InputMetadata +from cacheflow.models.attention import OPTCacheFlowAttention +from cacheflow.models.sample import Sampler +from cacheflow.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) +from cacheflow.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlamaRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.max_position_embeddings = max_position_embeddings + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Create cos and sin embeddings. + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, self.inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=self.inv_freq.dtype) + sin = emb.sin().to(dtype=self.inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin): + # TODO: Optimize. + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + # TODO: Merge the gate and down linear layers. + self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, + bias=False, gather_output=False, + perform_initialization=False) + self.down_proj = RowParallelLinear(intermediate_size, hidden_size, + bias=False, input_is_parallel=True, + perform_initialization=False) + self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, + bias=False, gather_output=False, + perform_initialization=False) + assert hidden_act == 'silu' + self.act_fn = nn.SiLU() + + def forward(self, x): + gate, _ = self.gate_proj(x) + up, _ = self.up_proj(x) + x = self.act_fn(gate) * up + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = hidden_size // self.total_num_heads + self.scaling = self.head_dim ** -0.5 + + # TODO: Merge the QKV linear layers. + self.q_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.k_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.v_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads * self.head_dim, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) + # FIXME(woosuk): Rename this. + self.attn = OPTCacheFlowAttention(scale=self.scaling) + + def forward( + self, + positions: torch.LongTensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + # Apply rotrary embedding. + # TODO: Optimize. + q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) + cos, sin = self.rotary_emb(positions) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) + k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) + + key_cache, value_cache = kv_cache + attn_output = self.attn( + q, k, v, key_cache, value_cache, input_metadata, cache_event) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.LongTensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaModel(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, + perform_initialization=False) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class LlamaForCausalLM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.model = LlamaModel(config) + self.lm_head = ColumnParallelLinear(config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.model( + input_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", + "q_proj.weight", "k_proj.weight", + "v_proj.weight", "gate_proj.weight", + "up_proj.weight"] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + def load_weights(self, weights_path: str): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, param in state_dict.items(): + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, + name))) + for p in self._column_parallel_weights: + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + @staticmethod + def get_weights(model_name: str, path: str): + if not os.path.isfile(os.path.join(model_name, "config.json")): + raise ValueError("LLaMA model's model_name has to be a path" + "to the huggingface model's directory.") + path = os.path.join(model_name, f"np") + path = os.path.abspath(os.path.expanduser(path)) + os.makedirs(path, exist_ok=True) + lock_path = os.path.join(path, "file_lock") + lock = filelock.FileLock(lock_path) + + with lock: + test_weight_path = os.path.join(path, "model.embed_tokens.weight") + if os.path.exists(test_weight_path): + return path + + bin_files = glob.glob(os.path.join(model_name, "*.bin")) + + for bin_file in tqdm(bin_files, desc="Convert format"): + state = torch.load(bin_file, map_location="cpu") + for name, param in tqdm(state.items(), leave=False): + param_path = os.path.join(path, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + + return path diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 45d6a36b90222..d3dc8f44bbf98 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -15,11 +15,30 @@ def get_max_num_gpu_blocks( ) -> int: raise NotImplementedError() + def get_workspace_size(self) -> int: + return 1 * _GiB + + def get_cache_block_size(self) -> int: + raise NotImplementedError() + def get_max_num_cpu_blocks( self, - memory_utilization: float, + swap_space: int, ) -> int: - raise NotImplementedError() + swap_space = swap_space * _GiB + cpu_memory = self.cpu_memory + if swap_space > 0.8 * cpu_memory: + raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' + 'takes more than 80% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'Please check the swap space size.') + if swap_space > 0.5 * cpu_memory: + print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' + 'takes more than 50% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'This may slow the system performance.') + max_num_blocks = swap_space // self.get_cache_block_size() + return max_num_blocks class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): @@ -52,9 +71,9 @@ def __init__( def _get_param_size(self) -> int: word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size - if self.embedding_size != self.vocab_size: + if self.embedding_size != self.hidden_size: # Project in/out. - word_embedding += 2 * self.embedding_size * self.vocab_size + word_embedding += 2 * self.embedding_size * self.hidden_size position_embedding = self.max_position * self.hidden_size ln1 = 2 * self.hidden_size @@ -89,15 +108,15 @@ def _get_max_act_size( ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size # Double the activation size for input and output. max_act = 2 * (max(qkv, ffn) + residual) + # Size of output logits. + output_logits = 2 * (max_num_batched_tokens * self.vocab_size) + max_act = max(max_act, output_logits) dtype_size = get_dtype_size(self.dtype) return dtype_size * max_act - def _get_workspace_size(self) -> int: - return 1 * _GiB - - def _get_cache_block_size(self) -> int: - key_cache_block = self.block_size * self.num_heads * self.head_size - value_cache_block = self.block_size * self.num_heads * self.head_size + def get_cache_block_size(self) -> int: + key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size + value_cache_block = key_cache_block total = self.num_layers * (key_cache_block + value_cache_block) dtype_size = get_dtype_size(self.dtype) return dtype_size * total @@ -112,26 +131,111 @@ def get_max_num_gpu_blocks( param_size = self._get_param_size() act_size = self._get_max_act_size(max_num_batched_tokens) - workspace_size = self._get_workspace_size() + workspace_size = self.get_workspace_size() max_cache_size = usable_memory - (param_size + act_size + workspace_size) - max_num_blocks = max_cache_size // self._get_cache_block_size() + if max_cache_size <= 0: + raise RuntimeError('Not enough GPU memory.') + max_num_blocks = max_cache_size // self.get_cache_block_size() return max_num_blocks - def get_max_num_cpu_blocks( + +class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer): + + def __init__( self, - swap_space: int, + model_name: str, + block_size: int, + dtype: torch.dtype, + gpu_memory: int, + cpu_memory: int, + tensor_parallel_size: int, + ) -> None: + self.model_name = model_name + self.block_size = block_size + self.dtype = dtype + self.gpu_memory = gpu_memory + self.cpu_memory = cpu_memory + self.tensor_parallel_size = tensor_parallel_size + + config = AutoConfig.from_pretrained(model_name) + self.num_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // self.num_heads + self.ffn_size = config.intermediate_size + self.vocab_size = config.vocab_size + # FIXME + self.max_position = 2048 + + def _get_param_size(self) -> int: + word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size + position_embedding = self.max_position * self.hidden_size + + # NOTE: LLaMA does not have bias terms. + ln1 = self.hidden_size + q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + # Rotary embedding. + # TODO(woosuk): Share the rotary embedding between layers. + rot = self.max_position * self.head_size + mha = ln1 + q + k + v + out + rot + + ln2 = self.hidden_size + gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size + down = self.ffn_size * self.hidden_size // self.tensor_parallel_size + up = self.hidden_size * self.ffn_size // self.tensor_parallel_size + ffn = ln2 + gate + down + up + + total = (word_embedding + position_embedding + self.num_layers * (mha + ffn)) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def _get_max_act_size( + self, + max_num_batched_tokens: int, ) -> int: - swap_space = swap_space * _GiB - if swap_space > 0.8 * self.cpu_memory: - raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' - 'takes more than 80% of the available memory ' - f'({self.cpu_memory / _GiB:.2f} GiB).' - 'Please check the swap space size.') - if swap_space > 0.5 * self.cpu_memory: - print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' - 'takes more than 50% of the available memory ' - f'({self.cpu_memory / _GiB:.2f} GiB).' - 'This may slow the system performance.') - max_num_blocks = swap_space // self._get_cache_block_size() + # NOTE: We approxmiately calculate the maximum activation size by + # estimating + # 1) the maximum activation tensor size during inference + # 2) the residual tensor size during inference + # Here, we assume that FlashAttention is used and + # thus the attention maps are never materialized in GPU DRAM. + residual = max_num_batched_tokens * self.hidden_size + qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size + ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size + # Double the activation size for input and output. + max_act = 2 * (max(qkv, ffn) + residual) + # Size of output logits. + output_logits = 2 * (max_num_batched_tokens * self.vocab_size) + max_act = max(max_act, output_logits) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * max_act + + def get_cache_block_size(self) -> int: + key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size + value_cache_block = key_cache_block + total = self.num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def get_max_num_gpu_blocks( + self, + max_num_batched_tokens: int, + memory_utilization: float = 0.95, + ) -> int: + # NOTE(woosuk): This assumes that the machine has homogeneous GPUs. + gpu_memory = self.gpu_memory + usable_memory = int(memory_utilization * gpu_memory) + + param_size = self._get_param_size() + act_size = self._get_max_act_size(max_num_batched_tokens) + workspace_size = self.get_workspace_size() + + max_cache_size = usable_memory - (param_size + act_size + workspace_size) + if max_cache_size <= 0: + raise RuntimeError('Not enough GPU memory.') + max_num_blocks = max_cache_size // self.get_cache_block_size() return max_num_blocks diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 3a2a6b2b5a35a..aaf81bc2b5130 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -6,16 +6,20 @@ from transformers import AutoConfig from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer +from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer +from cacheflow.models.llama import LlamaForCausalLM from cacheflow.models.opt import OPTForCausalLM from cacheflow.models.utils import get_torch_dtype _MODELS = { + 'llama': LlamaForCausalLM, 'opt': OPTForCausalLM, } _MEMORY_ANALYZERS = { + 'llama': LlamaMemoryAnalyzer, 'opt': OPTMemoryAnalyzer, } @@ -31,7 +35,7 @@ def get_model( for model_class_name, model_class in _MODELS.items(): if model_class_name in model_name: # Download model weights if it's not cached. - weights_dir = model_class.download_weights(model_name, path=path) + weights_dir = model_class.get_weights(model_name, path=path) # Create a model instance. model = model_class(config) # Load the weights from the cached or downloaded files. diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index e9d8e853cc082..3a7e6a1103855 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -299,7 +299,7 @@ def load_weights(self, weights_path: str): param.data.copy_(loaded_weight) @staticmethod - def download_weights(model_name: str, path: str): + def get_weights(model_name: str, path: str): path = os.path.join(path, f"{model_name}-np") path = os.path.abspath(os.path.expanduser(path)) os.makedirs(path, exist_ok=True) @@ -316,11 +316,8 @@ def download_weights(model_name: str, path: str): cache_dir=os.path.join(path, "cache")) bin_files = glob.glob(os.path.join(folder, "*.bin")) - if "/" in model_name: - model_name = model_name.split("/")[1].lower() - for bin_file in tqdm(bin_files, desc="Convert format"): - state = torch.load(bin_file) + state = torch.load(bin_file, map_location="cpu") for name, param in tqdm(state.items(), leave=False): if name.startswith("decoder."): name = "model." + name diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 8cbe73365a3f5..6c7dcbedd3b26 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -39,7 +39,7 @@ def forward( # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities (before applying top-p). - logprobs = torch.log(probs) + logprobs = torch.log(probs, out=logits) # Apply top-p truncation. top_ps = _get_top_ps(input_metadata)