diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index 49ae838cf0690..fd60f5b6afeca 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -18,7 +18,13 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" \ + --ignore=tests/models/test_embedding.py \ + --ignore=tests/models/test_oot_registration.py \ + --ignore=tests/models/test_registry.py \ + --ignore=tests/models/test_jamba.py \ + --ignore=tests/models/test_mamba.py \ + --ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported # online inference docker exec cpu-test bash -c " diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 62d3afb0212fd..c2818c38965ea 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -27,6 +27,7 @@ docker exec cpu-test bash -c " pytest -v -s tests/models/decoder_only/language \ --ignore=tests/models/test_fp8.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_mamba.py \ --ignore=tests/models/decoder_only/language/test_granitemoe.py \ --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ec64a82de84d4..f5d53edcebd35 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -152,6 +152,11 @@ Text Generation - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. - ✅︎ - ✅︎ + * - :code:`MambaForCausalLM` + - Mamba + - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. + - ✅︎ + - * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index c1fb45955a0e5..f471dcee938be 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch): if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, + 16, False) assert backend.name == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, + 16, False) assert backend.name == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.is_openvino", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, + 16, False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) + backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, + False) assert backend.name == name @@ -46,32 +46,37 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=(7, 5)): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + backend = which_attn_to_use(16, None, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + backend = which_attn_to_use(16, None, torch.float16, None, 8, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + backend = which_attn_to_use(16, 1, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + backend = which_attn_to_use(16, None, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + backend = which_attn_to_use(17, None, torch.float16, None, 16, False) + assert backend.name != STR_FLASH_ATTN_VAL + + # Attention-free models should bypass env and use PlaceholderAttention + backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, + True) assert backend.name != STR_FLASH_ATTN_VAL @@ -79,4 +84,4 @@ def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): - which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + which_attn_to_use(16, None, torch.float16, None, 16, False) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py new file mode 100644 index 0000000000000..c27bf6a60a4f4 --- /dev/null +++ b/tests/models/decoder_only/language/test_mamba.py @@ -0,0 +1,295 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +Run `pytest tests/models/test_mamba.py`. +""" +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vllm.sampling_params import SamplingParams +from vllm.worker.model_runner import _get_graph_batch_size + +from ...utils import check_outputs_equal + +MODELS = ["state-spaces/mamba-130m-hf"] + + +# Use lower-level interfaces to create this greedy generator, as mamba will +# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. +def generate_greedy(model_name, example_prompts, max_tokens): + # Create a text generation pipeline + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + # Generate texts from the prompts + outputs = [] + for prompt in example_prompts: + # Tokenize the input prompt with truncation + inputs = tokenizer(prompt, return_tensors="pt", truncation=True) + input_ids = inputs["input_ids"].to(model.device) + + # Generate text using the model's generate method directly + generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_text = tokenizer.decode(generated_ids[0], + skip_special_tokens=True) + + outputs.append((generated_ids[0].tolist(), generated_text)) + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_outputs = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int) -> None: + # Tests chunked prefill in conjunction with n>1. In this case, prefill is + # populated with decoding tokens and we test that it doesn't fail. + # This test might fail if cache is not allocated correctly for n > 1 + # decoding steps inside a chunked prefill forward pass (where we have both + # prefill and decode together ) + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=30, + max_num_seqs=10 # forces prefill chunks with decoding + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int, + chunked_prefill_token_size: int) -> None: + """ + Checks exact match decode between huggingface model and vllm runner with + chunked prefill. + """ + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size + + non_chunked = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy(example_prompts, + max_tokens=max_tokens) + + check_outputs_equal( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [15]) +def test_parallel_sampling( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype) as vllm_model: + for_loop_outputs = [] + for _ in range(10): + for_loop_outputs.append( + # using example_prompts index 1 instead of 0 since with 0 the + # logprobs get really close and the test doesn't pass + vllm_model.generate_greedy([example_prompts[1]], max_tokens) + [0]) + sampling_params = SamplingParams(n=10, + temperature=0.001, + seed=0, + max_tokens=max_tokens) + n_lt_1_outputs = vllm_model.generate([example_prompts[1]], + sampling_params) + token_ids, texts = n_lt_1_outputs[0] + n_lt_1_outputs = [(token_id, text) + for token_id, text in zip(token_ids, texts)] + + check_outputs_equal( + outputs_0_lst=n_lt_1_outputs, + outputs_1_lst=for_loop_outputs, + name_0="vllm_n_lt_1_outputs", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_mamba_cache_cg_padding( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is for verifying that mamba cache is padded to CG captured + # batch size. If it's not, a torch RuntimeError will be raised because + # tensor dimensions aren't compatible + while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + example_prompts.append(example_prompts[0]) + + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + except RuntimeError: + pytest.fail( + "Couldn't run batch size which is not equal to a Cuda Graph " + "captured batch size. " + "Could be related to mamba cache not padded correctly") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models_preemption_recompute( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Tests that outputs are identical with and w/o preemtions (recompute) + assert dtype == "float" + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Mamba inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum Mamba block capacity. + # This could generally happen due to the fact that Mamba does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Mamba inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_state_cleanup( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Mamba state is cleaned up between + # steps, If its not cleaned, an error would be expected. + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_model.generate_greedy([example_prompts[0]] * 100, 1) + except ValueError: + pytest.fail("Mamba inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py new file mode 100644 index 0000000000000..99c68a863f599 --- /dev/null +++ b/vllm/attention/backends/placeholder_attn.py @@ -0,0 +1,324 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + +# Placeholder attention backend for models like Mamba and embedding models that +# lack attention. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "placeholder-attn" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.seq_start_loc is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + decode_query_len=0, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=block_tables, + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + decode_query_len=self.decode_query_len, + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(self.runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + decode_query_len = max(decode_query_lens) + else: + decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + num_decode_tokens = batch_size + + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + decode_query_len=decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9b..0112f49876996 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -42,10 +42,12 @@ def __init__( kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size sliding_window = cache_config.sliding_window + is_attention_free = cache_config.is_attention_free else: kv_cache_dtype = "auto" block_size = 16 sliding_window = None + is_attention_free = False if num_kv_heads is None: num_kv_heads = num_heads @@ -76,9 +78,9 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size, blocksparse_params + attn_backend = get_attn_backend(head_size, sliding_window, dtype, + kv_cache_dtype, block_size, + is_attention_free, blocksparse_params is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 30aa7cb311afb..7edb7676ea2cd 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -24,6 +24,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() + NO_ATTENTION = enum.auto() def backend_name_to_enum(backend_name: str) -> _Backend: @@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: @lru_cache(maxsize=None) def get_attn_backend( - num_heads: int, head_size: int, - num_kv_heads: int, sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, + is_attention_free: bool, is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -105,9 +105,8 @@ def get_attn_backend( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - backend = which_attn_to_use(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size) + backend = which_attn_to_use(head_size, sliding_window, dtype, + kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -146,23 +145,31 @@ def get_attn_backend( logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend + elif backend == _Backend.NO_ATTENTION: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend else: raise ValueError("Invalid attention backend.") def which_attn_to_use( - num_heads: int, head_size: int, - num_kv_heads: int, sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, + is_attention_free: bool, ) -> _Backend: """Returns which flash attention backend to use.""" # Default case. selected_backend = _Backend.FLASH_ATTN + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + return _Backend.NO_ATTENTION + # Check whether a particular choice of backend was # previously forced. # diff --git a/vllm/config.py b/vllm/config.py index 91ba45798b4ba..f964928aa0a68 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -196,6 +196,9 @@ def __init__(self, if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + self.is_attention_free = self._init_attention_free() + self.has_inner_state = self._init_has_inner_state() + self.override_neuron_config = override_neuron_config if is_neuron( ) else None self._verify_embedding_mode() @@ -216,6 +219,14 @@ def _init_multimodal_config( return None + def _init_attention_free(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_attention_free_model(architectures) + + def _init_has_inner_state(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.model_has_inner_state(architectures) + def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow", "mistral"]: @@ -438,6 +449,10 @@ def get_head_size(self) -> int: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 + + if self.is_attention_free: + return 0 + if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. @@ -469,6 +484,9 @@ def get_total_num_kv_heads(self) -> int: return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) + if self.is_attention_free: + return 0 + attributes = [ # For Falcon: "n_head_kv", @@ -511,31 +529,17 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) return end - start - def contains_seqlen_agnostic_layers( - self, parallel_config: "ParallelConfig") -> bool: - """True for Mamba/SSM models (Jamba)""" - return self._get_num_seqlen_agnostic_layers(parallel_config) > 0 + def get_num_attention_layers(self, + parallel_config: "ParallelConfig") -> int: + if self.is_attention_free: + return 0 - def get_layers_block_type(self, - parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) - # Transformers supports layers_block_type @property - return getattr(self.hf_config, "layers_block_type", - ["attention"] * num_layers) - def get_num_attention_layers(self, - parallel_config: "ParallelConfig") -> int: - return len([ - t for t in self.get_layers_block_type(parallel_config) - if t == "attention" - ]) - - def _get_num_seqlen_agnostic_layers( - self, parallel_config: "ParallelConfig") -> int: - return len([ - t for t in self.get_layers_block_type(parallel_config) - if t != "attention" - ]) + # Transformers supports layers_block_type @property + layers = getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + return len([t for t in layers if t == "attention"]) def get_multimodal_config(self) -> "MultiModalConfig": """ @@ -585,6 +589,7 @@ def __init__( gpu_memory_utilization: float, swap_space: float, cache_dtype: str, + is_attention_free: bool = False, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -595,6 +600,7 @@ def __init__( self.swap_space_bytes = swap_space * GiB_bytes self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype + self.is_attention_free = is_attention_free self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 6346711587301..9e1d1b02f6805 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -36,10 +36,10 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 - if version == "embedding": - from vllm.core.embedding_model_block_manager import ( - EmbeddingModelBlockSpaceManager) - return EmbeddingModelBlockSpaceManager + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager raise ValueError(f"Unknown version {version=}") diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/placeholder_block_space_manager.py similarity index 90% rename from vllm/core/embedding_model_block_manager.py rename to vllm/core/placeholder_block_space_manager.py index 476e043ecc52d..a337392bbed53 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -5,9 +5,10 @@ from vllm.utils import Device -class EmbeddingModelBlockSpaceManager(BlockSpaceManager): - """An embedding version of BlockSpaceManager for use in environments - with embedding models where block management is not required. +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: embedding models or attention-free models like Mamba. This class provides the same interface as BlockSpaceManager, but its methods perform no actions or return simple values like True in specific @@ -40,7 +41,7 @@ def append_slots( seq: Sequence, num_lookahead_slots: int, ) -> List[Tuple[int, int]]: - return None # type: ignore + return [] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2d7a27d1377e4..1f0a121711db5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -314,8 +314,9 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" + if (self.scheduler_config.embedding_mode + or self.cache_config.is_attention_free): + version = "placeholder" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index efdcec4ab797a..bdfecabf96f2c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -912,6 +912,7 @@ def create_engine_config(self) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, @@ -945,13 +946,9 @@ def create_engine_config(self) -> EngineConfig: use_sliding_window = (model_config.get_sliding_window() is not None) use_spec_decode = self.speculative_model is not None - has_seqlen_agnostic_layers = ( - model_config.contains_seqlen_agnostic_layers( - parallel_config)) if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora - and not self.enable_prompt_adapter - and not has_seqlen_agnostic_layers): + and not self.enable_prompt_adapter): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models with " diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 5051d45dd1154..1e2857ee28cbf 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,7 +6,8 @@ import os import tempfile from collections import defaultdict -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, + Tuple, Union) import filelock import gguf @@ -559,6 +560,38 @@ def row_parallel_weight_loader(param: torch.Tensor, return default_weight_loader(param, loaded_weight) +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], + torch.Tensor]) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 278dfc52078ef..dcead65115132 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -271,7 +271,7 @@ class HasInnerState(Protocol): """ A flag that indicates this model has inner state. Models that has inner state usually need access to the scheduler_config - for max_num_seqs ,etc... (Currently only used by Jamba) + for max_num_seqs, etc. True for e.g. both Mamba and Jamba. """ def __init__(self, @@ -307,3 +307,46 @@ def has_inner_state( return isinstance(model, _HasInnerStateType) return isinstance(model, HasInnerState) + + +@runtime_checkable +class IsAttentionFree(Protocol): + """The interface required for all models like Mamba that lack attention, + but do have state whose size is constant wrt the number of tokens.""" + + is_attention_free: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has no attention. + Used for block manager and attention backend selection. + True for Mamba but not Jamba. + """ + + def __init__(self) -> None: + ... + + +@runtime_checkable +class _IsAttentionFreeType(Protocol): + is_attention_free: ClassVar[Literal[True]] + + def __init__(self) -> None: + ... + + +@overload +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: + ... + + +@overload +def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]: + ... + + +def is_attention_free( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + if isinstance(model, type): + return isinstance(model, _IsAttentionFreeType) + + return isinstance(model, IsAttentionFree) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 06ec324b3e108..ac251b88e872c 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,18 +1,16 @@ # coding=utf-8 """Inference-only Jamba model.""" from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn -from torch.nn.parameter import Parameter from transformers import JambaConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -29,7 +27,9 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -99,16 +99,6 @@ def __init__(self, config: JambaConfig, layer_idx): bias=True, skip_bias_add=True) - def weight_loader(param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) - - def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): - weight_loader(param, -torch.exp(loaded_weight.float())) - tp_size = get_tensor_model_parallel_world_size() self.A = nn.Parameter( torch.empty( @@ -118,8 +108,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): )) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - set_weight_attrs(self.D, {"weight_loader": weight_loader}) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) self.out_proj = RowParallelLinear( self.intermediate_size, @@ -571,10 +563,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, ) # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.mamba_cache: Optional[MambaCacheManager] = None + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -586,203 +576,36 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): - if not self.mamba_cache: - self._prepare_mamba_cache() - - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - mamba_cache = self._release_finished_and_prepare_mamba_cache( - finished_requests_ids, request_ids_to_seq_ids) - else: - # CUDA graph capturing runs - mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] - - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache[0], - mamba_cache[1]) - return hidden_states - - def _swap_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, [to_index,from_index]] = \ - cache_t[:, [from_index,to_index]] - - def _copy_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): - if index in all_occupied_indices: - first_free_index = self._first_free_index_in_mamba_cache() - # In case occupied, move the occupied to a new empty block - self._move_cache_index_and_mappings(from_index=index, - to_index=first_free_index) - - def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, - seq_id: int, - destination_index: int): - """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. - """ - all_occupied_indices = self._get_all_occupied_indices() - if cur_rid not in self.mamba_cache_indices_mapping: - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - self.mamba_cache_indices_mapping[cur_rid] = { - seq_id: destination_index - } - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have - # already happened now we only need to copy the already - # existing cache into the siblings seq_ids caches - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - index_exists = list(seq_ids2indices.values())[0] - # case of decoding n>1, copy prefill cache to decoding indices - self._copy_mamba_cache(from_index=index_exists, - to_index=destination_index) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = destination_index - else: - # already exists - cache_index_already_exists = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - if cache_index_already_exists != destination_index: - # In case the seq id already exists but not in - # the right destination, swap it with what's occupying it - self._swap_pair_indices_and_mappings( - from_index=cache_index_already_exists, - to_index=destination_index) - - def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], - finished_requests_ids: List[str] - ) -> Tuple[torch.Tensor, torch.Tensor]: - running_indices = [] - request_ids_to_seq_ids_flatten = [ - (req_id, seq_id) - for req_id, seq_ids in request_ids_to_seq_ids.items() - for seq_id in seq_ids - ] - batch_size = len(request_ids_to_seq_ids_flatten) - for dest_index, (request_id, - seq_id) in enumerate(request_ids_to_seq_ids_flatten): - if request_id in finished_requests_ids: - # Do not allocate cache index for requests that run - # and finish right after - continue - self._assign_seq_id_to_mamba_cache_in_specific_dest( - request_id, seq_id, dest_index) - running_indices.append(dest_index) + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) - self._clean_up_first_bs_blocks(batch_size, running_indices) - conv_state = self.mamba_cache[0][:, :batch_size] - temporal_state = self.mamba_cache[1][:, :batch_size] + layers_type = self.config.layers_block_type + num_mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) - return (conv_state, temporal_state) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) - def _get_all_occupied_indices(self): - return [ - cache_idx - for seq_ids2indices in self.mamba_cache_indices_mapping.values() - for cache_idx in seq_ids2indices.values() - ] + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) - def _clean_up_first_bs_blocks(self, batch_size: int, - indices_for_current_run: List[int]): - # move out all of the occupied but currently not running blocks - # outside of the first n blocks - destination_indices = range(batch_size) - max_possible_batch_size = self.mamba_cache[0].shape[1] - for destination_index in destination_indices: - if destination_index in self._get_all_occupied_indices() and \ - destination_index not in indices_for_current_run: - # move not running indices outside of the batch - all_other_indices = list( - range(batch_size, max_possible_batch_size)) - first_avail_index = self._first_free_index_in_mamba_cache( - all_other_indices) - self._swap_indices(from_index=destination_index, - to_index=first_avail_index) - - def _move_cache_index_and_mappings(self, from_index: int, to_index: int): - self._copy_mamba_cache(from_index=from_index, to_index=to_index) - self._update_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): - self._swap_mamba_cache(from_index=from_index, to_index=to_index) - self._swap_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - elif to_index == index: - seq_ids2index.update({seq_id: from_index}) - - def _update_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - return - - def _release_finished_and_prepare_mamba_cache( - self, finished_requests_ids, - request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]: - self._release_mamba_cache(finished_requests_ids) - return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - finished_requests_ids) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) + return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (JambaForCausalLM.mamba_gc_cache_buffer). - """ - self._release_finished_and_prepare_mamba_cache( - kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"]) + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) - - def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache( - self, indices_range: Optional[List[int]] = None) -> int: - assert self.mamba_cache is not None - if indices_range is None: - max_possible_batch_size = self.mamba_cache[0].shape[1] - indices_range = list(range(max_possible_batch_size)) - all_occupied_indices = self._get_all_occupied_indices() - for i in indices_range: - if i not in all_occupied_indices: - return i - raise Exception("Couldn't find a free spot in the mamba cache! This" - "should never happen") + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self - ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size conv_state_shape = ( @@ -790,31 +613,11 @@ def _get_mamba_cache_shape( self.config.mamba_d_conv - 1, ) temporal_state_shape = ( - self.config.mamba_expand * self.config.hidden_size // world_size, + self.config.mamba_expand * hidden_size // world_size, self.config.mamba_d_state, ) return conv_state_shape, temporal_state_shape - def _prepare_mamba_cache(self): - dtype = self.lm_head.weight.dtype - layers_type = self.config.layers_block_type - mamba_layers = sum( - [layer_type == "mamba" for layer_type in layers_type]) - max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config else - max(_BATCH_SIZES_TO_CAPTURE) + 2) - conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() - assert conv_state_shape is not None and temporal_state_shape is not None - - self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py new file mode 100644 index 0000000000000..1112a2181135a --- /dev/null +++ b/vllm/model_executor/models/mamba.py @@ -0,0 +1,499 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree) +from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=config.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=config.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=config.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + def forward(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, conv_state: torch.Tensor, + ssm_state: torch.Tensor): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + ) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = self.backbone.embeddings + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, self.config.num_hidden_layers, + max_batch_size, *self._get_mamba_cache_shape()) + + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel - 1, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py new file mode 100644 index 0000000000000..8d1ba3737d4a5 --- /dev/null +++ b/vllm/model_executor/models/mamba_cache.py @@ -0,0 +1,222 @@ +from typing import Dict, List, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionMetadata + + +class MambaCacheManager: + + def __init__(self, dtype, num_mamba_layers, max_batch_size, + conv_state_shape, temporal_state_shape): + + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda") + + self.mamba_cache = (conv_state, temporal_state) + + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + + def current_run_tensors(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, **kwargs): + """ + Return the tensors for the current run's conv and ssm state. + """ + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + + self._release_finished_requests(finished_requests_ids) + mamba_cache_tensors = self._prepare_current_run_mamba_cache( + request_ids_to_seq_ids, finished_requests_ids) + + else: + # CUDA graph capturing runs + mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + + return mamba_cache_tensors + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + + self._release_finished_requests(finished_requests_ids) + self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + finished_requests_ids) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) + + def _swap_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, [to_index,from_index]] = \ + cache_t[:, [from_index,to_index]] + + def _copy_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def _move_out_if_already_occupied(self, index: int, + all_occupied_indices: List[int]): + if index in all_occupied_indices: + first_free_index = self._first_free_index_in_mamba_cache() + # In case occupied, move the occupied to a new empty block + self._move_cache_index_and_mappings(from_index=index, + to_index=first_free_index) + + def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, + seq_id: int, + destination_index: int): + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + all_occupied_indices = self._get_all_occupied_indices() + if cur_rid not in self.mamba_cache_indices_mapping: + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + self.mamba_cache_indices_mapping[cur_rid] = { + seq_id: destination_index + } + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened now we only need to copy the already + # existing cache into the siblings seq_ids caches + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + index_exists = list(seq_ids2indices.values())[0] + # case of decoding n>1, copy prefill cache to decoding indices + self._copy_mamba_cache(from_index=index_exists, + to_index=destination_index) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = destination_index + else: + # already exists + cache_index_already_exists = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + if cache_index_already_exists != destination_index: + # In case the seq id already exists but not in + # the right destination, swap it with what's occupying it + self._swap_pair_indices_and_mappings( + from_index=cache_index_already_exists, + to_index=destination_index) + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], + finished_requests_ids: List[str]): + running_indices = [] + request_ids_to_seq_ids_flatten = [ + (req_id, seq_id) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + batch_size = len(request_ids_to_seq_ids_flatten) + for dest_index, (request_id, + seq_id) in enumerate(request_ids_to_seq_ids_flatten): + if request_id in finished_requests_ids: + # Do not allocate cache index for requests that run + # and finish right after + continue + self._assign_seq_id_to_mamba_cache_in_specific_dest( + request_id, seq_id, dest_index) + running_indices.append(dest_index) + + self._clean_up_first_bs_blocks(batch_size, running_indices) + conv_state = self.mamba_cache[0][:, :batch_size] + temporal_state = self.mamba_cache[1][:, :batch_size] + + return (conv_state, temporal_state) + + def _get_all_occupied_indices(self): + return [ + cache_idx + for seq_ids2indices in self.mamba_cache_indices_mapping.values() + for cache_idx in seq_ids2indices.values() + ] + + def _clean_up_first_bs_blocks(self, batch_size: int, + indices_for_current_run: List[int]): + # move out all of the occupied but currently not running blocks + # outside of the first n blocks + destination_indices = range(batch_size) + max_possible_batch_size = self.mamba_cache[0].shape[1] + for destination_index in destination_indices: + if destination_index in self._get_all_occupied_indices() and \ + destination_index not in indices_for_current_run: + # move not running indices outside of the batch + all_other_indices = list( + range(batch_size, max_possible_batch_size)) + first_avail_index = self._first_free_index_in_mamba_cache( + all_other_indices) + self._swap_indices(from_index=destination_index, + to_index=first_avail_index) + + def _move_cache_index_and_mappings(self, from_index: int, to_index: int): + self._copy_mamba_cache(from_index=from_index, to_index=to_index) + self._update_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): + self._swap_mamba_cache(from_index=from_index, to_index=to_index) + self._swap_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + elif to_index == index: + seq_ids2index.update({seq_id: from_index}) + + def _update_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + return + + def _release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache( + self, indices_range: Optional[List[int]] = None) -> int: + assert self.mamba_cache is not None + if indices_range is None: + max_possible_batch_size = self.mamba_cache[0].shape[1] + indices_range = list(range(max_possible_batch_size)) + all_occupied_indices = self._get_all_occupied_indices() + for i in indices_range: + if i not in all_occupied_indices: + return i + raise Exception("Couldn't find a free spot in the mamba cache! This" + "should never happen") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b37452877cf0c..3c8c600c2c026 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -14,7 +14,8 @@ from vllm.logger import init_logger from vllm.utils import is_hip -from .interfaces import supports_multimodal, supports_pp +from .interfaces import (has_inner_state, is_attention_free, + supports_multimodal, supports_pp) from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) @@ -52,6 +53,7 @@ "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), @@ -157,6 +159,8 @@ class _ModelInfo: is_embedding_model: bool supports_multimodal: bool supports_pp: bool + has_inner_state: bool + is_attention_free: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -165,6 +169,8 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": is_embedding_model=is_embedding_model(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), + has_inner_state=has_inner_state(model), + is_attention_free=is_attention_free(model), ) @@ -380,6 +386,14 @@ def is_pp_supported_model( ) -> bool: return self.inspect_model_cls(architectures).supports_pp + def model_has_inner_state(self, architectures: Union[str, + List[str]]) -> bool: + return self.inspect_model_cls(architectures).has_inner_state + + def is_attention_free_model(self, architectures: Union[str, + List[str]]) -> bool: + return self.inspect_model_cls(architectures).is_attention_free + ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252440c7b7e08..090f95e6e892c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -52,15 +52,12 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend( - model_config.get_num_attention_heads(parallel_config), - self.head_size, - self.num_kv_heads, - model_config.get_sliding_window(), - model_config.dtype, - cache_config.cache_dtype, - self.block_size, - ) + self.attn_backend = get_attn_backend(self.head_size, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f67b086796411..795511aea6754 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -418,13 +418,12 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, + self.model_config.is_attention_free, ) # Multi-modal data support diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index d6e3670e304d5..b84562851f0f8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -56,13 +56,12 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, # Get attention backend. self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, cache_config.cache_dtype, self.block_size, + self.model_config.is_attention_free, ) # Initialize the cache. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 59b4b8c4ddf38..6a00444f5098b 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -196,7 +196,7 @@ def execute_model( seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_seqlen_agnostic else {} + } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} with set_forward_context(model_input.attn_metadata): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5bc7100732291..9db3261b8ac36 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,7 +17,6 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, @@ -991,8 +990,7 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( - parallel_config) + self.has_inner_state = model_config.has_inner_state # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in @@ -1003,22 +1001,16 @@ def __init__( self.graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) self.attn_backend = get_attn_backend( - num_attn_heads, self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) if num_attn_heads else None - if self.attn_backend: - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) - else: - self.attn_state = CommonAttentionState(weakref.proxy(self)) + self.model_config.is_attention_free, + ) + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry @@ -1498,7 +1490,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "previous_hidden_states"] = previous_hidden_states[: batch_size] - if self.has_seqlen_agnostic: + if self.has_inner_state: # Only used by Mamba-based models CUDA graph atm (Jamba) capture_inputs.update({ "seqlen_agnostic_capture_inputs": @@ -1647,7 +1639,7 @@ def execute_model( seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_seqlen_agnostic else {} + } if self.has_inner_state else {} if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) @@ -1852,10 +1844,14 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) + + if self.backend_name != "placeholder-attn": + self.input_buffers["slot_mapping"].copy_( + attn_metadata.slot_mapping, non_blocking=True) + self.attn_state.prepare_graph_input_buffers( self.input_buffers, attn_metadata, self._is_encoder_decoder_model) + if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index de3088695dfef..760b18427e22b 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -74,13 +74,12 @@ def __init__( self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, + self.model_config.is_attention_free, ) # Multi-modal data support diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 6b818186779b6..24425fece850f 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -70,13 +70,12 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.head_size, - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, + self.model_config.is_attention_free, ) # Initialize the cache. diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index b3ae5b4a9a0ce..f26d1c8cf7dff 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -113,13 +113,12 @@ def __init__( (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), dtype=np.int32) self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, + self.model_config.is_attention_free, False, ) self.cached_step_outputs: List[torch.Tensor] = [] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3851843afc960..ab61e4377f900 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -236,11 +236,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -257,6 +261,7 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, + self.cache_config.is_attention_free, self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -472,14 +477,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): "`dtype` flag in CLI, for example: --dtype=half.") -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, max_model_len) -> None: - if num_gpu_blocks <= 0: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks}" + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: + if not is_attention_free and max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be " diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 612428180226a..20dceee849ae5 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -372,13 +372,12 @@ def __init__( self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, + self.model_config.is_attention_free, ) # Multi-modal data support