From ce630ea78004d33733229a9d783dda98c96ba227 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 8 Jul 2024 14:55:58 +0000 Subject: [PATCH 01/40] WiP adding support for Mamba --- examples/offline_inference.py | 4 +- vllm/attention/backends/no_attention.py | 161 ++++++ vllm/attention/selector.py | 2 + vllm/config.py | 12 + vllm/core/scheduler.py | 4 + vllm/engine/llm_engine.py | 6 +- vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/mamba.py | 704 ++++++++++++++++++++++++ vllm/sequence.py | 3 + vllm/worker/model_runner.py | 7 +- vllm/worker/worker.py | 16 +- 11 files changed, 911 insertions(+), 11 deletions(-) create mode 100644 vllm/attention/backends/no_attention.py create mode 100644 vllm/model_executor/models/mamba.py diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..f64082ac0fb12 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -8,10 +8,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="state-spaces/mamba-370m-hf") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/no_attention.py b/vllm/attention/backends/no_attention.py new file mode 100644 index 0000000000000..c42f39789ebe2 --- /dev/null +++ b/vllm/attention/backends/no_attention.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass, fields +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +import torch + + +class NoAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "No attention" + + @staticmethod + def get_impl_cls() -> Type["NoAttentionImpl"]: + return NoAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["NoAttentionMetadata"]: + return NoAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class NoAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["NoAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = NoAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class NoAttentionImpl(AttentionImpl): + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ae63eb1d48f8d..93c39e45635ed 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -35,6 +35,8 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" + import pdb + pdb.set_trace() if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") diff --git a/vllm/config.py b/vllm/config.py index 1ea2888796808..940dea2dbb0d4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -312,6 +312,12 @@ def get_head_size(self) -> int: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 + + if hasattr(self.hf_text_config, "model_type" + ) and self.hf_text_config.model_type == 'mamba': + # Is this going to explode + return 0 + if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. @@ -342,6 +348,8 @@ def get_total_num_kv_heads(self) -> int: if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) + if self.hf_config.model_type == "mamba": + return 0 attributes = [ # For Falcon: @@ -393,6 +401,10 @@ def contains_seqlen_agnostic_layers( def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) + + if self.hf_config.model_type == "mamba": + return ["mamba"] * num_layers + # Transformers supports layers_block_type @property return getattr(self.hf_config, "layers_block_type", ["attention"] * num_layers) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9e626b2883975..3382f19d25475 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -695,6 +695,7 @@ def _schedule_prefills( # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate(seq_group) + can_allocate = True #TODO HACK TMS if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -757,6 +758,8 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ + import pdb + pdb.set_trace() # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -1054,6 +1057,7 @@ def free_finished_seq_groups(self) -> None: if not seq_group.is_finished()) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + return #TODO TMS HACK self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de7604ece7c31..6a45524c0feda 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -252,9 +252,13 @@ def __init__( load_config=load_config, ) - if not self.model_config.embedding_mode: + if self.model_config.get_num_attention_layers(parallel_config) == 0: + self.cache_config.num_gpu_blocks = 0 + self.cache_config.num_cpu_blocks = 0 + elif not self.model_config.embedding_mode: self._initialize_kv_caches() + # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): from vllm.model_executor.model_loader import ( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a4fe18d52d608..0cca07dc567de 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -63,7 +63,8 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM") } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py new file mode 100644 index 0000000000000..156d045d29675 --- /dev/null +++ b/vllm/model_executor/models/mamba.py @@ -0,0 +1,704 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from torch import nn +from torch.nn.parameter import Parameter +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + + # TODO: ?? + #self.use_bias = config.mamba_proj_bias + self.use_bias = False + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.layer_norm_epsilon) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.layer_norm_epsilon) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.layer_norm_epsilon) + + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + if attn_metadata.prefill_metadata is not None: + offset = 0 + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) + offset += prompt_len + else: + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm( + hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + return hidden_states + +class MambaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if not self.mamba_cache: + self._prepare_mamba_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + else: + # CUDA graph capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) + self.current_indices = indices + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) + + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) + + return hidden_states + + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): + for i, offset in enumerate(indices): + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) + + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: + indices_for_current_run = [] + for request_id, seqs_id in request_ids_to_seq_ids.items(): + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + request_id, seqs_id) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_mamba_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = len(request_ids_to_seq_ids) + ( + current_mamba_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + self.current_indices = indices + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) + + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the MambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ + self._copy_mamba_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"]) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) + + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.mamba_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_batch_size) + ].index(True) + return first_free_index + return 0 + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def _prepare_mamba_cache(self): + dtype = self.lm_head.weight.dtype + num_mamba_layers = self.config.num_hidden_layers + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for k, v in params_dict.items(): + print(k) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index d200115aa0921..5ea623f1cc3fd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -643,6 +643,9 @@ def __init__( encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: + import pdb + pdb.set_trace() + self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0c82d6bbedf3..ac257d77bd5b1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,8 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.no_attention import NoAttentionBackend + from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) @@ -222,7 +224,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) if num_attn_heads else None + ) if num_attn_heads else NoAttentionBackend() # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ @@ -395,6 +397,9 @@ def _prepare_model_input_tensors( block_aligned_sliding_window = \ sliding_window_blocks * self.block_size + import pdb + pdb.set_trace() + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 58707269bd68c..d0be781101222 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,11 +184,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -209,7 +213,7 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - + self._init_cache_engine() self._warm_up_model() From 6c59b06a569e3d2abcbb90d6507e118bc483d6c4 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 9 Jul 2024 21:50:18 +0000 Subject: [PATCH 02/40] wip --- vllm/config.py | 3 +++ vllm/core/scheduler.py | 5 ++--- vllm/sequence.py | 5 +---- vllm/worker/model_runner.py | 3 --- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 940dea2dbb0d4..d31433552a85a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -798,6 +798,9 @@ def __init__(self, if enable_chunked_prefill: logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + #TODO: already perfect + self.its_mamba = True + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3382f19d25475..f671da01e6924 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -270,6 +270,8 @@ def __init__( version = "v2" if self.scheduler_config.embedding_mode: version = "embedding" + if self.scheduler_config.its_mamba: + version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) @@ -758,8 +760,6 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ - import pdb - pdb.set_trace() # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -1057,7 +1057,6 @@ def free_finished_seq_groups(self) -> None: if not seq_group.is_finished()) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - return #TODO TMS HACK self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING diff --git a/vllm/sequence.py b/vllm/sequence.py index 5ea623f1cc3fd..753d7875cacb7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -643,9 +643,6 @@ def __init__( encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: - import pdb - pdb.set_trace() - self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data @@ -660,7 +657,7 @@ def __init__( self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample - + # The number of speculative tokens adopted in this request. # None means specuative decoding is not used. # Zero means speculative decoding is disabled for some reasons. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ac257d77bd5b1..30f2b1d366cc2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -397,9 +397,6 @@ def _prepare_model_input_tensors( block_aligned_sliding_window = \ sliding_window_blocks * self.block_size - import pdb - pdb.set_trace() - for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt From eb9bf348032b51520c313a55d7813b43567e5763 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 10 Jul 2024 21:04:07 +0000 Subject: [PATCH 03/40] WIP -- runs through. Generates tokens. Bad tokens. --- vllm/attention/backends/no_attention.py | 6 +++--- vllm/attention/selector.py | 2 -- vllm/config.py | 2 ++ vllm/core/embedding_model_block_manager.py | 2 +- vllm/engine/arg_utils.py | 1 + vllm/engine/llm_engine.py | 9 ++++----- vllm/model_executor/models/mamba.py | 23 +++++++++++++--------- vllm/worker/model_runner.py | 7 ++++--- vllm/worker/worker.py | 7 ++++--- 9 files changed, 33 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/no_attention.py b/vllm/attention/backends/no_attention.py index c42f39789ebe2..25e239b603bc5 100644 --- a/vllm/attention/backends/no_attention.py +++ b/vllm/attention/backends/no_attention.py @@ -89,7 +89,7 @@ class NoAttentionMetadata(AttentionMetadata): use_cuda_graph: bool _cached_prefill_metadata: Optional["NoAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["NoAttentionMetadata"] = None @property def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: @@ -125,7 +125,7 @@ def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + def decode_metadata(self) -> Optional["NoAttentionMetadata"]: if self.num_decode_tokens == 0: return None @@ -134,7 +134,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: assert self.block_tables is not None assert self.seq_lens_tensor is not None - self._cached_decode_metadata = FlashAttentionMetadata( + self._cached_decode_metadata = NoAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 93c39e45635ed..ae63eb1d48f8d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -35,8 +35,6 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - import pdb - pdb.set_trace() if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") diff --git a/vllm/config.py b/vllm/config.py index d31433552a85a..0ea675cbe869f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -443,6 +443,7 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, + cache_grows: bool, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -452,6 +453,7 @@ def __init__( self.swap_space_bytes = swap_space * _GB self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype + self.cache_grows = cache_grows self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self._verify_args() diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306d7ceb..43a9f9de6767c 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -37,7 +37,7 @@ def append_slots( seq: Sequence, num_lookahead_slots: int, ) -> List[Tuple[int, int]]: - return None # type: ignore + return [] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: pass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index afa6892d49eb8..20c010a09b069 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -650,6 +650,7 @@ def create_engine_config(self, ) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, + cache_grows=False, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6a45524c0feda..07b6912576ab1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -252,10 +252,9 @@ def __init__( load_config=load_config, ) - if self.model_config.get_num_attention_layers(parallel_config) == 0: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - elif not self.model_config.embedding_mode: + if not self.model_config.embedding_mode: + # TODO: Even for mamba, we must initialize the KV caches, + # Because model warmup and CUDA graphs are created here. self._initialize_kv_caches() @@ -852,7 +851,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: 0].schedule() finished_requests_ids = self.scheduler[ 0].get_and_reset_finished_requests_ids() - + if not scheduler_outputs.is_empty(): execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 156d045d29675..1e5632816b19b 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -422,15 +422,19 @@ def __init__( self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) + + #TODO: this ends up all 0s -- we don't put anything in here when loading weights. + #TODO: Does mamba share weights between the lm head and embeddings? +# self.lm_head = ParallelLMHead( +# self.unpadded_vocab_size, +# config.hidden_size, +# org_num_embeddings=config.vocab_size, +# padding_size=DEFAULT_VOCAB_PADDING_SIZE +# # We need bigger padding if using lora for kernel +# # compatibility +# if not lora_config else lora_config.lora_vocab_padding_size, +# ) + self.lm_head = self.backbone.embeddings # Current step used indices self.current_indices: List[int] = [] # Used to track and store by the Mamba cache between steps. @@ -451,6 +455,7 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): + if not self.mamba_cache: self._prepare_mamba_cache() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 30f2b1d366cc2..006eb3c2d2dc6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -933,7 +933,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "You can also reduce the `max_num_seqs` as needed " "to decrease memory usage.") start_time = time.perf_counter() - + # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1410,8 +1410,9 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) + if self.backend_name != "No attention": + self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, + non_blocking=True) if self.backend_name != "flashinfer": self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d0be781101222..52ac40a5dca28 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -209,6 +209,7 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, + self.cache_config.cache_grows, self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -346,14 +347,14 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): "`dtype` flag in CLI, for example: --dtype=half.") -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, cache_grows, max_model_len) -> None: - if num_gpu_blocks <= 0: + if num_gpu_blocks <= 0 and cache_grows: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: + if max_model_len > max_seq_len and cache_grows: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be " From 320f79b348694175e80745a2f0de10fe4874d7f7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 15 Jul 2024 19:43:43 +0000 Subject: [PATCH 04/40] Good output for mamba-370m --- examples/offline_inference.py | 3 ++- vllm/model_executor/models/mamba.py | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index f64082ac0fb12..6e6f91542bb7c 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,4 +1,5 @@ from vllm import LLM, SamplingParams +import torch # Sample prompts. prompts = [ @@ -11,7 +12,7 @@ sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="state-spaces/mamba-370m-hf") +llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1e5632816b19b..fa65f081bca89 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -131,16 +131,19 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) self.activation = config.hidden_act - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.layer_norm_epsilon) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=config.layer_norm_epsilon) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=config.layer_norm_epsilon) + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + #self.dt_layernorm = RMSNorm(self.time_step_rank, + # eps=config.layer_norm_epsilon) + #self.b_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) + #self.c_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) @@ -180,9 +183,12 @@ def mamba_forward(self, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1, ) - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) + + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + # time_step = self.dt_layernorm(time_step.contiguous()) + # B = self.b_layernorm(B.contiguous()) + # C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -382,6 +388,8 @@ def forward( ssm_state=current_ssm_state, ) hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states class MambaForCausalLM(nn.Module): From 5ab6622f2d143e3242e44f1bd25c2faceaa93b1a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 15:37:34 +0000 Subject: [PATCH 05/40] wip --- examples/offline_inference.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 6e6f91542bb7c..cb74561f35e3e 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -12,7 +12,9 @@ sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) +#llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) +llm = LLM(model="state-spaces/mamba2-130m", dtype=torch.float32) + # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From 25b54d95458670402908eb4b4f11f1da575b0f85 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:09:55 +0000 Subject: [PATCH 06/40] cleanup --- .../{no_attention.py => placeholder_attn.py} | 38 +- vllm/config.py | 38 +- vllm/core/scheduler.py | 6 +- vllm/engine/arg_utils.py | 3 +- vllm/engine/llm_engine.py | 6 +- vllm/model_executor/models/2 | 728 ++++++++++++++++++ vllm/model_executor/models/__init__.py | 1 - vllm/model_executor/models/mamba.py | 81 +- vllm/sequence.py | 2 +- vllm/worker/model_runner.py | 10 +- vllm/worker/worker.py | 14 +- 11 files changed, 822 insertions(+), 105 deletions(-) rename vllm/attention/backends/{no_attention.py => placeholder_attn.py} (82%) create mode 100644 vllm/model_executor/models/2 diff --git a/vllm/attention/backends/no_attention.py b/vllm/attention/backends/placeholder_attn.py similarity index 82% rename from vllm/attention/backends/no_attention.py rename to vllm/attention/backends/placeholder_attn.py index 25e239b603bc5..6bc766ba4e3f7 100644 --- a/vllm/attention/backends/no_attention.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,12 +1,15 @@ -from dataclasses import dataclass, fields -from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, - TypeVar) +from dataclasses import dataclass +from typing import (List, Optional, Tuple, Type) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) import torch +# Placeholder attention backend for models like Mamba that don't have attention. +# Mainly exists to sidestep get_attn_backend. +# The attention metadata is still needed for Mamba. -class NoAttentionBackend(AttentionBackend): + +class PlaceholderAttentionBackend(AttentionBackend): """Placeholder backend for when no attention is needed.""" @staticmethod @@ -14,12 +17,12 @@ def get_name() -> str: return "No attention" @staticmethod - def get_impl_cls() -> Type["NoAttentionImpl"]: - return NoAttentionImpl + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl @staticmethod - def get_metadata_cls() -> Type["NoAttentionMetadata"]: - return NoAttentionMetadata + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata @staticmethod def get_kv_cache_shape( @@ -47,7 +50,7 @@ def copy_blocks( @dataclass -class NoAttentionMetadata(AttentionMetadata): +class PlaceholderAttentionMetadata(AttentionMetadata): """Attention metadata for prefill and decode batched together.""" # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. @@ -88,11 +91,11 @@ class NoAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - _cached_prefill_metadata: Optional["NoAttentionMetadata"] = None - _cached_decode_metadata: Optional["NoAttentionMetadata"] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @property - def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_prefills == 0: return None @@ -106,7 +109,7 @@ def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: assert self.block_tables is not None assert self.seq_start_loc is not None - self._cached_prefill_metadata = NoAttentionMetadata( + self._cached_prefill_metadata = PlaceholderAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, @@ -125,7 +128,7 @@ def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["NoAttentionMetadata"]: + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_decode_tokens == 0: return None @@ -134,7 +137,7 @@ def decode_metadata(self) -> Optional["NoAttentionMetadata"]: assert self.block_tables is not None assert self.seq_lens_tensor is not None - self._cached_decode_metadata = NoAttentionMetadata( + self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, @@ -153,9 +156,10 @@ def decode_metadata(self) -> Optional["NoAttentionMetadata"]: return self._cached_decode_metadata -class NoAttentionImpl(AttentionImpl): +class PlaceholderAttentionImpl(AttentionImpl): + def __init__(self, *args, **kwargs) -> None: return - + def forward(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/config.py b/vllm/config.py index 0ae4acf4403ea..26ad889cceba4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -276,6 +276,19 @@ def verify_with_parallel_config( raise ValueError( "BitAndBytes quantization with TP or PP is not supported yet.") + def is_attention_free(self) -> bool: + """Returns True if the model has no attention, i.e. the model has no + state that grows with the size of the context. + """ + + # Return true if the model is mamba. + # This check should be augmented with more models in the future, + # and made more robust if possible. + if hasattr(self.hf_text_config, + "model_type") and self.hf_text_config.model_type == 'mamba': + return True + return False + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -310,10 +323,8 @@ def get_head_size(self) -> int: # we need to pad head_size 192 to 256 return 256 - if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'mamba': - # Is this going to explode - return 0 + if self.is_attention_free(): + return 0 if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim @@ -345,7 +356,8 @@ def get_total_num_kv_heads(self) -> int: if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) - if self.hf_config.model_type == "mamba": + + if self.is_attention_free(): return 0 attributes = [ @@ -398,8 +410,9 @@ def contains_seqlen_agnostic_layers( def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) - - if self.hf_config.model_type == "mamba": + + if self.is_attention_free(): + assert (self.hf_config.model_type == "mamba") return ["mamba"] * num_layers # Transformers supports layers_block_type @property @@ -440,7 +453,7 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, - cache_grows: bool, + is_attention_free: bool, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -450,7 +463,7 @@ def __init__( self.swap_space_bytes = swap_space * _GB self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype - self.cache_grows = cache_grows + self.is_attention_free = is_attention_free self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self._verify_args() @@ -745,6 +758,8 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). + is_attention_free: True if the running model does not have state that + grows as the context size increases. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. num_lookahead_slots: The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative @@ -767,6 +782,7 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + is_attention_free: bool, use_v2_block_manager: bool = False, num_lookahead_slots: int = 0, delay_factor: float = 0.0, @@ -791,11 +807,9 @@ def __init__(self, if enable_chunked_prefill: logger.info("Chunked prefill is enabled (EXPERIMENTAL).") - #TODO: already perfect - self.its_mamba = True - self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len + self.is_attention_free = is_attention_free self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ae46f2f1a96ac..ef183cda9d6f3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -279,9 +279,8 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" - if self.scheduler_config.its_mamba: + if (self.scheduler_config.embedding_mode + or self.scheduler_config.is_attention_free): version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( @@ -708,7 +707,6 @@ def _schedule_prefills( # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate(seq_group) - can_allocate = True #TODO HACK TMS if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cd6912466a3b3..20bfd71221e4d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -664,7 +664,7 @@ def create_engine_config(self, ) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, - cache_grows=False, + is_attention_free=model_config.is_attention_free(), num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching) @@ -709,6 +709,7 @@ def create_engine_config(self, ) -> EngineConfig: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, + is_attention_free=model_config.is_attention_free(), use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=(self.num_lookahead_slots if speculative_config is None else diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index af6191b7b3527..c43f7fcb85484 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -260,11 +260,11 @@ def __init__( ) if not self.model_config.embedding_mode: - # TODO: Even for mamba, we must initialize the KV caches, - # Because model warmup and CUDA graphs are created here. + # For all decoders including attention-free models like mamba, + # this must call _initialize_kv_caches, as this is where model + # warmup and CUDA graphs creation happens. self._initialize_kv_caches() - # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): from vllm.model_executor.model_loader import ( diff --git a/vllm/model_executor/models/2 b/vllm/model_executor/models/2 new file mode 100644 index 0000000000000..0452e9b3381b0 --- /dev/null +++ b/vllm/model_executor/models/2 @@ -0,0 +1,728 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +import pdb +import traceback +import inspect + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from torch import nn +from torch.nn.parameter import Parameter +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def function_in_stack(function_name): + stack = traceback.extract_stack() + for frame in stack: + if frame.name == function_name: + return True + return False + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + + # TODO: ?? + #self.use_bias = config.mamba_proj_bias + self.use_bias = False + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + #self.dt_layernorm = RMSNorm(self.time_step_rank, + # eps=config.layer_norm_epsilon) + #self.b_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) + #self.c_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) + + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + # time_step = self.dt_layernorm(time_step.contiguous()) + # B = self.b_layernorm(B.contiguous()) + # C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + if attn_metadata.prefill_metadata is not None: + offset = 0 + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) + offset += prompt_len + else: + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm( + hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + + + return hidden_states + +class MambaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + #TODO: this ends up all 0s -- we don't put anything in here when loading weights. + #TODO: Does mamba share weights between the lm head and embeddings? +# self.lm_head = ParallelLMHead( +# self.unpadded_vocab_size, +# config.hidden_size, +# org_num_embeddings=config.vocab_size, +# padding_size=DEFAULT_VOCAB_PADDING_SIZE +# # We need bigger padding if using lora for kernel +# # compatibility +# if not lora_config else lora_config.lora_vocab_padding_size, +# ) + self.lm_head = self.backbone.embeddings + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + + if not self.mamba_cache: + self._prepare_mamba_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + else: + # CUDA graph capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) + self.current_indices = indices + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) + + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) + + return hidden_states + + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): + for i, offset in enumerate(indices): + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) + + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: + indices_for_current_run = [] + for request_id, seqs_id in request_ids_to_seq_ids.items(): + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + request_id, seqs_id) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_mamba_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = len(request_ids_to_seq_ids) + ( + current_mamba_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + self.current_indices = indices + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) + + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the MambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ + self._copy_mamba_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"]) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) + + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.mamba_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_batch_size) + ].index(True) + return first_free_index + return 0 + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def _prepare_mamba_cache(self): + dtype = self.lm_head.weight.dtype + num_mamba_layers = self.config.num_hidden_layers + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for k, v in params_dict.items(): + print(k) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 130f00e8645d1..7ce2b6fa5c91b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -75,7 +75,6 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } - _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} # Architecture -> type. diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fa65f081bca89..ca8d58fd3d6aa 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -12,25 +12,20 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs @@ -46,6 +41,7 @@ class MambaCacheParams: conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class MambaMixer(nn.Module): """ @@ -131,15 +127,6 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) self.activation = config.hidden_act - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - #self.dt_layernorm = RMSNorm(self.time_step_rank, - # eps=config.layer_norm_epsilon) - #self.b_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - #self.c_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): @@ -184,11 +171,7 @@ def mamba_forward(self, dim=-1, ) - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - # time_step = self.dt_layernorm(time_step.contiguous()) - # B = self.b_layernorm(B.contiguous()) - # C = self.c_layernorm(C.contiguous()) + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -256,6 +239,7 @@ def forward( return hidden_states + class MambaMLP(nn.Module): def __init__( @@ -286,6 +270,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class MambaDecoderLayer(nn.Module): def __init__(self, @@ -299,8 +284,7 @@ def __init__(self, self.mixer = MambaMixer(config, layer_idx) self.feed_forward = MambaMLP(config, quant_config=quant_config) - self.norm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -317,8 +301,7 @@ def forward( residual = hidden_states hidden_states = self.norm(hidden_states) else: - hidden_states, residual = self.norm( - hidden_states, residual) + hidden_states, residual = self.norm(hidden_states, residual) hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, ssm_state) @@ -328,6 +311,7 @@ def forward( hidden_states = self.feed_forward(hidden_states) return hidden_states, residual + class MambaModel(nn.Module): def __init__( @@ -355,12 +339,12 @@ def __init__( for i in range(config.num_hidden_layers): decoder_layers.append( MambaDecoderLayer(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) self.layers = nn.ModuleList(decoder_layers) self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + eps=config.layer_norm_epsilon) def forward( self, @@ -389,9 +373,9 @@ def forward( ) hidden_states, _ = self.norm_f(hidden_states, residual) - return hidden_states + class MambaForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ @@ -424,24 +408,13 @@ def __init__( super().__init__() self.config = config self.backbone = MambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - #TODO: this ends up all 0s -- we don't put anything in here when loading weights. - #TODO: Does mamba share weights between the lm head and embeddings? -# self.lm_head = ParallelLMHead( -# self.unpadded_vocab_size, -# config.hidden_size, -# org_num_embeddings=config.vocab_size, -# padding_size=DEFAULT_VOCAB_PADDING_SIZE -# # We need bigger padding if using lora for kernel -# # compatibility -# if not lora_config else lora_config.lora_vocab_padding_size, -# ) self.lm_head = self.backbone.embeddings # Current step used indices self.current_indices: List[int] = [] @@ -493,9 +466,9 @@ def forward(self, self.current_indices = indices hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) if "seqlen_agnostic_capture_inputs" not in kwargs: self._copy_mamba_cache_by_indices(self.current_indices, @@ -565,9 +538,9 @@ def _prepare_current_run_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). """ assert all( key in kwargs @@ -591,7 +564,7 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): """ Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the MambaForCausalLM.mamba_cache after CUDA + back to the MambaForCausalLM.mamba_cache after CUDA graph replay run is done. """ self._copy_mamba_cache_by_indices( @@ -601,7 +574,7 @@ def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph + The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ return tuple(buffer[:, :batch_size] @@ -629,7 +602,6 @@ def _get_mamba_cache_shape( self ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size conv_state_shape = ( self.config.intermediate_size // world_size, self.config.conv_kernel, @@ -682,9 +654,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) - for k, v in params_dict.items(): - print(k) - for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/sequence.py b/vllm/sequence.py index 6753d7f86b639..1cebf68d463db 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -684,7 +684,7 @@ def __init__( self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample - + # The number of speculative tokens adopted in this request. # None means specuative decoding is not used. # Zero means speculative decoding is disabled for some reasons. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 56c2693d661ba..459798e418c30 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.no_attention import NoAttentionBackend +from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -236,7 +236,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) if num_attn_heads else NoAttentionBackend() + ) if num_attn_heads else PlaceholderAttentionBackend() # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ @@ -1016,7 +1016,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "You can also reduce the `max_num_seqs` as needed " "to decrease memory usage.") start_time = time.perf_counter() - + # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1509,8 +1509,8 @@ def forward( self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) if self.backend_name != "No attention": - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) + self.input_buffers["slot_mapping"].copy_( + attn_metadata.slot_mapping, non_blocking=True) if self.backend_name != "flashinfer": self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 60f80b135e322..f80b8be89a8f3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -215,12 +215,12 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, - self.cache_config.cache_grows, + self.cache_config.is_attention_free, self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - + self._init_cache_engine() self._warm_up_model() @@ -366,14 +366,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): "`dtype` flag in CLI, for example: --dtype=half.") -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, cache_grows, +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, max_model_len) -> None: - if num_gpu_blocks <= 0 and cache_grows: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks}" + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len and cache_grows: + if not is_attention_free and max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be " From ebc12f1ee699f8d0b83dd5f75e402447729180fa Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:15:23 +0000 Subject: [PATCH 07/40] Rename embedding block space manager --- vllm/core/interfaces.py | 8 ++++---- ...lock_manager.py => placeholder_block_space_manager.py} | 7 ++++--- vllm/core/scheduler.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) rename vllm/core/{embedding_model_block_manager.py => placeholder_block_space_manager.py} (90%) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795b8..d964898af19bc 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,10 +35,10 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 - if version == "embedding": - from vllm.core.embedding_model_block_manager import ( - EmbeddingModelBlockSpaceManager) - return EmbeddingModelBlockSpaceManager + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager raise ValueError(f"Unknown version {version=}") diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/placeholder_block_space_manager.py similarity index 90% rename from vllm/core/embedding_model_block_manager.py rename to vllm/core/placeholder_block_space_manager.py index 43a9f9de6767c..a71e6f79b6d2f 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -4,9 +4,10 @@ from vllm.sequence import Sequence, SequenceGroup -class EmbeddingModelBlockSpaceManager(BlockSpaceManager): - """An embedding version of BlockSpaceManager for use in environments - with embedding models where block management is not required. +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: embedding models or attention-free models like Mamba. This class provides the same interface as BlockSpaceManager, but its methods perform no actions or return simple values like True in specific diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ef183cda9d6f3..f004df21169bb 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -281,7 +281,7 @@ def __init__( version = "v2" if (self.scheduler_config.embedding_mode or self.scheduler_config.is_attention_free): - version = "embedding" + version = "placeholder" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) From ac60374b8637ea2ed00ebaa159c06979e38ada44 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:30:29 +0000 Subject: [PATCH 08/40] cleanup --- examples/offline_inference.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index cb74561f35e3e..9b758fa2479f6 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,5 +1,4 @@ from vllm import LLM, SamplingParams -import torch # Sample prompts. prompts = [ @@ -9,12 +8,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.0, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -#llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) -llm = LLM(model="state-spaces/mamba2-130m", dtype=torch.float32) - +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From adb6713830e1f5c252f6de71d7173a55197dfe1d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:31:51 +0000 Subject: [PATCH 09/40] remove file --- vllm/model_executor/models/2 | 728 ----------------------------------- 1 file changed, 728 deletions(-) delete mode 100644 vllm/model_executor/models/2 diff --git a/vllm/model_executor/models/2 b/vllm/model_executor/models/2 deleted file mode 100644 index 0452e9b3381b0..0000000000000 --- a/vllm/model_executor/models/2 +++ /dev/null @@ -1,728 +0,0 @@ -# coding=utf-8 -"""PyTorch MAMBA model.""" -import pdb -import traceback -import inspect - -from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple - -import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from torch import nn -from torch.nn.parameter import Parameter -from transformers import MambaConfig - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -def function_in_stack(function_name): - stack = traceback.extract_stack() - for frame in stack: - if frame.name == function_name: - return True - return False - -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - -# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer -class MambaMixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute - the `contextualized_states`. A, D are input independent - (see Mamba paper [1] Section 3.5.2 "Interpretation of A" - for why A isn't selective) ∆, B, C are input-dependent - (this is a key difference between Mamba and the linear time - invariant S4, and is why Mamba is called - **selective** state spaces) - """ - - def __init__(self, config: MambaConfig, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = config.intermediate_size - self.time_step_rank = int(config.time_step_rank) - self.use_conv_bias = config.use_conv_bias - - # TODO: ?? - #self.use_bias = config.mamba_proj_bias - self.use_bias = False - - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.intermediate_size, - bias=self.use_conv_bias, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=self.use_bias) - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.intermediate_size, - self.time_step_rank + self.ssm_state_size * 2, - bias=False, - ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True) - - def weight_loader(param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) - - def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): - weight_loader(param, -torch.exp(loaded_weight.float())) - - tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter( - torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - - set_weight_attrs(self.D, {"weight_loader": weight_loader}) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) - - self.out_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=self.use_bias, - input_is_parallel=True, - ) - self.activation = config.hidden_act - - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - #self.dt_layernorm = RMSNorm(self.time_step_rank, - # eps=config.layer_norm_epsilon) - #self.b_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - #self.c_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - - def mamba_forward(self, - hidden_states: torch.Tensor, - cache_params: MambaCacheParams = None): - - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) - hidden_states, gate = projected_states.chunk(2, dim=1) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - if cache_params is not None and not cache_params.is_prompt: - hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), - cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - ) - hidden_states = hidden_states.unsqueeze(-1) - else: - if cache_params is not None: - conv_states = nn.functional.pad( - hidden_states, - (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_state.copy_(conv_states) - - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - ) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - # time_step = self.dt_layernorm(time_step.contiguous()) - # B = self.b_layernorm(B.contiguous()) - # C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - if cache_params is not None and not cache_params.is_prompt: - scan_outputs = selective_state_update( - cache_params.ssm_state, - hidden_states[..., 0], - discrete_time_step[..., 0], - self.A, - B[:, 0], - C[:, 0], - self.D, - gate[..., 0], - time_proj_bias, - dt_softplus=True, - ).unsqueeze(-1) - else: - scan_outputs, ssm_state = selective_scan_fn( - hidden_states, - discrete_time_step, - self.A, - B.transpose(1, 2), - C.transpose(1, 2), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - return_last_state=True, - ) - if ssm_state is not None and cache_params is not None: - cache_params.ssm_state.copy_(ssm_state) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] - return contextualized_states - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - ): - if attn_metadata.prefill_metadata is not None: - offset = 0 - for i, prompt_len in enumerate( - attn_metadata.prefill_metadata.seq_lens): - cache = MambaCacheParams(True, - conv_state=conv_state[i].unsqueeze(0), - ssm_state=ssm_state[i].unsqueeze(0)) - hidden_states[offset:offset + prompt_len].copy_( - self.mamba_forward(hidden_states[offset:offset + - prompt_len].unsqueeze(0), - cache_params=cache)[0]) - offset += prompt_len - else: - cache = MambaCacheParams(False, - conv_state=conv_state, - ssm_state=ssm_state) - hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), - cache_params=cache) - hidden_states = hidden_states.squeeze(1) - - return hidden_states - -class MambaMLP(nn.Module): - - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - hidden_act = config.hidden_act - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - -class MambaDecoderLayer(nn.Module): - - def __init__(self, - config: MambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.mixer = MambaMixer(config, layer_idx) - - self.feed_forward = MambaMLP(config, quant_config=quant_config) - self.norm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - **kwargs, - ): - if residual is None: - residual = hidden_states - hidden_states = self.norm(hidden_states) - else: - hidden_states, residual = self.norm( - hidden_states, residual) - - hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, - ssm_state) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) - return hidden_states, residual - -class MambaModel(nn.Module): - - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embeddings = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - decoder_layers = [] - for i in range(config.num_hidden_layers): - decoder_layers.append( - MambaDecoderLayer(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) - self.layers = nn.ModuleList(decoder_layers) - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - ) -> torch.Tensor: - hidden_states = self.embeddings(input_ids) - residual = None - - for i in range(len(self.layers)): - layer = self.layers[i] - current_ssm_state = ssm_state[i] - current_conv_state = conv_state[i] - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) - hidden_states, _ = self.norm_f(hidden_states, residual) - - - return hidden_states - -class MambaForCausalLM(nn.Module): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embeddings": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__( - self, - config: MambaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.backbone = MambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - #TODO: this ends up all 0s -- we don't put anything in here when loading weights. - #TODO: Does mamba share weights between the lm head and embeddings? -# self.lm_head = ParallelLMHead( -# self.unpadded_vocab_size, -# config.hidden_size, -# org_num_embeddings=config.vocab_size, -# padding_size=DEFAULT_VOCAB_PADDING_SIZE -# # We need bigger padding if using lora for kernel -# # compatibility -# if not lora_config else lora_config.lora_vocab_padding_size, -# ) - self.lm_head = self.backbone.embeddings - # Current step used indices - self.current_indices: List[int] = [] - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Used as an input_buffer for the CUDA graph runs. - self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.sampler = Sampler() - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs): - - if not self.mamba_cache: - self._prepare_mamba_cache() - - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = input_ids.shape[0] - if attn_metadata.prefill_metadata: - batch_size = len(request_ids_to_seq_ids) - ( - current_seqlen_agnostic_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - else: - # CUDA graph capturing runs - current_seqlen_agnostic_cache, indices = ( - kwargs["seqlen_agnostic_capture_inputs"], - [], - ) - self.current_indices = indices - - hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) - - if "seqlen_agnostic_capture_inputs" not in kwargs: - self._copy_mamba_cache_by_indices(self.current_indices, - current_seqlen_agnostic_cache) - - return hidden_states - - def _copy_mamba_cache_by_indices( - self, indices: List[int], - current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): - for i, offset in enumerate(indices): - self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) - - def _copy_mamba_cache(self, index_to: int, index_from: int, - from_buffer: Tuple[torch.Tensor, torch.Tensor]): - assert len(self.mamba_cache) > 0 - for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): - cache_t[:, index_to].copy_(from_buffer_t[:, index_from], - non_blocking=True) - - def _assign_seq_id_to_mamba_cache(self, cur_rid: str, - seqs_id: List[int]) -> List[int]: - indices_for_current_run = [] - for seq_id in seqs_id: - if cur_rid not in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping[cur_rid] = {} - first_free_index = self._first_free_index_in_mamba_cache() - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - ## case of decoding n>1, copy prefill cache to decoding indices - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - first_free_index = self._first_free_index_in_mamba_cache() - index_exist = list(seq_ids2indices.values())[0] - self._copy_mamba_cache(index_from=index_exist, - index_to=first_free_index, - from_buffer=self.mamba_cache) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - else: - index_for_current_run = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - - indices_for_current_run.append(index_for_current_run) - return indices_for_current_run - - def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: - indices_for_current_run = [] - for request_id, seqs_id in request_ids_to_seq_ids.items(): - indices_for_current_run += self._assign_seq_id_to_mamba_cache( - request_id, seqs_id) - ## Pad the batch in case of running batch that was not captured via CG - padded_indices = indices_for_current_run.copy() - pad_index = self._first_free_index_in_mamba_cache() - - for _ in range(batch_size - len(indices_for_current_run)): - padded_indices.append(pad_index) - - conv_state = self.mamba_cache[0][:, padded_indices] - temporal_state = self.mamba_cache[1][:, padded_indices] - - return (conv_state, temporal_state), indices_for_current_run - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = len(request_ids_to_seq_ids) - ( - current_mamba_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) - self.current_indices = indices - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - - for input_buffer, current_cache_buffer in zip( - input_buffers["seqlen_agnostic_capture_inputs"], - current_mamba_cache): - input_buffer.copy_(current_cache_buffer, non_blocking=True) - - def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the MambaForCausalLM.mamba_cache after CUDA - graph replay run is done. - """ - self._copy_mamba_cache_by_indices( - self.current_indices, - input_buffers["seqlen_agnostic_capture_inputs"]) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return tuple(buffer[:, :batch_size] - for buffer in self.mamba_gc_cache_buffer) - - def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache(self) -> int: - if self.mamba_cache: - max_possible_batch_size = self.mamba_cache[0].shape[1] - occupied = [ - id for seq_ids in self.mamba_cache_indices_mapping.values() - for id in seq_ids.values() - ] - first_free_index = [ - i not in occupied for i in range(max_possible_batch_size) - ].index(True) - return first_free_index - return 0 - - def _get_mamba_cache_shape( - self - ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.intermediate_size // world_size, - self.config.conv_kernel, - ) - temporal_state_shape = ( - self.config.intermediate_size // world_size, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - - def _prepare_mamba_cache(self): - dtype = self.lm_head.weight.dtype - num_mamba_layers = self.config.num_hidden_layers - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 - conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() - assert conv_state_shape is not None and temporal_state_shape is not None - for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: - buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) - setattr(self, buffername, buffer) - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - for k, v in params_dict.items(): - print(k) - - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if "A_log" in name: - name = name.replace("A_log", "A") - - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) From b733a840010c054f3bb069e49335e9c7926d5a35 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 18:26:46 +0000 Subject: [PATCH 10/40] format --- vllm/attention/backends/placeholder_attn.py | 6 ++++-- vllm/engine/llm_engine.py | 2 +- vllm/worker/model_runner.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 6bc766ba4e3f7..f5728756c6e5d 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,8 +1,10 @@ from dataclasses import dataclass -from typing import (List, Optional, Tuple, Type) +from typing import List, Optional, Tuple, Type + +import torch + from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -import torch # Placeholder attention backend for models like Mamba that don't have attention. # Mainly exists to sidestep get_attn_backend. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c43f7fcb85484..f1ce03171ebf7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -261,7 +261,7 @@ def __init__( if not self.model_config.embedding_mode: # For all decoders including attention-free models like mamba, - # this must call _initialize_kv_caches, as this is where model + # this must call _initialize_kv_caches, as this is where model # warmup and CUDA graphs creation happens. self._initialize_kv_caches() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 459798e418c30..2f4a0657c3f1a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,8 +23,8 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend - +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) From fb846ce85cde68ce6b22fcab596ed0ac06fef601 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 21:57:51 +0000 Subject: [PATCH 11/40] apply fix from #6214 --- vllm/model_executor/models/mamba.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ca8d58fd3d6aa..a76c3757be739 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -538,20 +538,20 @@ def _prepare_current_run_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). """ assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = len(request_ids_to_seq_ids) + cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) + cg_batch_size) self.current_indices = indices finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) From d8017cb5044eb7c3458c2976e7aed9bf17753ace Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 22:27:19 +0000 Subject: [PATCH 12/40] fixes from 6425 --- vllm/model_executor/models/interfaces.py | 2 +- vllm/model_executor/models/mamba.py | 32 ++++++++++++++++++------ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 6fdacd4469788..b0b614d6b5242 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -152,7 +152,7 @@ class HasInnerState(Protocol): """ A flag that indicates this model has inner state. Models that has inner state usually need access to the scheduler_config - for max_num_seqs ,etc... (Currently only used by Jamba) + for max_num_seqs ,etc... (Currently used by Jamba and Mamba) """ def __init__(self, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index a76c3757be739..49cfd5c186800 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -12,7 +12,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -27,10 +27,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -376,7 +378,7 @@ def forward( return hidden_states -class MambaForCausalLM(nn.Module): +class MambaForCausalLM(nn.Module, HasInnerState): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -404,9 +406,11 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, ) -> None: super().__init__() self.config = config + self.scheduler_config = scheduler_config self.backbone = MambaModel(config, cache_config=cache_config, quant_config=quant_config, @@ -436,7 +440,6 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): - if not self.mamba_cache: self._prepare_mamba_cache() @@ -447,6 +450,7 @@ def forward(self, for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) @@ -454,7 +458,8 @@ def forward(self, current_seqlen_agnostic_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) + batch_size, + finished_requests_ids) finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) else: @@ -518,10 +523,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str, return indices_for_current_run def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, + finished_requests_ids: List[str] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] for request_id, seqs_id in request_ids_to_seq_ids.items(): + if request_id in finished_requests_ids: + # Do not allocate cache for requests that run + # and finish right after + continue indices_for_current_run += self._assign_seq_id_to_mamba_cache( request_id, seqs_id) ## Pad the batch in case of running batch that was not captured via CG @@ -545,13 +555,16 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size) + cg_batch_size, + finished_requests_ids) self.current_indices = indices finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) @@ -615,9 +628,12 @@ def _get_mamba_cache_shape( def _prepare_mamba_cache(self): dtype = self.lm_head.weight.dtype num_mamba_layers = self.config.num_hidden_layers - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config else + max(_BATCH_SIZES_TO_CAPTURE)) + 10 conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state_shape, From 7ab2b9e7d3a2ce8648e35c9ab34bb1c627dbac2a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 23 Jul 2024 19:59:54 +0000 Subject: [PATCH 13/40] add an integration test --- tests/models/test_mamba.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 tests/models/test_mamba.py diff --git a/tests/models/test_mamba.py b/tests/models/test_mamba.py new file mode 100644 index 0000000000000..6a09d5f98f088 --- /dev/null +++ b/tests/models/test_mamba.py @@ -0,0 +1,79 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +Run `pytest tests/models/test_mamba.py`. +""" +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline +import torch + +from .utils import check_outputs_equal + +MODELS = [ + "state-spaces/mamba-370m-hf", +] + +# Use lower-level interfaces to create this greedy generator, as mamba will +# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. +def generate_greedy(model_name, example_prompts, max_tokens): + # Create a text generation pipeline + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, + device=torch.cuda.current_device() + if torch.cuda.is_available() else -1) + + # Generate texts from the prompts + outputs = [] + for prompt in example_prompts: + # Tokenize the input prompt with truncation + inputs = tokenizer(prompt, return_tensors="pt", truncation=True) + input_ids = inputs["input_ids"].to(model.device) + + # Generate text using the model's generate method directly + generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + outputs.append((generated_ids[0].tolist(), generated_text)) + + return outputs + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + assert dtype == "float" + + hf_outputs = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) From c319a21c9d203de56addc662691f176003ca5d91 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 23 Jul 2024 20:06:44 +0000 Subject: [PATCH 14/40] lint --- tests/models/test_mamba.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/models/test_mamba.py b/tests/models/test_mamba.py index 6a09d5f98f088..509027681f404 100644 --- a/tests/models/test_mamba.py +++ b/tests/models/test_mamba.py @@ -3,8 +3,7 @@ Run `pytest tests/models/test_mamba.py`. """ import pytest -from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline -import torch +from transformers import AutoModelForCausalLM, AutoTokenizer from .utils import check_outputs_equal @@ -12,6 +11,7 @@ "state-spaces/mamba-370m-hf", ] + # Use lower-level interfaces to create this greedy generator, as mamba will # choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. def generate_greedy(model_name, example_prompts, max_tokens): @@ -19,25 +19,23 @@ def generate_greedy(model_name, example_prompts, max_tokens): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) - generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, - device=torch.cuda.current_device() - if torch.cuda.is_available() else -1) - # Generate texts from the prompts outputs = [] for prompt in example_prompts: # Tokenize the input prompt with truncation inputs = tokenizer(prompt, return_tensors="pt", truncation=True) input_ids = inputs["input_ids"].to(model.device) - + # Generate text using the model's generate method directly generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + generated_text = tokenizer.decode(generated_ids[0], + skip_special_tokens=True) outputs.append((generated_ids[0].tolist(), generated_text)) return outputs + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) From 76022d30600c8f550cc99838bb9e82d84f988520 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 31 Jul 2024 19:39:55 +0000 Subject: [PATCH 15/40] fixup --- vllm/attention/backends/placeholder_attn.py | 137 +++++++++++++++++++- vllm/worker/model_runner.py | 5 +- 2 files changed, 133 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index f5728756c6e5d..21966af8933eb 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -4,7 +4,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) # Placeholder attention backend for models like Mamba that don't have attention. # Mainly exists to sidestep get_attn_backend. @@ -22,6 +23,10 @@ def get_name() -> str: def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: return PlaceholderAttentionImpl + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + @staticmethod def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: return PlaceholderAttentionMetadata @@ -33,6 +38,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: + return None return (1, 1, 1, 1, 1) @staticmethod @@ -108,14 +114,13 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: assert self.seq_lens_tensor is not None assert self.query_start_loc is not None assert self.context_lens_tensor is not None - assert self.block_tables is not None assert self.seq_start_loc is not None self._cached_prefill_metadata = PlaceholderAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + slot_mapping=None, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -124,7 +129,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + block_tables=None, use_cuda_graph=False, ) return self._cached_prefill_metadata @@ -136,14 +141,13 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self._cached_decode_metadata is not None: return self._cached_decode_metadata - assert self.block_tables is not None assert self.seq_lens_tensor is not None self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, @@ -152,11 +156,130 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], + block_tables=None, use_cuda_graph=self.use_cuda_graph, ) return self._cached_decode_metadata +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(self.runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + num_decode_tokens = batch_size + + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=None, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=None, + use_cuda_graph=use_captured_graph, + ) + class PlaceholderAttentionImpl(AttentionImpl): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a57756232394d..fd15f33a3547e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1532,8 +1532,9 @@ def forward( self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + if self.backend_name != "No attention": + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) From 9ffc0572635f46fa044e19fe99a03cc066577c8d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 31 Jul 2024 19:48:21 +0000 Subject: [PATCH 16/40] backend selector changes --- vllm/attention/backends/placeholder_attn.py | 1 - vllm/attention/selector.py | 10 +++++++++- vllm/worker/cache_engine.py | 1 + 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 21966af8933eb..770f1c099b599 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -39,7 +39,6 @@ def get_kv_cache_shape( head_size: int, ) -> Tuple[int, ...]: return None - return (1, 1, 1, 1, 1) @staticmethod def swap_blocks( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 8fcd85585a18f..0ae6326587d33 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -22,6 +22,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() + NO_ATTENTION = enum.auto() @lru_cache(maxsize=None) @@ -33,6 +34,7 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, + num_attention_layers: int, is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -45,7 +47,7 @@ def get_attn_backend( backend = which_attn_to_use(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, - block_size) + num_attention_layers, block_size) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -84,6 +86,8 @@ def get_attn_backend( logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend + elif backend == _Backend.NO_ATTENTION: + return PlaceholderAttentionBackend else: raise ValueError("Invalid attention backend.") @@ -96,11 +100,15 @@ def which_attn_to_use( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, + num_attention_layers: int, ) -> _Backend: """Returns which flash attention backend to use.""" # Default case. selected_backend = _Backend.FLASH_ATTN + if num_attention_layers == 0: + return _Backend.NO_ATTENTION + # Check the environment variable and override if specified backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252440c7b7e08..ac2a513f36f5b 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -60,6 +60,7 @@ def __init__( model_config.dtype, cache_config.cache_dtype, self.block_size, + self.num_attention_layers, ) # Initialize the cache. From 65d7e220397a3d1b1ee82eb476cfde648c871b52 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 31 Jul 2024 20:16:29 +0000 Subject: [PATCH 17/40] lint --- vllm/attention/backends/placeholder_attn.py | 8 ++++++-- vllm/attention/selector.py | 2 ++ vllm/worker/model_runner.py | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 770f1c099b599..490b7261a26a1 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import torch @@ -7,6 +7,9 @@ AttentionMetadata, AttentionMetadataBuilder) +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + # Placeholder attention backend for models like Mamba that don't have attention. # Mainly exists to sidestep get_attn_backend. # The attention metadata is still needed for Mamba. @@ -38,7 +41,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return None + return (1, 1, 1, 1, 1) @staticmethod def swap_blocks( @@ -160,6 +163,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: ) return self._cached_decode_metadata + class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 0ae6326587d33..6f43c74fb1f85 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -87,6 +87,8 @@ def get_attn_backend( from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend elif backend == _Backend.NO_ATTENTION: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) return PlaceholderAttentionBackend else: raise ValueError("Invalid attention backend.") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fd15f33a3547e..d26c22aeb267e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1534,7 +1534,8 @@ def forward( non_blocking=True) if self.backend_name != "No attention": self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + attn_metadata.decode_metadata.block_tables, + non_blocking=True) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) From e76a6178651d728e80dda133726c605779191439 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 20 Aug 2024 22:21:26 +0000 Subject: [PATCH 18/40] Factor out mamba cache from jamba.py, and fixes --- vllm/attention/layer.py | 6 +- vllm/attention/selector.py | 10 +- vllm/model_executor/models/jamba.py | 227 ++-------------------- vllm/model_executor/models/mamba_cache.py | 222 +++++++++++++++++++++ vllm/worker/cache_engine.py | 2 +- vllm/worker/model_runner.py | 1 + 6 files changed, 246 insertions(+), 222 deletions(-) create mode 100644 vllm/model_executor/models/mamba_cache.py diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9b..db385eb066a11 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -42,10 +42,12 @@ def __init__( kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size sliding_window = cache_config.sliding_window + is_attention_free = cache_config.is_attention_free else: kv_cache_dtype = "auto" block_size = 16 sliding_window = None + is_attention_free = False if num_kv_heads is None: num_kv_heads = num_heads @@ -78,8 +80,8 @@ def __init__( dtype = torch.get_default_dtype() attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, - block_size, blocksparse_params - is not None) + block_size, is_attention_free, + blocksparse_params is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 10778f4adf2d3..4277025bf0168 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -96,11 +96,11 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - num_attention_layers: int, + is_attention_free: bool, #TODO: pass in from all users is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - + if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( @@ -109,7 +109,7 @@ def get_attn_backend( backend = which_attn_to_use(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, - num_attention_layers, block_size) + block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -164,7 +164,7 @@ def which_attn_to_use( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - num_attention_layers: int, + is_attention_free: bool, ) -> _Backend: """Returns which flash attention backend to use.""" # Default case. @@ -172,7 +172,7 @@ def which_attn_to_use( # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION - if num_attention_layers == 0: + if is_attention_free: return _Backend.NO_ATTENTION # Check whether a particular choice of backend was diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b82eb14fb5f23..2004dad0620f6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -37,6 +37,8 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) +from vllm.model_executor.models.mamba_cache import MambaCacheManager + KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -593,10 +595,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, ) # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.mamba_cache = MambaCacheManager(config) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -608,8 +608,11 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): - if not self.mamba_cache: - self._prepare_mamba_cache() + if not self.mamba_cache.initialized: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config else + max(_BATCH_SIZES_TO_CAPTURE) + 2) + self.mamba_cache.prepare(self.lm_head.weight.dtype, max_batch_size) if "seqlen_agnostic_capture_inputs" not in kwargs: # We get here only on Prefill/Eager mode runs @@ -619,11 +622,11 @@ def forward(self, request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) + self.mamba_cache.release(finished_requests_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) - mamba_cache = self._prepare_current_run_mamba_cache( + mamba_cache = self.mamba_cache.prepare_current_run_state( request_ids_to_seq_ids, batch_size, finished_requests_ids) else: # CUDA graph capturing runs @@ -634,215 +637,11 @@ def forward(self, mamba_cache[1]) return hidden_states - def _swap_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, [to_index,from_index]] = \ - cache_t[:, [from_index,to_index]] - - def _copy_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): - if index in all_occupied_indices: - first_free_index = self._first_free_index_in_mamba_cache() - # In case occupied, move the occupied to a new empty block - self._move_cache_index_and_mappings(from_index=index, - to_index=first_free_index) - - def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, - seq_id: int, - destination_index: int): - """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. - """ - all_occupied_indices = self._get_all_occupied_indices() - if cur_rid not in self.mamba_cache_indices_mapping: - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - self.mamba_cache_indices_mapping[cur_rid] = { - seq_id: destination_index - } - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have - # already happened now we only need to copy the already - # existing cache into the siblings seq_ids caches - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - index_exists = list(seq_ids2indices.values())[0] - # case of decoding n>1, copy prefill cache to decoding indices - self._copy_mamba_cache(from_index=index_exists, - to_index=destination_index) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = destination_index - else: - # already exists - cache_index_already_exists = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - if cache_index_already_exists != destination_index: - # In case the seq id already exists but not in - # the right destination, swap it with what's occupying it - self._swap_pair_indices_and_mappings( - from_index=cache_index_already_exists, - to_index=destination_index) - - def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], - batch_size: int, finished_requests_ids: List[str]): - running_indices = [] - request_ids_to_seq_ids_flatten = [ - (req_id, seq_id) - for req_id, seq_ids in request_ids_to_seq_ids.items() - for seq_id in seq_ids - ] - for dest_index, (request_id, - seq_id) in enumerate(request_ids_to_seq_ids_flatten): - if request_id in finished_requests_ids: - # Do not allocate cache index for requests that run - # and finish right after - continue - self._assign_seq_id_to_mamba_cache_in_specific_dest( - request_id, seq_id, dest_index) - running_indices.append(dest_index) - - self._clean_up_first_bs_blocks(batch_size, running_indices) - conv_state = self.mamba_cache[0][:, :batch_size] - temporal_state = self.mamba_cache[1][:, :batch_size] - - return (conv_state, temporal_state) - - def _get_all_occupied_indices(self): - return [ - cache_idx - for seq_ids2indices in self.mamba_cache_indices_mapping.values() - for cache_idx in seq_ids2indices.values() - ] - - def _clean_up_first_bs_blocks(self, batch_size: int, - indices_for_current_run: List[int]): - # move out all of the occupied but currently not running blocks - # outside of the first n blocks - destination_indices = set([range(batch_size)]) - max_possible_batch_size = self.mamba_cache[0].shape[1] - for destination_index in destination_indices: - if destination_index in self._get_all_occupied_indices() and \ - destination_index not in indices_for_current_run: - # move not running indices outside of the batch - all_other_indices = list( - range(batch_size, max_possible_batch_size)) - first_avail_index = self._first_free_index_in_mamba_cache( - all_other_indices) - self._swap_indices(from_index=destination_index, - to_index=first_avail_index) - - def _move_cache_index_and_mappings(self, from_index: int, to_index: int): - self._copy_mamba_cache(from_index=from_index, to_index=to_index) - self._update_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): - self._swap_mamba_cache(from_index=from_index, to_index=to_index) - self._swap_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - elif to_index == index: - seq_ids2index.update({seq_id: from_index}) - - def _update_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - return - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (JambaForCausalLM.mamba_gc_cache_buffer). - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - cg_batch_size = input_buffers['input_ids'].shape[0] - self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size, - finished_requests_ids) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) - - def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache( - self, indices_range: Optional[List[int]] = None) -> int: - assert self.mamba_cache is not None - if indices_range is None: - max_possible_batch_size = self.mamba_cache[0].shape[1] - indices_range = list(range(max_possible_batch_size)) - all_occupied_indices = self._get_all_occupied_indices() - for i in indices_range: - if i not in all_occupied_indices: - return i - raise Exception("Couldn't find a free spot in the mamba cache! This" - "should never happen") - - def _get_mamba_cache_shape( - self - ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_conv, - ) - temporal_state_shape = ( - self.config.mamba_expand * self.config.hidden_size // world_size, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - - def _prepare_mamba_cache(self): - dtype = self.lm_head.weight.dtype - layers_type = self.config.layers_block_type - mamba_layers = sum( - [layer_type == "mamba" for layer_type in layers_type]) - max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config else - max(_BATCH_SIZES_TO_CAPTURE) + 2) - conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() - assert conv_state_shape is not None and temporal_state_shape is not None - - self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def compute_logits( self, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py new file mode 100644 index 0000000000000..5b7016f57e185 --- /dev/null +++ b/vllm/model_executor/models/mamba_cache.py @@ -0,0 +1,222 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.distributed import (get_tensor_model_parallel_world_size) + +class MambaCacheManager: + def __init__(self, config): + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.config = config + + self.initialized = False + + def _swap_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, [to_index,from_index]] = \ + cache_t[:, [from_index,to_index]] + + def _copy_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def _move_out_if_already_occupied(self, index: int, + all_occupied_indices: List[int]): + if index in all_occupied_indices: + first_free_index = self._first_free_index_in_mamba_cache() + # In case occupied, move the occupied to a new empty block + self._move_cache_index_and_mappings(from_index=index, + to_index=first_free_index) + + def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, + seq_id: int, + destination_index: int): + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + all_occupied_indices = self._get_all_occupied_indices() + if cur_rid not in self.mamba_cache_indices_mapping: + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + self.mamba_cache_indices_mapping[cur_rid] = { + seq_id: destination_index + } + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened now we only need to copy the already + # existing cache into the siblings seq_ids caches + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + index_exists = list(seq_ids2indices.values())[0] + # case of decoding n>1, copy prefill cache to decoding indices + self._copy_mamba_cache(from_index=index_exists, + to_index=destination_index) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = destination_index + else: + # already exists + cache_index_already_exists = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + if cache_index_already_exists != destination_index: + # In case the seq id already exists but not in + # the right destination, swap it with what's occupying it + self._swap_pair_indices_and_mappings( + from_index=cache_index_already_exists, + to_index=destination_index) + + def prepare_current_run_state( + self, request_ids_to_seq_ids: Dict[str, list[int]], + batch_size: int, finished_requests_ids: List[str]): + running_indices = [] + request_ids_to_seq_ids_flatten = [ + (req_id, seq_id) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + for dest_index, (request_id, + seq_id) in enumerate(request_ids_to_seq_ids_flatten): + if request_id in finished_requests_ids: + # Do not allocate cache index for requests that run + # and finish right after + continue + self._assign_seq_id_to_mamba_cache_in_specific_dest( + request_id, seq_id, dest_index) + running_indices.append(dest_index) + + self._clean_up_first_bs_blocks(batch_size, running_indices) + conv_state = self.mamba_cache[0][:, :batch_size] + temporal_state = self.mamba_cache[1][:, :batch_size] + + return (conv_state, temporal_state) + + def _get_all_occupied_indices(self): + return [ + cache_idx + for seq_ids2indices in self.mamba_cache_indices_mapping.values() + for cache_idx in seq_ids2indices.values() + ] + + def _clean_up_first_bs_blocks(self, batch_size: int, + indices_for_current_run: List[int]): + # move out all of the occupied but currently not running blocks + # outside of the first n blocks + destination_indices = set([range(batch_size)]) + max_possible_batch_size = self.mamba_cache[0].shape[1] + for destination_index in destination_indices: + if destination_index in self._get_all_occupied_indices() and \ + destination_index not in indices_for_current_run: + # move not running indices outside of the batch + all_other_indices = list( + range(batch_size, max_possible_batch_size)) + first_avail_index = self._first_free_index_in_mamba_cache( + all_other_indices) + self._swap_indices(from_index=destination_index, + to_index=first_avail_index) + + def _move_cache_index_and_mappings(self, from_index: int, to_index: int): + self._copy_mamba_cache(from_index=from_index, to_index=to_index) + self._update_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): + self._swap_mamba_cache(from_index=from_index, to_index=to_index) + self._swap_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + elif to_index == index: + seq_ids2index.update({seq_id: from_index}) + + def _update_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + return + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + self.release(finished_requests_ids) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + cg_batch_size = input_buffers['input_ids'].shape[0] + self.prepare_current_run_state(request_ids_to_seq_ids, + cg_batch_size, + finished_requests_ids) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) + + def release(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache( + self, indices_range: Optional[List[int]] = None) -> int: + assert self.mamba_cache is not None + if indices_range is None: + max_possible_batch_size = self.mamba_cache[0].shape[1] + indices_range = list(range(max_possible_batch_size)) + all_occupied_indices = self._get_all_occupied_indices() + for i in indices_range: + if i not in all_occupied_indices: + return i + raise Exception("Couldn't find a free spot in the mamba cache! This" + "should never happen") + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_conv, + ) + temporal_state_shape = ( + self.config.mamba_expand * self.config.hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def prepare(self, dtype, max_batch_size): + layers_type = self.config.layers_block_type + num_mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + + self.mamba_cache = (torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac2a513f36f5b..aa7499dfb6151 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -60,7 +60,7 @@ def __init__( model_config.dtype, cache_config.cache_dtype, self.block_size, - self.num_attention_layers, + model_config.is_attention_free() ) # Initialize the cache. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fc42b80902c51..9c461c4d81d15 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -859,6 +859,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, + self.model_config.is_attention_free(), ) if num_attn_heads else PlaceholderAttentionBackend() # Multi-modal data support From b9723fe44fc3191466ee4936e2335975b6eac7d0 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 21 Aug 2024 14:10:57 +0000 Subject: [PATCH 19/40] Fix mamba cache initialized bool. format and renames --- vllm/attention/layer.py | 2 +- vllm/attention/selector.py | 4 +-- vllm/model_executor/models/jamba.py | 25 +++++++------- vllm/model_executor/models/mamba_cache.py | 42 +++++++++++++---------- vllm/worker/cache_engine.py | 12 +++---- 5 files changed, 43 insertions(+), 42 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index db385eb066a11..06d2e55cfa710 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -80,7 +80,7 @@ def __init__( dtype = torch.get_default_dtype() attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, - block_size, is_attention_free, + block_size, is_attention_free, blocksparse_params is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4277025bf0168..ccd42f35ada77 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -96,11 +96,11 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, #TODO: pass in from all users + is_attention_free: bool, #TODO: pass in from all users is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - + if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2004dad0620f6..aab6192f92a44 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,7 +1,7 @@ # coding=utf-8 """Inference-only Jamba model.""" from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -31,14 +31,13 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) -from vllm.model_executor.models.mamba_cache import MambaCacheManager - KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -610,9 +609,10 @@ def forward(self, **kwargs): if not self.mamba_cache.initialized: max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config else - max(_BATCH_SIZES_TO_CAPTURE) + 2) - self.mamba_cache.prepare(self.lm_head.weight.dtype, max_batch_size) + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + self.mamba_cache.initialize_tensors(self.lm_head.weight.dtype, + max_batch_size) if "seqlen_agnostic_capture_inputs" not in kwargs: # We get here only on Prefill/Eager mode runs @@ -622,23 +622,24 @@ def forward(self, request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] finished_requests_ids = kwargs["finished_requests_ids"] - self.mamba_cache.release(finished_requests_ids) + self.mamba_cache.release_finished_requests(finished_requests_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) - mamba_cache = self.mamba_cache.prepare_current_run_state( + mamba_cache_tensors = self.mamba_cache.prepare_current_run_state( request_ids_to_seq_ids, batch_size, finished_requests_ids) else: # CUDA graph capturing runs - mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] + mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache[0], - mamba_cache[1]) + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 5b7016f57e185..ef44009471b34 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -2,9 +2,11 @@ import torch -from vllm.distributed import (get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size + class MambaCacheManager: + def __init__(self, config): self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() @@ -12,7 +14,6 @@ def __init__(self, config): # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.config = config - self.initialized = False def _swap_mamba_cache(self, from_index: int, to_index: int): @@ -75,9 +76,10 @@ def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, from_index=cache_index_already_exists, to_index=destination_index) - def prepare_current_run_state( - self, request_ids_to_seq_ids: Dict[str, list[int]], - batch_size: int, finished_requests_ids: List[str]): + def prepare_current_run_state(self, + request_ids_to_seq_ids: Dict[str, list[int]], + batch_size: int, + finished_requests_ids: List[str]): running_indices = [] request_ids_to_seq_ids_flatten = [ (req_id, seq_id) @@ -157,11 +159,10 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) finished_requests_ids = kwargs["finished_requests_ids"] - self.release(finished_requests_ids) + self.release_finished_requests(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] - self.prepare_current_run_state(request_ids_to_seq_ids, - cg_batch_size, + self.prepare_current_run_state(request_ids_to_seq_ids, cg_batch_size, finished_requests_ids) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): @@ -172,7 +173,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) - def release(self, finished_seq_groups_req_ids: List[str]): + def release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.mamba_cache_indices_mapping: self.mamba_cache_indices_mapping.pop(req_id) @@ -200,23 +202,25 @@ def _get_mamba_cache_shape( self.config.mamba_d_conv, ) temporal_state_shape = ( - self.config.mamba_expand * self.config.hidden_size // world_size, + self.config.mamba_expand * hidden_size // world_size, self.config.mamba_d_state, ) return conv_state_shape, temporal_state_shape - def prepare(self, dtype, max_batch_size): + def initialize_tensors(self, dtype, max_batch_size): layers_type = self.config.layers_block_type num_mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None - self.mamba_cache = (torch.empty(size=(num_mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) + self.mamba_cache = (torch.empty( + size=(num_mamba_layers, max_batch_size) + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty( + size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + self.initialized = True diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index aa7499dfb6151..52c3dec6453f1 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -54,14 +54,10 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend( model_config.get_num_attention_heads(parallel_config), - self.head_size, - self.num_kv_heads, - model_config.get_sliding_window(), - model_config.dtype, - cache_config.cache_dtype, - self.block_size, - model_config.is_attention_free() - ) + self.head_size, self.num_kv_heads, + model_config.get_sliding_window(), model_config.dtype, + cache_config.cache_dtype, self.block_size, + model_config.is_attention_free()) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( From b2a8cd838be469feabc686d773c616c83562614d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 21 Aug 2024 17:28:21 +0000 Subject: [PATCH 20/40] Refactor mamba to use the MambaCacheManager --- vllm/model_executor/models/jamba.py | 29 ++- vllm/model_executor/models/mamba.py | 211 +++------------------- vllm/model_executor/models/mamba_cache.py | 61 +++---- 3 files changed, 79 insertions(+), 222 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index aab6192f92a44..36dc34cc272f8 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -594,7 +594,7 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, ) # Used to track and store by the Mamba cache between steps. - self.mamba_cache = MambaCacheManager(config) + self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -607,12 +607,19 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): - if not self.mamba_cache.initialized: + if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) - self.mamba_cache.initialize_tensors(self.lm_head.weight.dtype, - max_batch_size) + + layers_type = self.config.layers_block_type + num_mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) + + self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype, + num_mamba_layers, + max_batch_size, + *self._get_mamba_cache_shape()) if "seqlen_agnostic_capture_inputs" not in kwargs: # We get here only on Prefill/Eager mode runs @@ -623,6 +630,7 @@ def forward(self, request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] finished_requests_ids = kwargs["finished_requests_ids"] self.mamba_cache.release_finished_requests(finished_requests_ids) + batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) @@ -637,6 +645,19 @@ def forward(self, mamba_cache_tensors[1]) return hidden_states + def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_conv, + ) + temporal_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( input_buffers, **kwargs) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 49cfd5c186800..9a94ca04a43c0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,7 +1,7 @@ # coding=utf-8 """PyTorch MAMBA model.""" from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -28,6 +28,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput @@ -420,15 +421,10 @@ def __init__( self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = self.backbone.embeddings - # Current step used indices - self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Used as an input_buffer for the CUDA graph runs. - self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.mamba_cache: Optional[MambaCacheManager] = None + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -440,8 +436,14 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): - if not self.mamba_cache: - self._prepare_mamba_cache() + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype, + self.config.num_hidden_layers, + max_batch_size, + *self._get_mamba_cache_shape()) if "seqlen_agnostic_capture_inputs" not in kwargs: # We get here only on Prefill/Eager mode runs @@ -451,169 +453,25 @@ def forward(self, request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] finished_requests_ids = kwargs["finished_requests_ids"] + self.mamba_cache.release_finished_requests(finished_requests_ids) + batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) - ( - current_seqlen_agnostic_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size, - finished_requests_ids) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) + mamba_cache_tensors = self.mamba_cache.prepare_current_run_state( + request_ids_to_seq_ids, batch_size, finished_requests_ids) + else: # CUDA graph capturing runs - current_seqlen_agnostic_cache, indices = ( - kwargs["seqlen_agnostic_capture_inputs"], - [], - ) - self.current_indices = indices + mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) - - if "seqlen_agnostic_capture_inputs" not in kwargs: - self._copy_mamba_cache_by_indices(self.current_indices, - current_seqlen_agnostic_cache) + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) return hidden_states - def _copy_mamba_cache_by_indices( - self, indices: List[int], - current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): - for i, offset in enumerate(indices): - self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) - - def _copy_mamba_cache(self, index_to: int, index_from: int, - from_buffer: Tuple[torch.Tensor, torch.Tensor]): - assert len(self.mamba_cache) > 0 - for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): - cache_t[:, index_to].copy_(from_buffer_t[:, index_from], - non_blocking=True) - - def _assign_seq_id_to_mamba_cache(self, cur_rid: str, - seqs_id: List[int]) -> List[int]: - indices_for_current_run = [] - for seq_id in seqs_id: - if cur_rid not in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping[cur_rid] = {} - first_free_index = self._first_free_index_in_mamba_cache() - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - ## case of decoding n>1, copy prefill cache to decoding indices - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - first_free_index = self._first_free_index_in_mamba_cache() - index_exist = list(seq_ids2indices.values())[0] - self._copy_mamba_cache(index_from=index_exist, - index_to=first_free_index, - from_buffer=self.mamba_cache) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - else: - index_for_current_run = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - - indices_for_current_run.append(index_for_current_run) - return indices_for_current_run - - def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, - finished_requests_ids: List[str] - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: - indices_for_current_run = [] - for request_id, seqs_id in request_ids_to_seq_ids.items(): - if request_id in finished_requests_ids: - # Do not allocate cache for requests that run - # and finish right after - continue - indices_for_current_run += self._assign_seq_id_to_mamba_cache( - request_id, seqs_id) - ## Pad the batch in case of running batch that was not captured via CG - padded_indices = indices_for_current_run.copy() - pad_index = self._first_free_index_in_mamba_cache() - - for _ in range(batch_size - len(indices_for_current_run)): - padded_indices.append(pad_index) - - conv_state = self.mamba_cache[0][:, padded_indices] - temporal_state = self.mamba_cache[1][:, padded_indices] - - return (conv_state, temporal_state), indices_for_current_run - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - cg_batch_size = input_buffers['input_ids'].shape[0] - ( - current_mamba_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size, - finished_requests_ids) - self.current_indices = indices - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - - for input_buffer, current_cache_buffer in zip( - input_buffers["seqlen_agnostic_capture_inputs"], - current_mamba_cache): - input_buffer.copy_(current_cache_buffer, non_blocking=True) - - def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the MambaForCausalLM.mamba_cache after CUDA - graph replay run is done. - """ - self._copy_mamba_cache_by_indices( - self.current_indices, - input_buffers["seqlen_agnostic_capture_inputs"]) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return tuple(buffer[:, :batch_size] - for buffer in self.mamba_gc_cache_buffer) - - def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache(self) -> int: - if self.mamba_cache: - max_possible_batch_size = self.mamba_cache[0].shape[1] - occupied = [ - id for seq_ids in self.mamba_cache_indices_mapping.values() - for id in seq_ids.values() - ] - first_free_index = [ - i not in occupied for i in range(max_possible_batch_size) - ].index(True) - return first_free_index - return 0 - - def _get_mamba_cache_shape( - self - ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() conv_state_shape = ( self.config.intermediate_size // world_size, @@ -625,25 +483,12 @@ def _get_mamba_cache_shape( ) return conv_state_shape, temporal_state_shape - def _prepare_mamba_cache(self): - dtype = self.lm_head.weight.dtype - num_mamba_layers = self.config.num_hidden_layers - max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config else - max(_BATCH_SIZES_TO_CAPTURE)) + 10 - conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() - assert conv_state_shape is not None and temporal_state_shape is not None - - for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: - buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) - setattr(self, buffername, buffer) + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index ef44009471b34..54d71d63a63f9 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -7,14 +7,23 @@ class MambaCacheManager: - def __init__(self, config): - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + def __init__(self, dtype, num_mamba_layers, max_batch_size, + conv_state_shape, temporal_state_shape): + + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda") + + self.mamba_cache = (conv_state, temporal_state) # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} - self.config = config - self.initialized = False def _swap_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 @@ -192,35 +201,17 @@ def _first_free_index_in_mamba_cache( raise Exception("Couldn't find a free spot in the mamba cache! This" "should never happen") - def _get_mamba_cache_shape( - self - ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_conv, - ) - temporal_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - - def initialize_tensors(self, dtype, max_batch_size): - layers_type = self.config.layers_block_type - num_mamba_layers = sum( - [layer_type == "mamba" for layer_type in layers_type]) - conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() - assert conv_state_shape is not None and temporal_state_shape is not None - - self.mamba_cache = (torch.empty( - size=(num_mamba_layers, max_batch_size) + conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty( - size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) + def initialize_tensors(self, dtype, num_mamba_layers, max_batch_size, + conv_state_shape, temporal_state_shape): + + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda") + + self.mamba_cache = (conv_state, temporal_state) self.initialized = True From f87a8e2eb046b9d53a722284adf6bfb7ad90d5d1 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 29 Aug 2024 17:54:20 +0000 Subject: [PATCH 21/40] fixes --- examples/offline_inference.py | 2 +- vllm/attention/backends/placeholder_attn.py | 34 +++++++++++++++------ vllm/attention/layer.py | 8 ++--- vllm/attention/selector.py | 11 ++----- vllm/model_executor/models/jamba.py | 10 +++--- vllm/model_executor/models/mamba.py | 10 +++--- vllm/model_executor/models/mamba_cache.py | 6 ++-- vllm/worker/cache_engine.py | 12 ++++---- vllm/worker/cpu_model_runner.py | 3 +- vllm/worker/cpu_worker.py | 3 +- vllm/worker/model_runner.py | 27 +++------------- vllm/worker/openvino_model_runner.py | 2 -- vllm/worker/openvino_worker.py | 3 +- vllm/worker/tpu_model_runner.py | 3 +- vllm/worker/xpu_model_runner.py | 3 +- 15 files changed, 60 insertions(+), 77 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..54002fd8eb148 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="state-spaces/mamba-370m-hf") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 490b7261a26a1..d8daae4f5060e 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -6,13 +6,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder -# Placeholder attention backend for models like Mamba that don't have attention. -# Mainly exists to sidestep get_attn_backend. -# The attention metadata is still needed for Mamba. +# Placeholder attention backend for models like Mamba and embedding models that +# lack attention. class PlaceholderAttentionBackend(AttentionBackend): @@ -34,6 +34,10 @@ def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: return PlaceholderAttentionMetadata + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -118,11 +122,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: assert self.context_lens_tensor is not None assert self.seq_start_loc is not None + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + self._cached_prefill_metadata = PlaceholderAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=None, + slot_mapping=slot_mapping, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -131,7 +139,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=None, + block_tables=block_tables, use_cuda_graph=False, ) return self._cached_prefill_metadata @@ -145,11 +153,15 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: return self._cached_decode_metadata assert self.seq_lens_tensor is not None + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=None, + slot_mapping=slot_mapping, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, @@ -158,7 +170,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, - block_tables=None, + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, ) return self._cached_decode_metadata @@ -266,9 +278,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=None, + slot_mapping=slot_mapping, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -279,7 +295,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, - block_tables=None, + block_tables=block_tables, use_cuda_graph=use_captured_graph, ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 06d2e55cfa710..0112f49876996 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,10 +78,10 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size, is_attention_free, - blocksparse_params is not None) + attn_backend = get_attn_backend(head_size, sliding_window, dtype, + kv_cache_dtype, block_size, + is_attention_free, blocksparse_params + is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ccd42f35ada77..0aff6df6d6d80 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -89,14 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: @lru_cache(maxsize=None) def get_attn_backend( - num_heads: int, head_size: int, - num_kv_heads: int, sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, #TODO: pass in from all users + is_attention_free: bool, is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -107,9 +105,8 @@ def get_attn_backend( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - backend = which_attn_to_use(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size, is_attention_free) + backend = which_attn_to_use(head_size, sliding_window, dtype, + kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -157,9 +154,7 @@ def get_attn_backend( def which_attn_to_use( - num_heads: int, head_size: int, - num_kv_heads: int, sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ba521dde15550..fbbba7e6b7edc 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -616,10 +616,9 @@ def forward(self, num_mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) - self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype, - num_mamba_layers, - max_batch_size, - *self._get_mamba_cache_shape()) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) if "seqlen_agnostic_capture_inputs" not in kwargs: # We get here only on Prefill/Eager mode runs @@ -645,7 +644,8 @@ def forward(self, mamba_cache_tensors[1]) return hidden_states - def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size conv_state_shape = ( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9a94ca04a43c0..38b51c20b7eae 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -440,10 +440,9 @@ def forward(self, max_batch_size = (_get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) - self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype, - self.config.num_hidden_layers, - max_batch_size, - *self._get_mamba_cache_shape()) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, self.config.num_hidden_layers, + max_batch_size, *self._get_mamba_cache_shape()) if "seqlen_agnostic_capture_inputs" not in kwargs: # We get here only on Prefill/Eager mode runs @@ -471,7 +470,8 @@ def forward(self, return hidden_states - def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() conv_state_shape = ( self.config.intermediate_size // world_size, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 54d71d63a63f9..eca15d99e6444 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -1,14 +1,12 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import torch -from vllm.distributed import get_tensor_model_parallel_world_size - class MambaCacheManager: def __init__(self, dtype, num_mamba_layers, max_batch_size, - conv_state_shape, temporal_state_shape): + conv_state_shape, temporal_state_shape): conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state_shape, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 52c3dec6453f1..56dc3da0ab719 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -52,12 +52,12 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend( - model_config.get_num_attention_heads(parallel_config), - self.head_size, self.num_kv_heads, - model_config.get_sliding_window(), model_config.dtype, - cache_config.cache_dtype, self.block_size, - model_config.is_attention_free()) + self.attn_backend = get_attn_backend(self.head_size, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free()) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f69afa4c43149..6a848def06172 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -103,13 +103,12 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, + self.model_config.is_attention_free(), ) # Multi-modal data support diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 52d1806018f51..f07dc5266c965 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -55,13 +55,12 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, # Get attention backend. self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, cache_config.cache_dtype, self.block_size, + self.model_config.is_attention_free(), ) # Initialize the cache. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 76e8f65a1b823..e37e7a9042f8b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,7 +17,6 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -864,23 +863,16 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) self.attn_backend = get_attn_backend( - num_attn_heads, self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free(), - ) if num_attn_heads else None - if self.attn_backend: - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) - else: - self.attn_state = CommonAttentionState(weakref.proxy(self)) + ) + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry @@ -1635,21 +1627,10 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) -""" + if self.backend_name != "No attention": self.input_buffers["slot_mapping"].copy_( attn_metadata.slot_mapping, non_blocking=True) - if self.backend_name != "flashinfer": - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, - non_blocking=True) - if self.backend_name != "No attention": - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, - non_blocking=True) -""" - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) self.attn_state.prepare_graph_input_buffers(self.input_buffers, attn_metadata) if "seqlen_agnostic_capture_inputs" in self.input_buffers: diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index a1d09a2f9e53e..f1b0f76bed88e 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -71,9 +71,7 @@ def __init__( self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index c47f9acc4423d..b7009b3f584c5 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -61,13 +61,12 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.head_size, - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, + self.model_config.is_attention_free(), ) # Initialize the cache. diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32f..b4383d3981b28 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -106,13 +106,12 @@ def __init__( (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), dtype=np.int32) self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, + self.model_config.is_attention_free(), False, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 3894658a095f3..2f07f18bac46c 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -371,13 +371,12 @@ def __init__( self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, + model_config.is_attention_free(), ) # Multi-modal data support From 8e16aca590983e441a0fe1982ffeb6ac32a7471d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 29 Aug 2024 18:24:40 +0000 Subject: [PATCH 22/40] Update to use kernels from #7651 --- vllm/model_executor/models/mamba.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 38b51c20b7eae..7cd0614309b85 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -4,9 +4,6 @@ from typing import Iterable, List, Optional, Tuple import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn from torch.nn.parameter import Parameter from transformers import MambaConfig @@ -21,6 +18,10 @@ MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler @@ -157,7 +158,7 @@ def mamba_forward(self, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) - hidden_states = causal_conv1d_fn( + hidden_states, _ = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, From 120b7616f61ad3d65e9743702daf8becdafd0de2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 29 Aug 2024 18:28:41 +0000 Subject: [PATCH 23/40] some cruft --- examples/offline_inference.py | 2 +- vllm/engine/llm_engine.py | 3 --- vllm/model_executor/models/jamba.py | 14 +++++++------- vllm/model_executor/models/mamba.py | 14 +++++++------- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 54002fd8eb148..9b758fa2479f6 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="state-spaces/mamba-370m-hf") +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ac9ba587395cb..a6de8817946cc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -311,9 +311,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ) if not self.model_config.embedding_mode: - # For all decoders including attention-free models like mamba, - # this must call _initialize_kv_caches, as this is where model - # warmup and CUDA graphs creation happens. self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index dff632fab1dab..7ef8ded6d4c71 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -645,6 +645,13 @@ def forward(self, mamba_cache_tensors[1]) return hidden_states + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + def _get_mamba_cache_shape( self) -> Tuple[Tuple[int, int], Tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() @@ -659,13 +666,6 @@ def _get_mamba_cache_shape( ) return conv_state_shape, temporal_state_shape - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 7cd0614309b85..da61ddd5b8a98 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -471,6 +471,13 @@ def forward(self, return hidden_states + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + def _get_mamba_cache_shape( self) -> Tuple[Tuple[int, int], Tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() @@ -484,13 +491,6 @@ def _get_mamba_cache_shape( ) return conv_state_shape, temporal_state_shape - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, From a5bd7d2cf9b55ddeebc320ed80d5c52201f2263c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 13 Sep 2024 15:41:48 -0400 Subject: [PATCH 24/40] Move test_mamba.py (for #7820) --- tests/models/{ => decoder_only/language}/test_mamba.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/models/{ => decoder_only/language}/test_mamba.py (100%) diff --git a/tests/models/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py similarity index 100% rename from tests/models/test_mamba.py rename to tests/models/decoder_only/language/test_mamba.py From 6546bd9b5b16a3cd1d3eb68c96a0ae2106c49774 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 13 Sep 2024 17:12:15 -0400 Subject: [PATCH 25/40] fixes --- tests/models/decoder_only/language/test_mamba.py | 2 +- vllm/model_executor/models/mamba.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 509027681f404..75caa9f581b3d 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,7 +5,7 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer -from .utils import check_outputs_equal +from ...utils import check_outputs_equal MODELS = [ "state-spaces/mamba-370m-hf", diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index da61ddd5b8a98..1d035a59bb97b 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -24,7 +24,7 @@ selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -32,7 +32,7 @@ from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) From 85a83781ba7aafa42fcf5c3848ff3d480206f1f4 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 12:16:40 -0400 Subject: [PATCH 26/40] Review comments --- vllm/config.py | 7 +-- .../model_loader/weight_utils.py | 35 +++++++++++- vllm/model_executor/models/__init__.py | 15 ++--- vllm/model_executor/models/jamba.py | 29 +++------- vllm/model_executor/models/mamba.py | 55 ++++--------------- vllm/model_executor/models/mamba_cache.py | 40 ++++++++++++-- 6 files changed, 99 insertions(+), 82 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f180b4a9490e5..eeefab607f356 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -420,10 +420,9 @@ def is_attention_free(self) -> bool: # Return true if the model is mamba. # This check should be augmented with more models in the future, # and made more robust if possible. - if hasattr(self.hf_text_config, - "model_type") and self.hf_text_config.model_type == 'mamba': - return True - return False + return hasattr( + self.hf_text_config, + "model_type") and self.hf_text_config.model_type == 'mamba' def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 5051d45dd1154..1e2857ee28cbf 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,7 +6,8 @@ import os import tempfile from collections import defaultdict -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, + Tuple, Union) import filelock import gguf @@ -559,6 +560,38 @@ def row_parallel_weight_loader(param: torch.Tensor, return default_weight_loader(param, loaded_weight) +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], + torch.Tensor]) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17c6126ffc756..04fb638f38995 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -12,6 +12,7 @@ _GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), @@ -22,6 +23,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "EAGLEModel": ("eagle", "EAGLE"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), @@ -30,15 +32,20 @@ "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM"), + "MedusaModel": ("medusa", "Medusa"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), @@ -52,6 +59,7 @@ "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), @@ -62,14 +70,7 @@ "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), - "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), - "MedusaModel": ("medusa", "Medusa"), - "EAGLEModel": ("eagle", "EAGLE"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM"), - "MambaForCausalLM": ("mamba", "MambaForCausalLM"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index cbf4146fbf5a6..f30d393273e81 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -30,7 +30,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -121,8 +122,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): )) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - set_weight_attrs(self.D, {"weight_loader": weight_loader}) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) self.out_proj = RowParallelLinear( self.intermediate_size, @@ -623,24 +626,8 @@ def forward(self, self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - self.mamba_cache.release_finished_requests(finished_requests_ids) - - batch_size = input_ids.shape[0] - if attn_metadata.prefill_metadata: - batch_size = len(request_ids_to_seq_ids) - mamba_cache_tensors = self.mamba_cache.prepare_current_run_state( - request_ids_to_seq_ids, batch_size, finished_requests_ids) - else: - # CUDA graph capturing runs - mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_tensors[0], diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1d035a59bb97b..7909d4b075bfd 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -5,13 +5,11 @@ import torch from torch import nn -from torch.nn.parameter import Parameter from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -27,7 +25,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -67,16 +66,11 @@ def __init__(self, config: MambaConfig, layer_idx): self.conv_kernel_size = config.conv_kernel self.intermediate_size = config.intermediate_size self.time_step_rank = int(config.time_step_rank) - self.use_conv_bias = config.use_conv_bias - - # TODO: ?? - #self.use_bias = config.mamba_proj_bias - self.use_bias = False self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.intermediate_size, - bias=self.use_conv_bias, + bias=config.use_conv_bias, ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -86,7 +80,7 @@ def __init__(self, config: MambaConfig, layer_idx): self.in_proj = MergedColumnParallelLinear(self.hidden_size, [self.intermediate_size] * 2, - bias=self.use_bias) + bias=config.use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.intermediate_size, @@ -101,16 +95,6 @@ def __init__(self, config: MambaConfig, layer_idx): bias=True, skip_bias_add=True) - def weight_loader(param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) - - def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): - weight_loader(param, -torch.exp(loaded_weight.float())) - tp_size = get_tensor_model_parallel_world_size() self.A = nn.Parameter( torch.empty( @@ -120,13 +104,15 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): )) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - set_weight_attrs(self.D, {"weight_loader": weight_loader}) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) self.out_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, - bias=self.use_bias, + bias=config.use_bias, input_is_parallel=True, ) self.activation = config.hidden_act @@ -445,25 +431,8 @@ def forward(self, self.lm_head.weight.dtype, self.config.num_hidden_layers, max_batch_size, *self._get_mamba_cache_shape()) - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - self.mamba_cache.release_finished_requests(finished_requests_ids) - - batch_size = input_ids.shape[0] - if attn_metadata.prefill_metadata: - batch_size = len(request_ids_to_seq_ids) - mamba_cache_tensors = self.mamba_cache.prepare_current_run_state( - request_ids_to_seq_ids, batch_size, finished_requests_ids) - - else: - # CUDA graph capturing runs - mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) hidden_states = self.backbone(input_ids, positions, kv_caches, attn_metadata, mamba_cache_tensors[0], diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 0fcea422c38e7..bb9f7ed8e09e6 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -2,6 +2,8 @@ import torch +from vllm.attention.backends.abstract import AttentionMetadata + class MambaCacheManager: @@ -83,10 +85,11 @@ def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, from_index=cache_index_already_exists, to_index=destination_index) - def prepare_current_run_state(self, - request_ids_to_seq_ids: Dict[str, list[int]], - batch_size: int, - finished_requests_ids: List[str]): + def _prepare_current_run_state(self, + request_ids_to_seq_ids: Dict[str, + list[int]], + batch_size: int, + finished_requests_ids: List[str]): running_indices = [] request_ids_to_seq_ids_flatten = [ (req_id, seq_id) @@ -109,6 +112,31 @@ def prepare_current_run_state(self, return (conv_state, temporal_state) + def current_run_tensors(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, **kwargs): + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + self.release_finished_requests(finished_requests_ids) + + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + mamba_cache_tensors = self._prepare_current_run_state( + request_ids_to_seq_ids, batch_size, finished_requests_ids) + + else: + # CUDA graph capturing runs + mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + + return mamba_cache_tensors + def _get_all_occupied_indices(self): return [ cache_idx @@ -169,8 +197,8 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): self.release_finished_requests(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] - self.prepare_current_run_state(request_ids_to_seq_ids, cg_batch_size, - finished_requests_ids) + self._prepare_current_run_state(request_ids_to_seq_ids, cg_batch_size, + finished_requests_ids) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ From 80e3c770bdf8d13bf15561dffc833c2a7d739a85 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 12:21:29 -0400 Subject: [PATCH 27/40] cache attention free --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index eeefab607f356..642924a50bbbd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -412,6 +412,7 @@ def verify_with_parallel_config( "pipeline parallelism currently. Disabling it.") self.use_async_output_proc = False + @cached_property def is_attention_free(self) -> bool: """Returns True if the model has no attention, i.e. the model has no state that grows with the size of the context. From 184e808b1eeb515bcc4ddb78d51da6ff37ae3046 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 12:31:00 -0400 Subject: [PATCH 28/40] fixup --- vllm/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 642924a50bbbd..a8a5d6d478dca 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,7 @@ import enum import json from dataclasses import dataclass, field, fields +from functools import cached_property from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union) @@ -459,7 +460,7 @@ def get_head_size(self) -> int: # we need to pad head_size 192 to 256 return 256 - if self.is_attention_free(): + if self.is_attention_free: return 0 if hasattr(self.hf_text_config, "head_dim"): @@ -493,7 +494,7 @@ def get_total_num_kv_heads(self) -> int: return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) - if self.is_attention_free(): + if self.is_attention_free: return 0 attributes = [ @@ -547,7 +548,7 @@ def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) - if self.is_attention_free(): + if self.is_attention_free: assert (self.hf_config.model_type == "mamba") return ["mamba"] * num_layers From 05d6aab4324a4dc94f9f4c70e02abc7f9712701e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 12:40:23 -0400 Subject: [PATCH 29/40] fixup --- vllm/engine/arg_utils.py | 2 +- vllm/worker/cache_engine.py | 2 +- vllm/worker/cpu_model_runner.py | 2 +- vllm/worker/cpu_worker.py | 2 +- vllm/worker/openvino_worker.py | 2 +- vllm/worker/tpu_model_runner.py | 2 +- vllm/worker/xpu_model_runner.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index edaae3214a3f6..7aae6861be0cc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -881,7 +881,7 @@ def create_engine_config(self) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, - is_attention_free=model_config.is_attention_free(), + is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 56dc3da0ab719..090f95e6e892c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -57,7 +57,7 @@ def __init__( model_config.dtype, cache_config.cache_dtype, self.block_size, - model_config.is_attention_free()) + model_config.is_attention_free) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 39ab3ebd56ec7..e28764b8fed48 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -356,7 +356,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - self.model_config.is_attention_free(), + self.model_config.is_attention_free, ) # Multi-modal data support diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 27471d4a43d75..eaa3e67f39af3 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -60,7 +60,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.model_config.dtype, cache_config.cache_dtype, self.block_size, - self.model_config.is_attention_free(), + self.model_config.is_attention_free, ) # Initialize the cache. diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index ac63e001b8ff8..1e9bfdcc63253 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -67,7 +67,7 @@ def __init__( self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, - self.model_config.is_attention_free(), + self.model_config.is_attention_free, ) # Initialize the cache. diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 4e2971fe24064..8eb6ee55fa14e 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -118,7 +118,7 @@ def __init__( self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, - self.model_config.is_attention_free(), + self.model_config.is_attention_free, False, ) self.cached_step_outputs: List[torch.Tensor] = [] diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 455741f7f3026..ba1ede3c2f649 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -376,7 +376,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - model_config.is_attention_free(), + model_config.is_attention_free, ) # Multi-modal data support From 4ebd4ccd3071cc47335697f39816bde7b94acf65 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 17:05:33 -0400 Subject: [PATCH 30/40] missed two --- vllm/engine/arg_utils.py | 2 +- vllm/worker/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7aae6861be0cc..db862c2967696 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -992,7 +992,7 @@ def create_engine_config(self) -> EngineConfig: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, - is_attention_free=model_config.is_attention_free(), + is_attention_free=model_config.is_attention_free, use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d4aa1218a4c2d..c207e0a233f7b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -977,7 +977,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - self.model_config.is_attention_free(), + self.model_config.is_attention_free, ) self.attn_state = self.attn_backend.get_state_cls()( weakref.proxy(self)) From ca3788ed4ad51f31bddea4f71f4abb86c938d93d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 18:04:53 -0400 Subject: [PATCH 31/40] Remove is_attention_free from SchedulerConfig --- vllm/config.py | 4 ---- vllm/core/scheduler.py | 2 +- vllm/engine/arg_utils.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a8a5d6d478dca..2289ff431e185 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -952,8 +952,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - is_attention_free: True if the running model does not have state that - grows as the context size increases. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. num_lookahead_slots: The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative @@ -981,7 +979,6 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - is_attention_free: bool, use_v2_block_manager: bool = False, num_lookahead_slots: int = 0, delay_factor: float = 0.0, @@ -1023,7 +1020,6 @@ def __init__(self, self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.is_attention_free = is_attention_free self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fc7a0e3760406..a3abd40efd746 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -315,7 +315,7 @@ def __init__( if self.scheduler_config.use_v2_block_manager: version = "v2" if (self.scheduler_config.embedding_mode - or self.scheduler_config.is_attention_free): + or self.cache_config.is_attention_free): version = "placeholder" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index db862c2967696..1aa6937e2df38 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -992,7 +992,6 @@ def create_engine_config(self) -> EngineConfig: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, - is_attention_free=model_config.is_attention_free, use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, From c67a6501294ce18703cce4b1b6fe4cbf5ad1d836 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 24 Sep 2024 21:43:14 -0400 Subject: [PATCH 32/40] default `is_attention_free` for unit tests --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 2289ff431e185..394a7b8a9016d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -616,7 +616,7 @@ def __init__( gpu_memory_utilization: float, swap_space: float, cache_dtype: str, - is_attention_free: bool, + is_attention_free: bool = False, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, From 9e2edf6b69fac0decbac032debaee9e3ad5badbc Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 25 Sep 2024 09:31:07 -0400 Subject: [PATCH 33/40] Fix attention selector tests --- tests/kernels/test_attention_selector.py | 37 ++++++++++++--------- vllm/attention/backends/placeholder_attn.py | 2 +- vllm/worker/model_runner.py | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index c1fb45955a0e5..f471dcee938be 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch): if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, + 16, False) assert backend.name == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, + 16, False) assert backend.name == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.is_openvino", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, + 16, False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, + False) assert backend.name == name @@ -46,32 +46,37 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=(7, 5)): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + backend = which_attn_to_use(16, None, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + backend = which_attn_to_use(16, None, torch.float16, None, 8, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + backend = which_attn_to_use(16, 1, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + backend = which_attn_to_use(16, None, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + backend = which_attn_to_use(17, None, torch.float16, None, 16, False) + assert backend.name != STR_FLASH_ATTN_VAL + + # Attention-free models should bypass env and use PlaceholderAttention + backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, + True) assert backend.name != STR_FLASH_ATTN_VAL @@ -79,4 +84,4 @@ def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): - which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + which_attn_to_use(16, None, torch.float16, None, 16, False) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index d8daae4f5060e..4ed6b6394d5a4 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "No attention" + return "placeholder-attn" @staticmethod def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c207e0a233f7b..38c76f2edf9ef 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1778,7 +1778,7 @@ def forward( self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - if self.backend_name != "No attention": + if self.backend_name != "placeholder-attn": self.input_buffers["slot_mapping"].copy_( attn_metadata.slot_mapping, non_blocking=True) From 8729b439441144b4896df359442adbdf5290038a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 10 Oct 2024 15:28:40 -0400 Subject: [PATCH 34/40] Review comments --- docs/source/models/supported_models.rst | 7 +++ vllm/attention/backends/placeholder_attn.py | 17 +++++- vllm/commit_id.py | 1 - vllm/config.py | 58 +++++++-------------- vllm/engine/arg_utils.py | 6 +-- vllm/model_executor/models/interfaces.py | 45 +++++++++++++++- vllm/model_executor/models/jamba.py | 4 +- vllm/model_executor/models/mamba.py | 5 +- vllm/model_executor/models/registry.py | 29 ++++++++++- vllm/worker/enc_dec_model_runner.py | 2 +- vllm/worker/model_runner.py | 7 ++- vllm/worker/openvino_model_runner.py | 1 + vllm/worker/xpu_model_runner.py | 2 +- 13 files changed, 124 insertions(+), 60 deletions(-) delete mode 100644 vllm/commit_id.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 084607c155cb0..337db96e8e71d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -151,6 +151,13 @@ Text Generation - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. - ✅︎ + - + - ✅︎ + * - :code:`MambaForCausalLM` + - Mamba + - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. + - ✅︎ + - - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 4ed6b6394d5a4..99c68a863f599 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -72,8 +72,15 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum query length in the batch. None for decoding. + # Maximum query length in the batch. max_query_len: Optional[int] + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int @@ -133,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -164,6 +172,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + decode_query_len=self.decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -245,6 +254,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], "export VLLM_ATTENTION_BACKEND=FLASHINFER.") max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + decode_query_len = max(decode_query_lens) + else: + decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -290,6 +304,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, + decode_query_len=decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, diff --git a/vllm/commit_id.py b/vllm/commit_id.py deleted file mode 100644 index 2542bca2b0287..0000000000000 --- a/vllm/commit_id.py +++ /dev/null @@ -1 +0,0 @@ -__commit__ = "aa808f5e63261587019610377b507eb70d43021d" diff --git a/vllm/config.py b/vllm/config.py index 9d1f453e7c550..f964928aa0a68 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,6 @@ import enum import json from dataclasses import dataclass, field, fields -from functools import cached_property from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union) @@ -197,6 +196,9 @@ def __init__(self, if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + self.is_attention_free = self._init_attention_free() + self.has_inner_state = self._init_has_inner_state() + self.override_neuron_config = override_neuron_config if is_neuron( ) else None self._verify_embedding_mode() @@ -217,6 +219,14 @@ def _init_multimodal_config( return None + def _init_attention_free(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_attention_free_model(architectures) + + def _init_has_inner_state(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.model_has_inner_state(architectures) + def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow", "mistral"]: @@ -406,19 +416,6 @@ def verify_with_parallel_config( "pipeline parallelism currently. Disabling it.") self.use_async_output_proc = False - @cached_property - def is_attention_free(self) -> bool: - """Returns True if the model has no attention, i.e. the model has no - state that grows with the size of the context. - """ - - # Return true if the model is mamba. - # This check should be augmented with more models in the future, - # and made more robust if possible. - return hasattr( - self.hf_text_config, - "model_type") and self.hf_text_config.model_type == 'mamba' - def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -532,36 +529,17 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) return end - start - def contains_seqlen_agnostic_layers( - self, parallel_config: "ParallelConfig") -> bool: - """True for Mamba/SSM models (Jamba)""" - return self._get_num_seqlen_agnostic_layers(parallel_config) > 0 + def get_num_attention_layers(self, + parallel_config: "ParallelConfig") -> int: + if self.is_attention_free: + return 0 - def get_layers_block_type(self, - parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) - if self.is_attention_free: - assert (self.hf_config.model_type == "mamba") - return ["mamba"] * num_layers - # Transformers supports layers_block_type @property - return getattr(self.hf_config, "layers_block_type", - ["attention"] * num_layers) - - def get_num_attention_layers(self, - parallel_config: "ParallelConfig") -> int: - return len([ - t for t in self.get_layers_block_type(parallel_config) - if t == "attention" - ]) - - def _get_num_seqlen_agnostic_layers( - self, parallel_config: "ParallelConfig") -> int: - return len([ - t for t in self.get_layers_block_type(parallel_config) - if t != "attention" - ]) + layers = getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + return len([t for t in layers if t == "attention"]) def get_multimodal_config(self) -> "MultiModalConfig": """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5506575257768..ad101d23de194 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -944,13 +944,9 @@ def create_engine_config(self) -> EngineConfig: use_sliding_window = (model_config.get_sliding_window() is not None) use_spec_decode = self.speculative_model is not None - has_seqlen_agnostic_layers = ( - model_config.contains_seqlen_agnostic_layers( - parallel_config)) if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora - and not self.enable_prompt_adapter - and not has_seqlen_agnostic_layers): + and not self.enable_prompt_adapter): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models with " diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 09f140c751fd9..9377d8fc0984b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -271,7 +271,7 @@ class HasInnerState(Protocol): """ A flag that indicates this model has inner state. Models that has inner state usually need access to the scheduler_config - for max_num_seqs ,etc... (Currently used by Jamba and Mamba) + for max_num_seqs, etc. True for e.g. both Mamba and Jamba. """ def __init__(self, @@ -307,3 +307,46 @@ def has_inner_state( return isinstance(model, _HasInnerStateType) return isinstance(model, HasInnerState) + + +@runtime_checkable +class IsAttentionFree(Protocol): + """The interface required for all models like Mamba that lack attention, + but do have state whose size is constant wrt the number of tokens.""" + + is_attention_free: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has no attention. + Used for block manager and attention backend selection. + True for Mamba but not Jamba. + """ + + def __init__(self) -> None: + ... + + +@runtime_checkable +class _IsAttentionFreeType(Protocol): + has_inner_state: ClassVar[Literal[True]] + + def __init__(self) -> None: + ... + + +@overload +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: + ... + + +@overload +def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]: + ... + + +def is_attention_free( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + if isinstance(model, type): + return isinstance(model, _IsAttentionFreeType) + + return isinstance(model, IsAttentionFree) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 950c42619e420..ac251b88e872c 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -5,14 +5,12 @@ import torch from torch import nn -from torch.nn.parameter import Parameter from transformers import JambaConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 3af61baf60c48..1112a2181135a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -27,7 +27,8 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree) from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs @@ -344,7 +345,7 @@ def forward( return hidden_states -class MambaForCausalLM(nn.Module, HasInnerState): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ad0ba01bf08a3..e21abc8de6ea1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.utils import is_hip -from .interfaces import supports_multimodal, supports_pp +from .interfaces import (has_inner_state, is_attention_free, + supports_multimodal, supports_pp) from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) @@ -359,6 +360,32 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: return any(is_pp(arch) for arch in architectures) + @staticmethod + def model_has_inner_state(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + has_instate = partial(ModelRegistry._check_stateless, + has_inner_state, + default=False) + + return any(has_instate(arch) for arch in architectures) + + @staticmethod + def is_attention_free_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_attn_free = partial(ModelRegistry._check_stateless, + is_attention_free, + default=False) + + return any(is_attn_free(arch) for arch in architectures) + if __name__ == "__main__": (mod_name, cls_name, func, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 59b4b8c4ddf38..6a00444f5098b 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -196,7 +196,7 @@ def execute_model( seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_seqlen_agnostic else {} + } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} with set_forward_context(model_input.attn_metadata): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index af34216fc08b5..b94240d51a5b7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -988,8 +988,7 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( - parallel_config) + self.has_inner_state = model_config.has_inner_state # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in @@ -1481,7 +1480,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "previous_hidden_states"] = previous_hidden_states[: batch_size] - if self.has_seqlen_agnostic: + if self.has_inner_state: # Only used by Mamba-based models CUDA graph atm (Jamba) capture_inputs.update({ "seqlen_agnostic_capture_inputs": @@ -1630,7 +1629,7 @@ def execute_model( seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_seqlen_agnostic else {} + } if self.has_inner_state else {} if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 89f809f7fd727..760b18427e22b 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -79,6 +79,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, + self.model_config.is_attention_free, ) # Multi-modal data support diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 36d3865120bc0..20dceee849ae5 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -377,7 +377,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - model_config.is_attention_free, + self.model_config.is_attention_free, ) # Multi-modal data support From 16d3f1d06d2671b43bc65e566c28724818eddb89 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 10 Oct 2024 16:09:04 -0400 Subject: [PATCH 35/40] format --- vllm/worker/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 58fa17efb8a13..9db3261b8ac36 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,7 +17,6 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, From 4b21a08914dcf3b11abb55f10f86b42e6a06bdd2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 10 Oct 2024 16:13:37 -0400 Subject: [PATCH 36/40] Fix supported_models.rst --- docs/source/models/supported_models.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d866397bf7d83..f5d53edcebd35 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -151,14 +151,12 @@ Text Generation - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. - ✅︎ - - - ✅︎ * - :code:`MambaForCausalLM` - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - ✅︎ - - - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. From ec8ef04348f1f06dadb09f84cb3498a1fa5b5595 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 10 Oct 2024 18:53:24 -0400 Subject: [PATCH 37/40] jambafix --- vllm/model_executor/models/interfaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9377d8fc0984b..dcead65115132 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -327,7 +327,7 @@ def __init__(self) -> None: @runtime_checkable class _IsAttentionFreeType(Protocol): - has_inner_state: ClassVar[Literal[True]] + is_attention_free: ClassVar[Literal[True]] def __init__(self) -> None: ... From 49e1f3c660dac14fc60098a0b3bbfbce441277f7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 11 Oct 2024 09:14:03 -0400 Subject: [PATCH 38/40] fix softfail on cpu tests --- .buildkite/run-cpu-test-ppc64le.sh | 8 +++++++- .buildkite/run-cpu-test.sh | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index 49ae838cf0690..fd60f5b6afeca 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -18,7 +18,13 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" \ + --ignore=tests/models/test_embedding.py \ + --ignore=tests/models/test_oot_registration.py \ + --ignore=tests/models/test_registry.py \ + --ignore=tests/models/test_jamba.py \ + --ignore=tests/models/test_mamba.py \ + --ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported # online inference docker exec cpu-test bash -c " diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 62d3afb0212fd..c2818c38965ea 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -27,6 +27,7 @@ docker exec cpu-test bash -c " pytest -v -s tests/models/decoder_only/language \ --ignore=tests/models/test_fp8.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_mamba.py \ --ignore=tests/models/decoder_only/language/test_granitemoe.py \ --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported From 609e9fbcd59c4262acbe9392e273eb8d096396fd Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 11 Oct 2024 09:28:42 -0400 Subject: [PATCH 39/40] fix for #9233 --- vllm/model_executor/models/registry.py | 36 +++++++------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6063146988560..fbf75c2a15412 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -159,6 +159,8 @@ class _ModelInfo: is_embedding_model: bool supports_multimodal: bool supports_pp: bool + has_inner_state: bool + is_attention_free: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -167,6 +169,8 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": is_embedding_model=is_embedding_model(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), + has_inner_state=has_inner_state(model), + is_attention_free=is_attention_free(model), ) @@ -382,6 +386,12 @@ def is_pp_supported_model( ) -> bool: return self.inspect_model_cls(architectures).supports_pp + def model_has_inner_state(self, architectures: Union[str, List[str]]) -> bool: + return self.inspect_model_cls(architectures).has_inner_state + + def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool: + return self.inspect_model_cls(architectures).is_attention_free + ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( @@ -430,32 +440,6 @@ def _run() -> None: with open(output_file, "wb") as f: f.write(pickle.dumps(result)) - @staticmethod - def model_has_inner_state(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - has_instate = partial(ModelRegistry._check_stateless, - has_inner_state, - default=False) - - return any(has_instate(arch) for arch in architectures) - - @staticmethod - def is_attention_free_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - is_attn_free = partial(ModelRegistry._check_stateless, - is_attention_free, - default=False) - - return any(is_attn_free(arch) for arch in architectures) - if __name__ == "__main__": _run() From 93129e5290db35ee51dd208b7b44c19fd145762e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 11 Oct 2024 09:29:47 -0400 Subject: [PATCH 40/40] format --- vllm/model_executor/models/registry.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fbf75c2a15412..3c8c600c2c026 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -386,10 +386,12 @@ def is_pp_supported_model( ) -> bool: return self.inspect_model_cls(architectures).supports_pp - def model_has_inner_state(self, architectures: Union[str, List[str]]) -> bool: + def model_has_inner_state(self, architectures: Union[str, + List[str]]) -> bool: return self.inspect_model_cls(architectures).has_inner_state - def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool: + def is_attention_free_model(self, architectures: Union[str, + List[str]]) -> bool: return self.inspect_model_cls(architectures).is_attention_free