From 9d6a8daa87e2e0af3ff45d03d08ad5a94ec089a8 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 3 Jul 2024 02:11:29 +0300 Subject: [PATCH] [Model] Jamba support (#4115) Signed-off-by: Muralidhar Andoorveedu Co-authored-by: Erez Schwartz Co-authored-by: Mor Zusman Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Tomer Asida Co-authored-by: Zhuohan Li Co-authored-by: Muralidhar Andoorveedu --- .buildkite/run-cpu-test.sh | 2 +- Dockerfile | 23 + docs/source/models/supported_models.rst | 4 + requirements-mamba.txt | 3 + tests/models/test_jamba.py | 65 ++ vllm/config.py | 29 +- vllm/core/scheduler.py | 16 +- vllm/engine/async_llm_engine.py | 4 +- vllm/engine/llm_engine.py | 4 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/jamba.py | 955 ++++++++++++++++++++++++ vllm/sequence.py | 4 +- vllm/spec_decode/draft_model_runner.py | 12 +- vllm/worker/cache_engine.py | 15 +- vllm/worker/cpu_model_runner.py | 7 +- vllm/worker/embedding_model_runner.py | 3 +- vllm/worker/model_runner.py | 67 +- vllm/worker/model_runner_base.py | 1 + vllm/worker/neuron_model_runner.py | 1 + vllm/worker/worker_base.py | 3 +- vllm/worker/xpu_model_runner.py | 7 +- 21 files changed, 1192 insertions(+), 34 deletions(-) create mode 100644 requirements-mamba.txt create mode 100644 tests/models/test_jamba.py create mode 100644 vllm/model_executor/models/jamba.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f4fa24be1f20f..9d4b2bb1cd582 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" docker exec cpu-test bash -c "cd tests; pip install pytest Pillow protobuf cd ../ - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported diff --git a/Dockerfile b/Dockerfile index d031d98c5b7e4..f571e8be421e8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt +COPY requirements-mamba.txt requirements-mamba.txt +RUN python3 -m pip install packaging +RUN python3 -m pip install -r requirements-mamba.txt + # cuda arch list used by torch # can be useful for both `dev` and `test` # explicitly set the list to avoid issues with torch 2.2 @@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt #################### DEV IMAGE #################### +#################### MAMBA Build IMAGE #################### +FROM dev as mamba-builder +# max jobs used for build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} + +WORKDIR /usr/src/mamba + +COPY requirements-mamba.txt requirements-mamba.txt + +# Download the wheel or build it if a pre-compiled release doesn't exist +RUN pip --verbose wheel -r requirements-mamba.txt \ + --no-build-isolation --no-deps --no-cache-dir + +#################### MAMBA Build IMAGE #################### #################### vLLM installation IMAGE #################### # image with vLLM installed @@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose + +RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ + --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 544322582f8e9..0283f36ea52b8 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it. - Jais - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - + * - :code:`JambaForCausalLM` + - Jamba + - :code:`ai21labs/Jamba-v0.1`, etc. + - ✅︎ * - :code:`LlamaForCausalLM` - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. diff --git a/requirements-mamba.txt b/requirements-mamba.txt new file mode 100644 index 0000000000000..1838e87d063da --- /dev/null +++ b/requirements-mamba.txt @@ -0,0 +1,3 @@ +# Mamba dependencies +mamba-ssm>=1.2.2 +causal-conv1d>=1.2.0 diff --git a/tests/models/test_jamba.py b/tests/models/test_jamba.py new file mode 100644 index 0000000000000..d7e3a2fc4a71b --- /dev/null +++ b/tests/models/test_jamba.py @@ -0,0 +1,65 @@ +import pytest + +MODELS = ["ai21labs/Jamba-tiny-random"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + assert dtype == "float" + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_state_cleanup( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Jamba state is cleaned up between + # steps, If its not cleaned, an error would be expected. + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_model.generate_greedy([example_prompts[0]] * 100, 1) + except ValueError: + pytest.fail("Jamba inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/config.py b/vllm/config.py index 9a7e0ea7a3a10..8c449323f7a17 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -386,9 +386,36 @@ def get_num_attention_heads(self, return num_heads // parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: - total_num_hidden_layers = self.hf_text_config.num_hidden_layers + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) return total_num_hidden_layers // parallel_config.pipeline_parallel_size + def contains_seqlen_agnostic_layers( + self, parallel_config: "ParallelConfig") -> bool: + """True for Mamba/SSM models (Jamba)""" + return self._get_num_seqlen_agnostic_layers(parallel_config) > 0 + + def get_layers_block_type(self, + parallel_config: "ParallelConfig") -> List[str]: + num_layers = self.get_num_layers(parallel_config) + # Transformers supports layers_block_type @property + return getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + + def get_num_attention_layers(self, + parallel_config: "ParallelConfig") -> int: + return len([ + t for t in self.get_layers_block_type(parallel_config) + if t == "attention" + ]) + + def _get_num_seqlen_agnostic_layers( + self, parallel_config: "ParallelConfig") -> int: + return len([ + t for t in self.get_layers_block_type(parallel_config) + if t != "attention" + ]) + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5fb3b78141b12..9e626b2883975 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -299,7 +299,10 @@ def __init__( # Sequence groups in the SWAPPED state. # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() - + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. + self._finished_requests_ids: List[str] = list() # Time at previous scheduling step self.prev_time = 0.0 # Did we schedule a prompt at previous step? @@ -373,6 +376,12 @@ def has_unfinished_seqs(self) -> bool: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) + def get_and_reset_finished_requests_ids(self) -> List[str]: + """Flushes the list of request ids of previously finished seq_groups.""" + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() + return finished_requests_ids + def _schedule_running( self, running_queue: deque, @@ -1036,6 +1045,11 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: + for queue in [self.running, self.swapped, self.waiting]: + self._finished_requests_ids += [ + seq_group.request_id for seq_group in queue + if seq_group.is_finished() + ] self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0ce511ce42476..13b4635cb8855 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -224,6 +224,8 @@ async def step_async( """ seq_group_metadata_list, scheduler_outputs = self.scheduler[ virtual_engine].schedule() + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() if not scheduler_outputs.is_empty(): # Execute the model. @@ -235,7 +237,7 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - ) + finished_requests_ids=finished_requests_ids) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a790570051491..a7428d0101033 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -846,6 +846,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") seq_group_metadata_list, scheduler_outputs = self.scheduler[ 0].schedule() + finished_requests_ids = self.scheduler[ + 0].get_and_reset_finished_requests_ids() if not scheduler_outputs.is_empty(): execute_model_req = ExecuteModelRequest( @@ -855,7 +857,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - ) + finished_requests_ids=finished_requests_ids) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 69a65ff023bc9..a4fe18d52d608 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -63,6 +63,7 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM") } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py new file mode 100644 index 0000000000000..c485d3779d9a6 --- /dev/null +++ b/vllm/model_executor/models/jamba.py @@ -0,0 +1,955 @@ +# coding=utf-8 +"""Inference-only Jurassic model.""" +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from torch import nn +from torch.nn.parameter import Parameter +from transformers import JambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class JambaMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: JambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.time_step_rank = config.mamba_dt_rank + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + if attn_metadata.prefill_metadata is not None: + offset = 0 + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) + offset += prompt_len + else: + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class JambaMLP(nn.Module): + + def __init__( + self, + config: JambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class JambaMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: JambaConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype) + + self.ws = nn.Parameter( + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype, + )) + self.w2s = nn.Parameter( + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype, + )) + + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("gate_proj.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("up_proj.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("down_proj.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.router(hidden_states) + + final_hidden_states = fused_moe( + hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize= + False, # Mixtral normalize the expert probs to 1. We don't! + inplace=True, + ) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class JambaMambaDecoderLayer(nn.Module): + + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mamba = JambaMambaMixer(config, layer_idx) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class JambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + ) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": JambaAttentionDecoderLayer, + "mamba": JambaMambaDecoderLayer +} + + +class JambaModel(nn.Module): + + def __init__( + self, + config: JambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] + decoder_layers.append( + layer_class(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + current_ssm_state = None + current_conv_state = None + if isinstance(layer, JambaAttentionDecoderLayer): + kv_cache = kv_caches[(i - self.config.attn_layer_offset) // + self.config.attn_layer_period] + if isinstance(layer, JambaMambaDecoderLayer): + current_state_layer = i - (1 + + (i - self.config.attn_layer_offset) + // self.config.attn_layer_period) + current_ssm_state = ssm_state[current_state_layer] + current_conv_state = conv_state[current_state_layer] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class JambaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: JambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.model = JambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if not self.mamba_cache: + self._prepare_mamba_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + else: + # CUDA graph capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) + self.current_indices = indices + + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) + + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) + + return hidden_states + + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): + for i, offset in enumerate(indices): + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) + + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: + indices_for_current_run = [] + for request_id, seqs_id in request_ids_to_seq_ids.items(): + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + request_id, seqs_id) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_mamba_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = len(request_ids_to_seq_ids) + ( + current_mamba_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + self.current_indices = indices + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) + + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the JambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ + self._copy_mamba_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"]) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) + + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.mamba_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_batch_size) + ].index(True) + return first_free_index + return 0 + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_conv, + ) + temporal_state_shape = ( + self.config.mamba_expand * self.config.hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def _prepare_mamba_cache(self): + dtype = self.lm_head.weight.dtype + layers_type = self.config.layers_block_type + mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty(size=(mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ( + "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(self.config.num_experts) + for weight_name in ["down_proj", "up_proj", "gate_proj"] + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if 'experts' in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index b036e76d7ccec..7e08586cdfd93 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -934,6 +934,8 @@ class ExecuteModelRequest: previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 + # Finished request ids since last step. + finished_requests_ids: List[str] = field(default_factory=list) def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -949,4 +951,4 @@ def clone( running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, num_steps=self.num_steps, - ) + finished_requests_ids=self.finished_requests_ids) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index b4c953162e2b4..1c7b8c07e89e5 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -75,15 +75,19 @@ def __init__( List[SequenceGroupMetadata]] = None def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata: + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPUWithSamplingMetadata: """A temporary solution that caches the seq_group_metadata_list for multi-step execution. TODO: In-place update model_input and remove this function. """ self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input(seq_group_metadata_list) + return super().prepare_model_input( + seq_group_metadata_list, + finished_requests_ids=finished_requests_ids) def update_model_input( self, model_input: ModelInputForGPUWithSamplingMetadata, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 891e74f8ab940..252440c7b7e08 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -33,7 +33,9 @@ def __init__( self.device_config = device_config self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_layers(parallel_config) + # Models like Jamba, have mixed typed layers, E.g Mamba + self.num_attention_layers = model_config.get_num_attention_layers( + parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size @@ -75,7 +77,7 @@ def _allocate_kv_cache( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] - for _ in range(self.num_layers): + for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. @@ -87,12 +89,12 @@ def _allocate_kv_cache( return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_layers): + for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) def swap_out(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_layers): + for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) @@ -107,11 +109,12 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = model_config.get_num_layers(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block - total = num_layers * (key_cache_block + value_cache_block) + total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f46e9e8aba9db..fd6c2b8546dfb 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -314,9 +314,10 @@ def make_model_input_from_broadcasted_tensor_dict( ) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index faf6e99ab646f..0e1bb1bfe273d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -120,10 +120,11 @@ def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( - seq_group_metadata_list) + seq_group_metadata_list, finished_requests_ids) # Prepare PoolingMetadata. assert model_input.seq_lens is not None pooling_metadata = self._prepare_pooling(seq_group_metadata_list, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 28b447c0dc8a9..bd30281471d19 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -84,6 +84,8 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None + finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: @@ -94,6 +96,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -128,6 +132,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -191,6 +197,10 @@ def __init__( ] self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. + + self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( + parallel_config) + # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -317,6 +327,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None ) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not @@ -347,6 +358,7 @@ def _prepare_model_input_tensors( block_tables: List[List[int]] = [] multi_modal_kwargs_list: Dict[str, List[torch.Tensor]] = defaultdict(list) + request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) decode_only = True num_prefills = 0 num_prefill_tokens = 0 @@ -738,7 +750,11 @@ def _prepare_model_input_tensors( k: torch.cat(v, dim=0).to(self.device) for k, v in multi_modal_kwargs_list.items() } - + request_ids_to_seq_ids = { + seq_group_metadata.request_id: + list(seq_group_metadata.seq_data.keys()) + for seq_group_metadata in seq_group_metadata_list + } return self._model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, @@ -748,7 +764,8 @@ def _prepare_model_input_tensors( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, - ) + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_requests_ids=finished_requests_ids) @torch.inference_mode() def profile_run(self) -> None: @@ -821,7 +838,9 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs) + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( @@ -1033,21 +1052,37 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: graph_runner.flashinfer_decode_wrapper = \ decode_wrapper - graph_runner.capture( + capture_inputs = { + "input_ids": input_tokens[:batch_size], + "positions": input_positions[:batch_size], + "hidden_or_intermediate_states": hidden_or_intermediate_states[ virtual_engine] # type: ignore [:batch_size] if hidden_or_intermediate_states[virtual_engine] is not None else None, + "intermediate_inputs": intermediate_inputs[:batch_size] if intermediate_inputs is not None else None, + "kv_caches": kv_caches[virtual_engine], + "attn_metadata": attn_metadata, - memory_pool=self.graph_memory_pool, - stream=graph_capture_context.stream, - ) + "memory_pool": + self.graph_memory_pool, + "stream": + graph_capture_context.stream + } + if self.has_seqlen_agnostic: + # Only used by Mamba-based models CUDA graph atm (Jamba) + capture_inputs.update({ + "seqlen_agnostic_capture_inputs": + self.model.get_seqlen_agnostic_capture_inputs( + batch_size) + }) + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( graph_runner) @@ -1084,6 +1119,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1099,7 +1135,7 @@ def prepare_model_input( If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( - seq_group_metadata_list) + seq_group_metadata_list, finished_requests_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, @@ -1175,6 +1211,10 @@ def execute_model( model_executable = self.model multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1182,7 +1222,7 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, - ) + **seqlen_agnostic_kwargs) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: @@ -1305,6 +1345,7 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, + **kwargs, } else: self.input_buffers = { @@ -1315,6 +1356,7 @@ def capture( "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, + **kwargs, } if intermediate_inputs is not None: self.input_buffers.update(intermediate_inputs.tensors) @@ -1349,13 +1391,18 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_inputs_before_cuda_graphs(self.input_buffers, + **kwargs) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) # Run the graph. self.graph.replay() - + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_outputs_after_cuda_graphs(self.input_buffers, + **kwargs) # Return the output tensor. if get_pp_group().is_last_rank: return self.output_buffers["hidden_states"] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index f66bb466228be..bc0960fa16221 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -139,6 +139,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, ) -> T: """ Prepare the inputs to ModelRunnerBase.execute_model from an execution diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index ab8e485281293..8b96966be4704 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -177,6 +177,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 118173a4ca94b..b082f45344863 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -234,7 +234,8 @@ def execute_model( model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine)) + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) num_steps = execute_model_req.num_steps if self.do_metadata_broadcast: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 73b771c4395f8..e652f1b1042e3 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -189,9 +189,10 @@ def make_model_input_from_broadcasted_tensor_dict( )) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: