Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models #7631

Merged
merged 91 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
c08dc0a
Add cuda graph support during decoding for encoder-decoder models
sroy745 Aug 18, 2024
fbc8837
Add logic for CUDA Graph capture
sroy745 Aug 21, 2024
b3b4e4a
Remove extra line
sroy745 Aug 21, 2024
bc9c32e
Remove debugs
sroy745 Aug 21, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
33342fb
Add comments
sroy745 Aug 22, 2024
b3425d6
Merge branch 'main' into enc-dec-cuda-graph
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
47ba15b
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 22, 2024
bf70ceb
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 22, 2024
cef7273
Merge
sroy745 Aug 22, 2024
599cd6b
Move logic to backend/utils.py
sroy745 Aug 23, 2024
a9ca02f
Fix import
sroy745 Aug 23, 2024
6b09ee8
Fix formatting
sroy745 Aug 23, 2024
c199b50
Fix test documentation
sroy745 Aug 23, 2024
539b10e
Remove debug stmt
sroy745 Aug 23, 2024
727b3f2
Add a new test
sroy745 Aug 23, 2024
e2e16cf
Fix Batch Size
sroy745 Aug 23, 2024
1fb7cc6
Fix comments
sroy745 Aug 23, 2024
79e3928
Fix formatting
sroy745 Aug 23, 2024
09f9741
Fix test to run with CUDA Graph
sroy745 Aug 23, 2024
eb52df8
fix format
sroy745 Aug 23, 2024
f27e84a
Dummy commit
sroy745 Aug 23, 2024
ada2d05
Dummy commit
sroy745 Aug 24, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
3c75775
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 25, 2024
7b94a35
Format
sroy745 Aug 25, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
783b37b
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 28, 2024
4b72b1f
Addressing comments
sroy745 Aug 28, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
105f6c3
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 29, 2024
525c541
Addressing comments
sroy745 Aug 29, 2024
9f3ad3f
Merge branch 'main' into enc-dec-cuda-graph
sroy745 Aug 29, 2024
785dfc5
Fix format
sroy745 Aug 29, 2024
758c8d2
Add tests
sroy745 Aug 29, 2024
6f61f97
fix format
sroy745 Aug 29, 2024
f93ac1c
Add comments
sroy745 Aug 29, 2024
e408a00
Dummy fix
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
61af6ed
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 29, 2024
bf7b4fc
Dummy commit
sroy745 Aug 29, 2024
a814a0b
Fix format
sroy745 Aug 29, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
2ca1c2f
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Aug 31, 2024
4617e39
Adding back the assertion
sroy745 Sep 1, 2024
9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 Sep 2, 2024
25cef3d
Merge branch 'vllm-project:main' into main
sroy745 Sep 3, 2024
0d8b6bc
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Sep 3, 2024
2ab85b8
Dummy commit
sroy745 Sep 3, 2024
0592dc0
Dummy commit
sroy745 Sep 3, 2024
9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 Sep 4, 2024
0b3c9e2
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Sep 4, 2024
c8eb247
Merge branch 'main' into enc-dec-cuda-graph
sroy745 Sep 4, 2024
12f8312
Dummy commit
sroy745 Sep 5, 2024
0ea4479
Dummy commit
sroy745 Sep 5, 2024
c48cacb
Merge branch 'vllm-project:main' into main
sroy745 Sep 5, 2024
c42c399
Merge branch 'vllm-project:main' into main
sroy745 Sep 7, 2024
3d13e43
Merge branch 'vllm-project:main' into main
sroy745 Sep 9, 2024
ed68558
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Sep 9, 2024
7417751
Dummy
sroy745 Sep 9, 2024
12c4af4
Dummy
sroy745 Sep 9, 2024
7cd2b43
Dummy Commit to rerun tests
sroy745 Sep 9, 2024
e4338ce
Dummy Commit to rerun tests
sroy745 Sep 9, 2024
7479775
Merge branch 'vllm-project:main' into main
sroy745 Sep 11, 2024
9c54a60
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Sep 11, 2024
df9b966
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
7cc1a49
Merge remote-tracking branch 'origin/main' into enc-dec-cuda-graph
sroy745 Sep 17, 2024
3fb360c
Address comments
sroy745 Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-small.txt -t 1

- label: Encoder Decoder tests # 5min
source_file_dependencies:
- vllm/
- tests/encoder_decoder
commands:
- pytest -v -s encoder_decoder

- label: OpenAI-Compatible Tool Use # 20 min
fast_check: false
mirror_hardwares: [ amd ]
Expand Down
Empty file.
98 changes: 98 additions & 0 deletions tests/encoder_decoder/test_e2e_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""E2E tests to verify the correctness of the encoder-decoder framework
Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
"""
from typing import List, Optional, Tuple

import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu

from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close


def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

hf_output_str = output_str + "</s>"
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str

return output_ids, hf_output_str, out_logprobs


@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.skipif(
is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models"
)
def test_encoder_decoder_e2e(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
decoder_prompt_type: DecoderPromptType,
enforce_eager: bool,
) -> None:
'''
End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent
and correct results.
'''
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]

# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}

with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)

hf_skip_tokens = (1
if decoder_prompt_type == DecoderPromptType.NONE else 0)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
182 changes: 160 additions & 22 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from array import array
from typing import List

Expand All @@ -7,13 +8,9 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu
from vllm.utils import is_cpu, make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner

# CUDA graph scenarios to test
#
# Currently CUDA graph is not supported
ENFORCE_EAGER = [True]
from vllm.worker.model_runner import _get_graph_batch_size

BATCH_SIZES = [1, 4, 16, 64, 256]

Expand All @@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args,
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_empty_seq_group(enforce_eager, ):
def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""

Expand All @@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ):
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
enforce_eager=True,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
Expand Down Expand Up @@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ):
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_prepare_prompt(
batch_size,
enforce_eager,
):
def test_prepare_prompt(batch_size):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
Expand All @@ -115,7 +107,7 @@ def test_prepare_prompt(
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
enforce_eager=True,
)

seq_lens: List[int] = []
Expand Down Expand Up @@ -281,11 +273,7 @@ def test_prepare_prompt(
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
def test_prepare_decode(
batch_size,
enforce_eager,
):
def test_prepare_decode(batch_size):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Expand All @@ -311,7 +299,7 @@ def test_prepare_decode(
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=enforce_eager,
enforce_eager=True,
)

seq_lens: List[int] = []
Expand Down Expand Up @@ -428,7 +416,8 @@ def test_prepare_decode(
expected,
)

# Cuda graph should is currently not supported for encoder/decoer.
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert attn_metadata.use_cuda_graph is False

# Verify the lengths of input tokens & positions
Expand Down Expand Up @@ -484,3 +473,152 @@ def test_prepare_decode(
dtype=actual.dtype,
)
assert torch.equal(actual, expected)


@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
for varying input batch sizes.
"""
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=False,
)

seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)

model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping

# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = _get_graph_batch_size(batch_size)
cuda_graph_pad_size = graph_batch_size - batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
itertools.repeat(1, cuda_graph_pad_size))

assert return_seq_lens == padded_seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)

# Verify attention metadata
device = model_runner.device
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_decode_tokens > 0
assert torch.equal(
attn_metadata.seq_lens_tensor,
torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == padded_seq_lens
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(seq_lens)
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)

# Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected.
expected = [block_tables[0] for _ in range(batch_size)]
expected.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad(
expected,
max_len=64,
pad=0,
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected.
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
expected.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad(
expected,
max_len=64,
pad=0,
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)

# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert attn_metadata.use_cuda_graph is True

# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == len(padded_seq_lens)
assert len(input_positions) == len(padded_seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == 0
assert len(encoder_input_tokens) == 0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
17 changes: 13 additions & 4 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,27 @@ def graph_clone(self, batch_size: int) -> "AttentionState[T]":
...

@abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...

@abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
def get_graph_input_buffers(
self,
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...

@abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
attn_metadata: T) -> None:
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...

Expand Down
Loading
Loading