Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Add --async-engine option to benchmark_throughput.py #7964

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 109 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from typing import List, Optional, Tuple

import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


def sample_requests(
Expand Down Expand Up @@ -135,6 +138,93 @@ def run_vllm(
return end - start


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
)

async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))

generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start


def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
Expand Down Expand Up @@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
args.output_len)

if args.backend == "vllm":
elapsed_time = run_vllm(
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
Expand All @@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
args.disable_async_output_proc
]

if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
Comment on lines +336 to +340
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: better to use kwargs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but the list was already passed as regular args so this involved minimal changes.

elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -426,6 +523,14 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
45 changes: 30 additions & 15 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@


def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: str) -> bool:
quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
Expand Down Expand Up @@ -96,13 +96,6 @@ async def _force_log():
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC

Returns the Client or None if the creation failed.
"""

# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
Expand All @@ -112,14 +105,37 @@ async def build_async_engine_client(
# Backend itself still global for the silly lil' health handler
global async_engine_client

async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:

async_engine_client = engine # type: ignore[assignment]
yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC

Returns the Client or None if the creation failed.
"""

# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code,
args.quantization)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization)
or disable_frontend_multiprocessing):
engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
try:
yield engine_client
finally:
engine_client.shutdown_background_loop()
return

# Otherwise, use the multiprocessing AsyncLLMEngine.
Expand Down Expand Up @@ -148,7 +164,6 @@ async def build_async_engine_client(
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore

# Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn")
Expand All @@ -174,7 +189,7 @@ async def build_async_engine_client(
yield None
return

yield async_engine_client
yield rpc_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand Down Expand Up @@ -214,6 +215,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,

# Await the data from the Server.
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer)

if isinstance(data, Exception):
Expand Down Expand Up @@ -247,6 +249,7 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
f"{self._data_timeout} ms")

frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer)

# Make a new socket connection.
Expand Down Expand Up @@ -395,6 +398,7 @@ async def generate(
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer)

if isinstance(request_output, Exception):
Expand Down
Loading