Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Mamba #6484

Merged
merged 52 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
ce630ea
WiP adding support for Mamba
tlrmchlsmth Jul 8, 2024
6c59b06
wip
tlrmchlsmth Jul 9, 2024
eb9bf34
WIP -- runs through. Generates tokens. Bad tokens.
tlrmchlsmth Jul 10, 2024
320f79b
Good output for mamba-370m
tlrmchlsmth Jul 15, 2024
5ab6622
wip
tlrmchlsmth Jul 16, 2024
71173a0
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 16, 2024
25b54d9
cleanup
tlrmchlsmth Jul 16, 2024
ebc12f1
Rename embedding block space manager
tlrmchlsmth Jul 16, 2024
ac60374
cleanup
tlrmchlsmth Jul 16, 2024
adb6713
remove file
tlrmchlsmth Jul 16, 2024
b733a84
format
tlrmchlsmth Jul 16, 2024
fb846ce
apply fix from #6214
tlrmchlsmth Jul 16, 2024
09b1495
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 16, 2024
d8017cb
fixes from 6425
tlrmchlsmth Jul 16, 2024
7ab2b9e
add an integration test
tlrmchlsmth Jul 23, 2024
c319a21
lint
tlrmchlsmth Jul 23, 2024
3374d8f
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 31, 2024
76022d3
fixup
tlrmchlsmth Jul 31, 2024
9ffc057
backend selector changes
tlrmchlsmth Jul 31, 2024
65d7e22
lint
tlrmchlsmth Jul 31, 2024
f14648e
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Aug 20, 2024
e76a617
Factor out mamba cache from jamba.py, and fixes
tlrmchlsmth Aug 20, 2024
b9723fe
Fix mamba cache initialized bool. format and renames
tlrmchlsmth Aug 21, 2024
b2a8cd8
Refactor mamba to use the MambaCacheManager
tlrmchlsmth Aug 21, 2024
9ba8734
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Aug 28, 2024
f87a8e2
fixes
tlrmchlsmth Aug 29, 2024
06b146e
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Aug 29, 2024
8e16aca
Update to use kernels from #7651
tlrmchlsmth Aug 29, 2024
120b761
some cruft
tlrmchlsmth Aug 29, 2024
698f666
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Sep 13, 2024
a5bd7d2
Move test_mamba.py (for #7820)
tlrmchlsmth Sep 13, 2024
6546bd9
fixes
tlrmchlsmth Sep 13, 2024
f42af9b
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Sep 23, 2024
85a8378
Review comments
tlrmchlsmth Sep 24, 2024
80e3c77
cache attention free
tlrmchlsmth Sep 24, 2024
184e808
fixup
tlrmchlsmth Sep 24, 2024
05d6aab
fixup
tlrmchlsmth Sep 24, 2024
4ebd4cc
missed two
tlrmchlsmth Sep 24, 2024
ca3788e
Remove is_attention_free from SchedulerConfig
tlrmchlsmth Sep 24, 2024
c67a650
default `is_attention_free` for unit tests
tlrmchlsmth Sep 25, 2024
9e2edf6
Fix attention selector tests
tlrmchlsmth Sep 25, 2024
f41b474
merge main, support chunked prefill, more tests
tlrmchlsmth Sep 30, 2024
7ef3c68
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 10, 2024
8729b43
Review comments
tlrmchlsmth Oct 10, 2024
5fb01c4
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 10, 2024
16d3f1d
format
tlrmchlsmth Oct 10, 2024
4b21a08
Fix supported_models.rst
tlrmchlsmth Oct 10, 2024
ec8ef04
jambafix
tlrmchlsmth Oct 10, 2024
49e1f3c
fix softfail on cpu tests
tlrmchlsmth Oct 11, 2024
e80b82a
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 11, 2024
609e9fb
fix for #9233
tlrmchlsmth Oct 11, 2024
93129e5
format
tlrmchlsmth Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.

Run `pytest tests/models/test_mamba.py`.
"""
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from ...utils import check_outputs_equal

MODELS = [
"state-spaces/mamba-370m-hf",
]


# Use lower-level interfaces to create this greedy generator, as mamba will
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
# Create a text generation pipeline
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
# Tokenize the input prompt with truncation
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(model.device)

# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)

outputs.append((generated_ids[0].tolist(), generated_text))

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"

hf_outputs = generate_greedy(model, example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
309 changes: 309 additions & 0 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.


class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""

@staticmethod
def get_name() -> str:
return "No attention"

@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl

@staticmethod
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
return PlaceholderAttentionMetadataBuilder

@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return


@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]

# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]

# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None

@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None

if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata

assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.seq_start_loc is not None

# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)

self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=block_tables,
use_cuda_graph=False,
)
return self._cached_prefill_metadata

@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None

if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None

# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)

self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata


class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
"""
is_prompt = inter_data.is_prompt

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)

if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
num_decode_tokens = batch_size

assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)

return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)


class PlaceholderAttentionImpl(AttentionImpl):

def __init__(self, *args, **kwargs) -> None:
return

def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
Loading
Loading