From 3ebf0dbcc1b7b322556939dc98d2b6c727cfe0d0 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Fri, 26 Jul 2024 16:30:06 -0700 Subject: [PATCH 1/7] Add OLMoE --- docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/olmoe.py | 424 ++++++++++++++++++++++++ 3 files changed, 429 insertions(+) create mode 100644 vllm/model_executor/models/olmoe.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index dc8bd6fb245df..f07880fdbbeea 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -113,6 +113,10 @@ Decoder-only Language Models - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - + * - :code:`OLMoEForCausalLM` + - OLMoE + - :code:`allenai/OLMoE-7B-A1B`, :code:`allenai/OLMoE-7B-A1B`, etc. + - * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ead64c0e92553..8b3f2e13d135a 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -53,6 +53,7 @@ "MiniCPMV": ("minicpmv", "MiniCPMV"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py new file mode 100644 index 0000000000000..fb3766d263892 --- /dev/null +++ b/vllm/model_executor/models/olmoe.py @@ -0,0 +1,424 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import print_warning_once + + +class OlmoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoeSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config) + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class OlmoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + self.q_layernorm = RMSNorm(hidden_size, eps=1e-5) + self.k_layernorm = RMSNorm(hidden_size, eps=1e-5) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = OlmoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + + self.mlp = OlmoeSparseMoeBlock(config=config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, 1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, 1e-5) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class OlmoeModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + OlmoeDecoderLayer(config, layer_idx, cache_config, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], attn_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class OlmoeForCausalLM(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, kv_caches, attn_metadata) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.sampler(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 9d11e1265fbe2840c2649d9bfa482f1813d3a279 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sun, 28 Jul 2024 15:13:46 +0000 Subject: [PATCH 2/7] Add OLMoE --- vllm/model_executor/models/olmoe.py | 176 ++++++++++++++++------------ 1 file changed, 102 insertions(+), 74 deletions(-) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index fb3766d263892..dfa574d47a056 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -9,7 +9,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only OLMoE model compatible with HuggingFace weights.""" +"""Inference-only OLMoE model compatible with HuggingFace weights. + +from vllm import LLM +llm = LLM(model="allenai/OLMoE-7B-A1B") +print(llm.generate("Bitcoin is")) + +from vllm import LLM, SamplingParams +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +import torch +from transformers import OlmoeForCausalLM, AutoTokenizer +model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-7B-A1B", torch_dtype=torch.bfloat16).cuda() +tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-7B-A1B") +inputs = tokenizer("Bitcoin is", return_tensors="pt") +inputs = {k: v.cuda() for k, v in inputs.items()} +out = model.generate(**inputs, max_length=64) +print(tokenizer.decode(out[0])) +""" from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -41,82 +70,52 @@ from vllm.utils import print_warning_once -class OlmoeMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - ) -> None: +class OlmoeMoE(nn.Module): + """A tensor-parallel MoE implementation for Olmoe that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class OlmoeSparseMoeBlock(nn.Module): + self.hidden_size = hidden_size - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.experts = FusedMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config) - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, bias=False, quant_config=None) + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + #import pdb; pdb.set_trace() # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - + final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) - class OlmoeAttention(nn.Module): def __init__( @@ -126,7 +125,7 @@ def __init__( num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, + max_position_embeddings: int = 4096, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -158,11 +157,11 @@ def __init__( self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=True, + bias=False, quant_config=quant_config, ) - self.q_layernorm = RMSNorm(hidden_size, eps=1e-5) - self.k_layernorm = RMSNorm(hidden_size, eps=1e-5) + self.q_norm = RMSNorm(hidden_size, eps=1e-5) + self.k_norm = RMSNorm(hidden_size, eps=1e-5) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -176,6 +175,7 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + is_neox_style=True, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -191,8 +191,10 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + #import pdb; pdb.set_trace() qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.q_norm(q), self.k_norm(k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -213,7 +215,9 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + 4096) + + #""" self.self_attn = OlmoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -224,11 +228,25 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) + #""" + + + + self.mlp = OlmoeMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) - self.mlp = OlmoeSparseMoeBlock(config=config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, 1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, 1e-5) + #from transformers import OlmoeForCausalLM + #self.model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-7B-A1B", torch_dtype=torch.bfloat16).cuda() + self.layer_idx = layer_idx + def forward( self, positions: torch.Tensor, @@ -237,6 +255,7 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: + #import pdb; pdb.set_trace() # Self Attention if residual is None: residual = hidden_states @@ -244,12 +263,21 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( + """ + hidden_states_old = self.self_attn_old( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) + """ + from transformers import OlmoeForCausalLM + model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-7B-A1B", torch_dtype=torch.bfloat16) + self_attn = model.model.layers[self.layer_idx].self_attn.cuda() + hidden_states = self_attn( + hidden_states=hidden_states.unsqueeze(0), + position_ids=torch.arange(hidden_states.size(0)).unsqueeze(0).cuda(), + )[0] # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -287,6 +315,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: + #import pdb; pdb.set_trace() hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): @@ -360,8 +389,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue + if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: From 58feb19aace551e608d98fffeeb38596030d1bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=BA=E8=9B=B0?= Date: Tue, 27 Aug 2024 23:50:37 +0800 Subject: [PATCH 3/7] update: olmoe --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/olmoe.py | 102 ++++++++++-------------- 2 files changed, 43 insertions(+), 61 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f07880fdbbeea..0b9fc80daa65e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -115,7 +115,7 @@ Decoder-only Language Models - * - :code:`OLMoEForCausalLM` - OLMoE - - :code:`allenai/OLMoE-7B-A1B`, :code:`allenai/OLMoE-7B-A1B`, etc. + - :code:`OLMoE/OLMoE-1B-7B-0824`, :code:`OLMoE/OLMoE-1B-7B-0824`, etc. - * - :code:`OLMoForCausalLM` - OLMo diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index dfa574d47a056..314b28e432b0a 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -9,36 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only OLMoE model compatible with HuggingFace weights. - -from vllm import LLM -llm = LLM(model="allenai/OLMoE-7B-A1B") -print(llm.generate("Bitcoin is")) - -from vllm import LLM, SamplingParams -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -import torch -from transformers import OlmoeForCausalLM, AutoTokenizer -model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-7B-A1B", torch_dtype=torch.bfloat16).cuda() -tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-7B-A1B") -inputs = tokenizer("Bitcoin is", return_tensors="pt") -inputs = {k: v.cuda() for k, v in inputs.items()} -out = model.generate(**inputs, max_length=64) -print(tokenizer.decode(out[0])) -""" +"""Inference-only OLMoE model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -107,15 +78,17 @@ def __init__(self, tp_size=tp_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - #import pdb; pdb.set_trace() # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) return final_hidden_states.view(orig_shape) + class OlmoeAttention(nn.Module): def __init__( @@ -191,7 +164,6 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - #import pdb; pdb.set_trace() qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.q_norm(q), self.k_norm(k) @@ -200,6 +172,8 @@ def forward( output, _ = self.o_proj(attn_output) return output +# from transformers import OlmoeForCausalLM +# model = OlmoeForCausalLM.from_pretrained("OLMoE/OLMoE-1B-7B-0824", torch_dtype=torch.bfloat16) class OlmoeDecoderLayer(nn.Module): @@ -217,7 +191,6 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 4096) - #""" self.self_attn = OlmoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -228,9 +201,6 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) - #""" - - self.mlp = OlmoeMoE( num_experts=config.num_experts, @@ -239,13 +209,13 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=1e-5) - self.input_layernorm = RMSNorm(config.hidden_size, 1e-5) - self.post_attention_layernorm = RMSNorm(config.hidden_size, 1e-5) + # self.layer_idx = layer_idx - #from transformers import OlmoeForCausalLM - #self.model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-7B-A1B", torch_dtype=torch.bfloat16).cuda() - self.layer_idx = layer_idx def forward( self, @@ -255,7 +225,6 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: - #import pdb; pdb.set_trace() # Self Attention if residual is None: residual = hidden_states @@ -263,21 +232,27 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - """ - hidden_states_old = self.self_attn_old( + + # print("#-"*20) + # print(positions) + # print(hidden_states) + # print("#-"*20) + + # self_attn = model.model.layers[self.layer_idx].self_attn.cuda() + # hidden_states_old = self_attn( + # hidden_states=hidden_states.unsqueeze(0), + # position_ids=positions.unsqueeze(0) + # )[0][0] + # print("old:", hidden_states_old) + + + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) - """ - from transformers import OlmoeForCausalLM - model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-7B-A1B", torch_dtype=torch.bfloat16) - self_attn = model.model.layers[self.layer_idx].self_attn.cuda() - hidden_states = self_attn( - hidden_states=hidden_states.unsqueeze(0), - position_ids=torch.arange(hidden_states.size(0)).unsqueeze(0).cuda(), - )[0] + # print("new: ", hidden_states) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -303,7 +278,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OlmoeDecoderLayer(config, layer_idx, cache_config, quant_config=quant_config) + OlmoeDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -315,7 +293,6 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - #import pdb; pdb.set_trace() hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): @@ -355,19 +332,23 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - return self.model(input_ids, positions, kv_caches, attn_metadata) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.logits_processor(self.lm_head, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + return logits def sample( self, logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - return self.sampler(logits, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -389,7 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: continue + if "rotary_emb.inv_freq" in name: + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: From 70b4f298e2484092ab7a1a818275b411c803a2a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=BA=E8=9B=B0?= Date: Wed, 28 Aug 2024 01:24:26 +0800 Subject: [PATCH 4/7] remove: unused comments --- vllm/model_executor/models/olmoe.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 314b28e432b0a..6667c8d7de830 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -172,8 +172,6 @@ def forward( output, _ = self.o_proj(attn_output) return output -# from transformers import OlmoeForCausalLM -# model = OlmoeForCausalLM.from_pretrained("OLMoE/OLMoE-1B-7B-0824", torch_dtype=torch.bfloat16) class OlmoeDecoderLayer(nn.Module): @@ -214,8 +212,6 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) - # self.layer_idx = layer_idx - def forward( self, @@ -233,26 +229,12 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - # print("#-"*20) - # print(positions) - # print(hidden_states) - # print("#-"*20) - - # self_attn = model.model.layers[self.layer_idx].self_attn.cuda() - # hidden_states_old = self_attn( - # hidden_states=hidden_states.unsqueeze(0), - # position_ids=positions.unsqueeze(0) - # )[0][0] - # print("old:", hidden_states_old) - - hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) - # print("new: ", hidden_states) # Fully Connected hidden_states, residual = self.post_attention_layernorm( From 94d92769b04b3fe0300f3b5723eb3aa00d69565d Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 9 Sep 2024 10:43:45 -0700 Subject: [PATCH 5/7] Update name --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 0b9fc80daa65e..0f43560661274 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -115,7 +115,7 @@ Decoder-only Language Models - * - :code:`OLMoEForCausalLM` - OLMoE - - :code:`OLMoE/OLMoE-1B-7B-0824`, :code:`OLMoE/OLMoE-1B-7B-0824`, etc. + - :code:`OLMoE/OLMoE-1B-7B-0924`, :code:`OLMoE/OLMoE-1B-7B-0924`, etc. - * - :code:`OLMoForCausalLM` - OLMo From e8030aaeb5e7242668d17692daa98f0313772220 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 9 Sep 2024 12:48:01 -0700 Subject: [PATCH 6/7] Fix path Co-authored-by: Michael Goin --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 0f43560661274..f1fa937eb3c45 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -115,7 +115,7 @@ Decoder-only Language Models - * - :code:`OLMoEForCausalLM` - OLMoE - - :code:`OLMoE/OLMoE-1B-7B-0924`, :code:`OLMoE/OLMoE-1B-7B-0924`, etc. + - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. - * - :code:`OLMoForCausalLM` - OLMo From d646a9041a44cdddb7eb2ff94b5b38be26440607 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 19 Sep 2024 21:09:56 +0000 Subject: [PATCH 7/7] Format --- vllm/model_executor/models/olmoe.py | 31 +++++++++++------------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 6667c8d7de830..c76e5e86c89d8 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -13,31 +13,27 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple import torch -import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once @@ -166,7 +162,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.q_norm(q), self.k_norm(k) + q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -188,7 +184,7 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 4096) - + self.self_attn = OlmoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -207,11 +203,8 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=1e-5) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=1e-5) - + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) def forward( self, @@ -261,9 +254,9 @@ def __init__( ) self.layers = nn.ModuleList([ OlmoeDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -387,7 +380,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break