From 1c80c4a4746f21a523441db27a0da13c92a8febf Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 25 Mar 2023 23:35:32 +0000 Subject: [PATCH 01/14] Add sentencepiece and Ray to dependency --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7008d8a204cbf..dd4cb27df4f09 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,9 @@ ## Installation ```bash -pip install psutil numpy torch transformers -pip install flash-attn # This may take up to 10 mins. +pip install psutil numpy ray torch transformers +pip install sentencepiece # Required for LlamaTokenizer. +pip install flash-attn # This may take up to 20 mins. pip install -e . ``` From 6ed70571f1560267865a516bd51f98c97516c258 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 26 Mar 2023 05:34:18 +0000 Subject: [PATCH 02/14] Implement LLaMA --- cacheflow/models/llama.py | 307 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 cacheflow/models/llama.py diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py new file mode 100644 index 0000000000000..7f6ea2b018ac9 --- /dev/null +++ b/cacheflow/models/llama.py @@ -0,0 +1,307 @@ +"""1D LLaMA model compatible with HuggingFace weights.""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import LlamaConfig +from transformers import PreTrainedModel + +from cacheflow.models import InputMetadata +from cacheflow.models.attention import OPTCacheFlowAttention +from cacheflow.models.sample import Sampler +from cacheflow.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) +from cacheflow.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlamaRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.max_position_embeddings = max_position_embeddings + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Create cos and sin embeddings. + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, self.inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=self.inv_freq.dtype) + sin = emb.sin().to(dtype=self.inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward(self, positions: torch.LongTensor): + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin): + # TODO: Optimize. + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + # TODO: Merge the gate and down linear layers. + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + assert hidden_act == 'silu' + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scaling = self.head_dim ** -0.5 + + # TODO: Merge the QKV linear layers. + self.q_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.k_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.v_proj = nn.Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear( + num_heads * self.head_dim, + hidden_size, + bias=False, + ) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) + # FIXME(woosuk): Rename this. + self.attn = OPTCacheFlowAttention(scale=self.scaling) + + def forward( + self, + positions: torch.LongTensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Apply rotrary embedding. + # TODO: Optimize. + cos, sin = self.rotary_emb(positions) + q_t = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1) + k_t = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) + q_t, k_t = apply_rotary_pos_emb(q_t, k_t, cos, sin) + q = q_t.transpose(0, 1).contiguous().view(-1, self.hidden_size) + k = k_t.transpose(0, 1).contiguous().view(-1, self.hidden_size) + + key_cache, value_cache = kv_cache + attn_output = self.attn( + q, k, v, key_cache, value_cache, input_metadata, cache_event) + output = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.LongTensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + pass + + +class LlamaModel(LlamaPreTrainedModel): + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.sampler = Sampler() + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.model( + input_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, input_metadata) + return next_tokens From bd8664cd76f7f75d3fa4b0530980f201b26834e3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 26 Mar 2023 05:35:56 +0000 Subject: [PATCH 03/14] Implement LLaMA memory analyzer --- cacheflow/models/memory_analyzer.py | 137 +++++++++++++++++++++++----- 1 file changed, 113 insertions(+), 24 deletions(-) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 69675588c3c43..0fab357d8976f 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -17,11 +17,30 @@ def get_max_num_gpu_blocks( ) -> int: raise NotImplementedError() + def get_workspace_size(self) -> int: + return 1 * _GiB + + def get_cache_block_size(self) -> int: + raise NotImplementedError() + def get_max_num_cpu_blocks( self, - memory_utilization: float, + swap_space: int, ) -> int: - raise NotImplementedError() + swap_space = swap_space * _GiB + cpu_memory = get_cpu_memory() + if swap_space > 0.8 * cpu_memory: + raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' + 'takes more than 80% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'Please check the swap space size.') + if swap_space > 0.5 * cpu_memory: + print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' + 'takes more than 50% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'This may slow the system performance.') + max_num_blocks = swap_space // self.get_cache_block_size() + return max_num_blocks class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): @@ -50,9 +69,9 @@ def __init__( def _get_param_size(self) -> int: word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size - if self.embedding_size != self.vocab_size: + if self.embedding_size != self.hidden_size: # Project in/out. - word_embedding += 2 * self.embedding_size * self.vocab_size + word_embedding += 2 * self.embedding_size * self.hidden_size position_embedding = self.max_position * self.hidden_size ln1 = 2 * self.hidden_size @@ -90,10 +109,7 @@ def _get_max_act_size( dtype_size = get_dtype_size(self.dtype) return dtype_size * max_act - def _get_workspace_size(self) -> int: - return 1 * _GiB - - def _get_cache_block_size(self) -> int: + def get_cache_block_size(self) -> int: key_cache_block = self.block_size * self.num_heads * self.head_size value_cache_block = self.block_size * self.num_heads * self.head_size total = self.num_layers * (key_cache_block + value_cache_block) @@ -111,27 +127,100 @@ def get_max_num_gpu_blocks( param_size = self._get_param_size() act_size = self._get_max_act_size(max_num_batched_tokens) - workspace_size = self._get_workspace_size() + workspace_size = self.get_workspace_size() max_cache_size = usable_memory - (param_size + act_size + workspace_size) max_num_blocks = max_cache_size // self._get_cache_block_size() return max_num_blocks - def get_max_num_cpu_blocks( + +class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer): + + def __init__( self, - swap_space: int, + model_name: str, + block_size: int, + dtype: torch.dtype, + tensor_parallel_size: int, + ) -> None: + self.model_name = model_name + self.block_size = block_size + self.dtype = dtype + self.tensor_parallel_size = tensor_parallel_size + + config = AutoConfig.from_pretrained(model_name) + self.num_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // self.num_heads + self.ffn_size = config.intermediate_size + self.vocab_size = config.vocab_size + # FIXME + self.max_position = 2048 + + def _get_param_size(self) -> int: + word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size + position_embedding = self.max_position * self.hidden_size + + # NOTE: LLaMA does not have bias terms. + ln1 = self.hidden_size + q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + # Rotary embedding. + # TODO(woosuk): Share the rotary embedding between layers. + rot = self.max_position * self.head_size + mha = ln1 + q + k + v + out + rot + + ln2 = self.hidden_size + gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size + down = self.ffn_size * self.hidden_size // self.tensor_parallel_size + up = self.hidden_size * self.ffn_size // self.tensor_parallel_size + ffn = ln2 + gate + down + up + + total = (word_embedding + position_embedding + self.num_layers * (mha + ffn)) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def _get_max_act_size( + self, + max_num_batched_tokens: int, ) -> int: - swap_space = swap_space * _GiB - cpu_memory = get_cpu_memory() - if swap_space > 0.8 * cpu_memory: - raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' - 'takes more than 80% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' - 'Please check the swap space size.') - if swap_space > 0.5 * cpu_memory: - print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' - 'takes more than 50% of the available memory ' - f'({cpu_memory / _GiB:.2f} GiB).' - 'This may slow the system performance.') - max_num_blocks = swap_space // self._get_cache_block_size() + # NOTE: We approxmiately calculate the maximum activation size by + # estimating + # 1) the maximum activation tensor size during inference + # 2) the residual tensor size during inference + # Here, we assume that FlashAttention is used and + # thus the attention maps are never materialized in GPU DRAM. + residual = max_num_batched_tokens * self.hidden_size + qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size + ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size + # Double the activation size for input and output. + max_act = 2 * (max(qkv, ffn) + residual) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * max_act + + def get_cache_block_size(self) -> int: + key_cache_block = self.block_size * self.num_heads * self.head_size + value_cache_block = self.block_size * self.num_heads * self.head_size + total = self.num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def get_max_num_gpu_blocks( + self, + max_num_batched_tokens: int, + memory_utilization: float = 0.95, + ) -> int: + # NOTE(woosuk): This assumes that the machine has homogeneous GPUs. + gpu_memory = get_gpu_memory() + usable_memory = int(memory_utilization * gpu_memory) + + param_size = self._get_param_size() + act_size = self._get_max_act_size(max_num_batched_tokens) + workspace_size = self.get_workspace_size() + + max_cache_size = usable_memory - (param_size + act_size + workspace_size) + max_num_blocks = max_cache_size // self.get_cache_block_size() return max_num_blocks From 4ea2216b2e7a32b18cbbcfd5814b55e411eb411e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 26 Mar 2023 05:36:17 +0000 Subject: [PATCH 04/14] [WIP] Add LLaMA model loader --- cacheflow/models/model_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index b1fdacea075ae..369b57acf6315 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -6,16 +6,20 @@ from transformers import AutoConfig from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer +from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer +from cacheflow.models.llama import LlamaForCausalLM from cacheflow.models.opt import OPTForCausalLM from cacheflow.models.utils import get_torch_dtype _MODELS = { + 'llama': LlamaForCausalLM, 'opt': OPTForCausalLM, } _MEMORY_ANALYZERS = { + 'llama': LlamaMemoryAnalyzer, 'opt': OPTMemoryAnalyzer, } @@ -28,6 +32,11 @@ def get_model( torch_dtype = get_torch_dtype(dtype) torch.set_default_dtype(torch_dtype) config = AutoConfig.from_pretrained(model_name) + # FIXME + if 'llama' in model_name: + model = LlamaForCausalLM.from_pretrained(model_name) + return model.eval(), torch_dtype + for model_class_name, model_class in _MODELS.items(): if model_class_name in model_name: # Download model weights if it's not cached. From 742ac4a087660ea1e79bbf5a7760bfa22e53a309 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 26 Mar 2023 05:37:27 +0000 Subject: [PATCH 05/14] Fix README --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index dd4cb27df4f09..9dedb08752f3a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ ## Installation ```bash -pip install psutil numpy ray torch transformers +pip install psutil numpy ray torch +pip install git+https://github.com/huggingface/transformers # Required for LLaMA. pip install sentencepiece # Required for LlamaTokenizer. pip install flash-attn # This may take up to 20 mins. pip install -e . From f3d5e78a51295e4c978f6f27e77147da9cc718b8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 26 Mar 2023 05:37:53 +0000 Subject: [PATCH 06/14] Minor --- cacheflow/master/frontend.py | 1 + cacheflow/sequence.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cacheflow/master/frontend.py b/cacheflow/master/frontend.py index cfa17684fd56a..6aaa83b1c97e9 100644 --- a/cacheflow/master/frontend.py +++ b/cacheflow/master/frontend.py @@ -76,4 +76,5 @@ def print_response( for seq in seq_group.seqs: token_ids = seq.get_token_ids() output = self.tokenizer.decode(token_ids, skip_special_tokens=True) + output = output.strip() print(f'Seq {seq.seq_id}: {output!r}') diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 471052fbd5a94..8cdd977237f1e 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -30,7 +30,7 @@ def __init__( self.status = SequenceStatus.PENDING self.output_logprobs: List[Dict[int, float]] = [] - self.cumulative_logprobs = 1.0 + self.cumulative_logprobs = 0.0 def add_block(self) -> None: block = LogicalTokenBlock( From d2e08a23b79035eb1c7ff482146939983c014841 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 26 Mar 2023 07:10:22 +0000 Subject: [PATCH 07/14] Minor --- cacheflow/models/llama.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 7f6ea2b018ac9..4471683b29372 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -54,7 +54,10 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000): self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("sin_cached", sin, persistent=False) - def forward(self, positions: torch.LongTensor): + def forward( + self, + positions: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: cos = F.embedding(positions, self.cos_cached) sin = F.embedding(positions, self.sin_cached) return cos, sin @@ -145,12 +148,12 @@ def forward( # Apply rotrary embedding. # TODO: Optimize. + q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) cos, sin = self.rotary_emb(positions) - q_t = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1) - k_t = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) - q_t, k_t = apply_rotary_pos_emb(q_t, k_t, cos, sin) - q = q_t.transpose(0, 1).contiguous().view(-1, self.hidden_size) - k = k_t.transpose(0, 1).contiguous().view(-1, self.hidden_size) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + q = q.transpose(0, 1).contiguous().view(-1, self.hidden_size) + k = k.transpose(0, 1).contiguous().view(-1, self.hidden_size) key_cache, value_cache = kv_cache attn_output = self.attn( @@ -226,7 +229,6 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() From 48b6dd1c1a0d7882d6971b05664b90567ef8d8b4 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 28 Mar 2023 14:58:35 +0000 Subject: [PATCH 08/14] [WIP] Incorrect TP implementation --- cacheflow/models/llama.py | 169 ++++++++++++++++++++------------ cacheflow/models/model_utils.py | 7 +- cacheflow/models/opt.py | 7 +- 3 files changed, 110 insertions(+), 73 deletions(-) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 4471683b29372..5db6e371fc26c 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -1,6 +1,11 @@ """1D LLaMA model compatible with HuggingFace weights.""" +import os +import glob +import filelock +from tqdm import tqdm from typing import Dict, List, Optional, Tuple +import numpy as np import torch from torch import nn import torch.nn.functional as F @@ -86,14 +91,24 @@ def __init__( ): super().__init__() # TODO: Merge the gate and down linear layers. - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, + bias=False, gather_output=False, + perform_initialization=False) + self.down_proj = RowParallelLinear(intermediate_size, hidden_size, + bias=False, input_is_parallel=True, + perform_initialization=False) + self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, + bias=False, gather_output=False, + perform_initialization=False) assert hidden_act == 'silu' self.act_fn = nn.SiLU() def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + gate, _ = self.gate_proj(x) + up, _ = self.up_proj(x) + x = self.act_fn(gate) * up + x, _ = self.down_proj(x) + return x class LlamaAttention(nn.Module): @@ -110,25 +125,33 @@ def __init__( self.scaling = self.head_dim ** -0.5 # TODO: Merge the QKV linear layers. - self.q_proj = nn.Linear( + self.q_proj = ColumnParallelLinear( hidden_size, num_heads * self.head_dim, bias=False, + gather_output=False, + perform_initialization=False, ) - self.k_proj = nn.Linear( + self.k_proj = ColumnParallelLinear( hidden_size, num_heads * self.head_dim, bias=False, + gather_output=False, + perform_initialization=False, ) - self.v_proj = nn.Linear( + self.v_proj = ColumnParallelLinear( hidden_size, num_heads * self.head_dim, bias=False, + gather_output=False, + perform_initialization=False, ) - self.o_proj = nn.Linear( + self.o_proj = RowParallelLinear( num_heads * self.head_dim, hidden_size, bias=False, + input_is_parallel=True, + perform_initialization=False, ) self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) # FIXME(woosuk): Rename this. @@ -142,9 +165,9 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) # Apply rotrary embedding. # TODO: Optimize. @@ -158,7 +181,7 @@ def forward( key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) - output = self.o_proj(attn_output) + output, _ = self.o_proj(attn_output) return output @@ -207,37 +230,19 @@ def forward( return hidden_states -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - pass - - -class LlamaModel(LlamaPreTrainedModel): +class LlamaModel(nn.Module): def __init__(self, config: LlamaConfig): - super().__init__(config) + super().__init__() + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, + perform_initialization=False) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - def forward( self, input_ids: torch.LongTensor, @@ -264,36 +269,16 @@ def forward( return hidden_states -class LlamaForCausalLM(LlamaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - +class LlamaForCausalLM(nn.Module): def __init__(self, config): - super().__init__(config) + super().__init__() + self.config = config self.model = LlamaModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.model.embed_tokens.weight self.sampler = Sampler() - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - def forward( self, input_ids: torch.LongTensor, @@ -305,5 +290,65 @@ def forward( hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( - self.lm_head.weight, hidden_states, input_metadata) + self.lm_head_weight, hidden_states, input_metadata) return next_tokens + + _column_parallel_weights = ["embed_tokens.weight", + "q_proj.weight", "k_proj.weight", + "v_proj.weight", "gate_proj.weight", + "up_proj.weight"] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + def load_weights(self, weights_path: str): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, param in state_dict.items(): + if "lm_head_weight" in name: + continue + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, + name))) + for p in self._column_parallel_weights: + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + @staticmethod + def get_weights(model_name: str, path: str): + if not os.path.isfile(os.path.join(model_name, "config.json")): + raise ValueError("LLaMA model's model_name has to be a path" + "to the huggingface model's directory.") + path = os.path.join(model_name, f"np") + path = os.path.abspath(os.path.expanduser(path)) + os.makedirs(path, exist_ok=True) + lock_path = os.path.join(path, "file_lock") + lock = filelock.FileLock(lock_path) + + with lock: + test_weight_path = os.path.join(path, "model.embed_tokens.weight") + if os.path.exists(test_weight_path): + return path + + bin_files = glob.glob(os.path.join(model_name, "*.bin")) + + for bin_file in tqdm(bin_files, desc="Convert format"): + state = torch.load(bin_file, map_location="cpu") + for name, param in tqdm(state.items(), leave=False): + param_path = os.path.join(path, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + + return path diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 369b57acf6315..afbef25e5bcb1 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -32,15 +32,10 @@ def get_model( torch_dtype = get_torch_dtype(dtype) torch.set_default_dtype(torch_dtype) config = AutoConfig.from_pretrained(model_name) - # FIXME - if 'llama' in model_name: - model = LlamaForCausalLM.from_pretrained(model_name) - return model.eval(), torch_dtype - for model_class_name, model_class in _MODELS.items(): if model_class_name in model_name: # Download model weights if it's not cached. - weights_dir = model_class.download_weights(model_name, path=path) + weights_dir = model_class.get_weights(model_name, path=path) # Create a model instance. model = model_class(config) # Load the weights from the cached or downloaded files. diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index e9d8e853cc082..3a7e6a1103855 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -299,7 +299,7 @@ def load_weights(self, weights_path: str): param.data.copy_(loaded_weight) @staticmethod - def download_weights(model_name: str, path: str): + def get_weights(model_name: str, path: str): path = os.path.join(path, f"{model_name}-np") path = os.path.abspath(os.path.expanduser(path)) os.makedirs(path, exist_ok=True) @@ -316,11 +316,8 @@ def download_weights(model_name: str, path: str): cache_dir=os.path.join(path, "cache")) bin_files = glob.glob(os.path.join(folder, "*.bin")) - if "/" in model_name: - model_name = model_name.split("/")[1].lower() - for bin_file in tqdm(bin_files, desc="Convert format"): - state = torch.load(bin_file) + state = torch.load(bin_file, map_location="cpu") for name, param in tqdm(state.items(), leave=False): if name.startswith("decoder."): name = "model." + name From 914416785f01e0b3320ae61c7cafb890b6126cd1 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 28 Mar 2023 15:56:22 +0000 Subject: [PATCH 09/14] fix lm_head --- cacheflow/models/llama.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 5db6e371fc26c..e772adfd0ab1b 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -274,9 +274,11 @@ def __init__(self, config): super().__init__() self.config = config self.model = LlamaModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = ColumnParallelLinear(config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False) self.sampler = Sampler() def forward( @@ -290,7 +292,7 @@ def forward( hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( - self.lm_head_weight, hidden_states, input_metadata) + self.lm_head.weight, hidden_states, input_metadata) return next_tokens _column_parallel_weights = ["embed_tokens.weight", @@ -303,8 +305,6 @@ def load_weights(self, weights_path: str): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, param in state_dict.items(): - if "lm_head_weight" in name: - continue loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name))) for p in self._column_parallel_weights: From fcb4f957f56b71092c84e69f2b98f6b5fcb96e2a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 29 Mar 2023 07:18:11 +0000 Subject: [PATCH 10/14] fix merge error --- cacheflow/models/memory_analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 855d5c3a17a7d..c5a2a123e2f5b 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -131,7 +131,7 @@ def get_max_num_gpu_blocks( workspace_size = self.get_workspace_size() max_cache_size = usable_memory - (param_size + act_size + workspace_size) - max_num_blocks = max_cache_size // self._get_cache_block_size() + max_num_blocks = max_cache_size // self.get_cache_block_size() return max_num_blocks From 91ec0c54572d2bc129acbb14e70a1fb1a15d5d39 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 29 Mar 2023 09:28:27 +0000 Subject: [PATCH 11/14] fix distributed execution --- cacheflow/models/llama.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index e772adfd0ab1b..c515cc97ff43f 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -120,34 +120,37 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = hidden_size // self.total_num_heads self.scaling = self.head_dim ** -0.5 # TODO: Merge the QKV linear layers. self.q_proj = ColumnParallelLinear( hidden_size, - num_heads * self.head_dim, + self.total_num_heads * self.head_dim, bias=False, gather_output=False, perform_initialization=False, ) self.k_proj = ColumnParallelLinear( hidden_size, - num_heads * self.head_dim, + self.total_num_heads * self.head_dim, bias=False, gather_output=False, perform_initialization=False, ) self.v_proj = ColumnParallelLinear( hidden_size, - num_heads * self.head_dim, + self.total_num_heads * self.head_dim, bias=False, gather_output=False, perform_initialization=False, ) self.o_proj = RowParallelLinear( - num_heads * self.head_dim, + self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, @@ -295,7 +298,7 @@ def forward( self.lm_head.weight, hidden_states, input_metadata) return next_tokens - _column_parallel_weights = ["embed_tokens.weight", + _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight", "gate_proj.weight", "up_proj.weight"] From 7e727475feb04f10ce3b96d8b4a091937b936504 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Mar 2023 18:14:43 +0000 Subject: [PATCH 12/14] Fix a bug in memory analyzer when using TP --- cacheflow/models/memory_analyzer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index c5a2a123e2f5b..02cd764a3a1bc 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -112,8 +112,8 @@ def _get_max_act_size( return dtype_size * max_act def get_cache_block_size(self) -> int: - key_cache_block = self.block_size * self.num_heads * self.head_size - value_cache_block = self.block_size * self.num_heads * self.head_size + key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size + value_cache_block = key_cache_block total = self.num_layers * (key_cache_block + value_cache_block) dtype_size = get_dtype_size(self.dtype) return dtype_size * total @@ -131,6 +131,8 @@ def get_max_num_gpu_blocks( workspace_size = self.get_workspace_size() max_cache_size = usable_memory - (param_size + act_size + workspace_size) + if max_cache_size <= 0: + raise RuntimeError('Not enough GPU memory.') max_num_blocks = max_cache_size // self.get_cache_block_size() return max_num_blocks @@ -207,8 +209,8 @@ def _get_max_act_size( return dtype_size * max_act def get_cache_block_size(self) -> int: - key_cache_block = self.block_size * self.num_heads * self.head_size - value_cache_block = self.block_size * self.num_heads * self.head_size + key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size + value_cache_block = key_cache_block total = self.num_layers * (key_cache_block + value_cache_block) dtype_size = get_dtype_size(self.dtype) return dtype_size * total @@ -227,5 +229,7 @@ def get_max_num_gpu_blocks( workspace_size = self.get_workspace_size() max_cache_size = usable_memory - (param_size + act_size + workspace_size) + if max_cache_size <= 0: + raise RuntimeError('Not enough GPU memory.') max_num_blocks = max_cache_size // self.get_cache_block_size() return max_num_blocks From dd385488607ea64c920d961d342e5dfbcfa1dd53 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 30 Mar 2023 03:59:37 +0000 Subject: [PATCH 13/14] Fix the shape of q and k --- cacheflow/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index c515cc97ff43f..4ddbc698eb789 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -178,8 +178,8 @@ def forward( k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) cos, sin = self.rotary_emb(positions) q, k = apply_rotary_pos_emb(q, k, cos, sin) - q = q.transpose(0, 1).contiguous().view(-1, self.hidden_size) - k = k.transpose(0, 1).contiguous().view(-1, self.hidden_size) + q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) + k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) key_cache, value_cache = kv_cache attn_output = self.attn( From ffd964473c540b85fd13360217f054417bd7cb2c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 30 Mar 2023 04:23:49 +0000 Subject: [PATCH 14/14] fix memory analyzer --- cacheflow/models/memory_analyzer.py | 6 ++++++ cacheflow/models/sample.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 02cd764a3a1bc..d3dc8f44bbf98 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -108,6 +108,9 @@ def _get_max_act_size( ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size # Double the activation size for input and output. max_act = 2 * (max(qkv, ffn) + residual) + # Size of output logits. + output_logits = 2 * (max_num_batched_tokens * self.vocab_size) + max_act = max(max_act, output_logits) dtype_size = get_dtype_size(self.dtype) return dtype_size * max_act @@ -205,6 +208,9 @@ def _get_max_act_size( ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size # Double the activation size for input and output. max_act = 2 * (max(qkv, ffn) + residual) + # Size of output logits. + output_logits = 2 * (max_num_batched_tokens * self.vocab_size) + max_act = max(max_act, output_logits) dtype_size = get_dtype_size(self.dtype) return dtype_size * max_act diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 8cbe73365a3f5..6c7dcbedd3b26 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -39,7 +39,7 @@ def forward( # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities (before applying top-p). - logprobs = torch.log(probs) + logprobs = torch.log(probs, out=logits) # Apply top-p truncation. top_ps = _get_top_ps(input_metadata)