Skip to content

Commit

Permalink
[V1][Frontend] Coalesce bunched RequestOutputs (vllm-project#12298)
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
  • Loading branch information
2 people authored and LucasWilkinson committed Jan 24, 2025
1 parent adf9471 commit cb32ece
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 18 deletions.
51 changes: 35 additions & 16 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from contextlib import ExitStack
from typing import List, Tuple

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM

if not current_platform.is_cuda():
Expand All @@ -18,28 +20,39 @@


async def generate(engine: AsyncLLM, request_id: str,
output_kind: RequestOutputKind,
max_tokens: int) -> Tuple[int, str]:
count = 0
async for _ in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(
max_tokens=max_tokens, temperature=0)):
sampling_params = SamplingParams(max_tokens=max_tokens,
output_kind=output_kind,
temperature=0)
async for out in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=sampling_params):

num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens

count += 1
await asyncio.sleep(0.)

return count, request_id


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(monkeypatch):
async def test_load(monkeypatch, output_kind: RequestOutputKind):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
after.callback(engine.shutdown)

NUM_REQUESTS = 10000
NUM_EXPECTED_TOKENS = 10
Expand All @@ -51,26 +64,33 @@ async def test_load(monkeypatch):
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
generate(engine, request_id, output_kind,
NUM_EXPECTED_TOKENS)))

# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")

assert not engine.output_processor.has_unfinished_requests()
engine.shutdown()


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort(monkeypatch):
async def test_abort(monkeypatch, output_kind: RequestOutputKind):

with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 100
Expand All @@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
generate(engine, request_id, output_kind,
NUM_EXPECTED_TOKENS)))

# API server cancels requests when they disconnect.
for idx in REQUEST_IDS_TO_ABORT:
Expand All @@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
# Confirm we can do another generation.
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
task = asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS))
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()

engine.shutdown()
22 changes: 21 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, Generic, List, Optional
from typing import Dict, Generic, List, MutableSequence, Optional
from typing import Sequence as GenericSequence
from typing import Union

Expand Down Expand Up @@ -162,6 +162,26 @@ def new(
finished=finished,
)

def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""

self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids
self.prompt_logprobs = next_output.prompt_logprobs
self.finished |= next_output.finished

#TODO assuming n == 1 for now
completion = self.outputs[0]
next_completion = next_output.outputs[0]
completion.text += next_completion.text
if not isinstance(completion.token_ids, MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(next_completion.logprobs)
completion.cumulative_logprob = next_completion.cumulative_logprob

@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
Expand Down
10 changes: 9 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -214,6 +214,14 @@ async def generate(
# task switching under load which helps performance).
out = q.get_nowait() if not q.empty() else await q.get()

# Coalesce any additional queued outputs
while not q.empty():
next_out = q.get_nowait()
if sampling_params.output_kind == RequestOutputKind.DELTA:
out.add(next_out)
else:
out = next_out

# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
Expand Down

0 comments on commit cb32ece

Please sign in to comment.