Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] gemma2 full context length support #10584

Merged
merged 8 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test

MODELS = [
"facebook/opt-125m",
"google/gemma-2-2b-it",
"meta-llama/Llama-3.2-1B",
]

Expand All @@ -42,8 +43,6 @@ def test_vllm_gc_ed():
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
Expand All @@ -56,13 +55,21 @@ def test_models(

os.environ["VLLM_ATTENTION_BACKEND"] = backend

# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,26 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
Expand Down
19 changes: 9 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,15 @@ def __init__(
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in ["gemma2"]))

if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

print_warning_once(
f"{self.hf_text_config.model_type} has interleaved attention, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
if has_interleaved_attention:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding window,
# and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
delattr(self.hf_text_config, "sliding_window")
sliding_window = None

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
Expand Down
13 changes: 7 additions & 6 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,20 @@ def __init__(self,
is_neox_style=True,
)

# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window = (layer_idx % 2 == 1
and config.sliding_window is not None)
del use_sliding_window # Unused.
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")

def forward(
Expand Down