From 58fcc8545a149c9c5b1f91f417a68f5ba1fdabf3 Mon Sep 17 00:00:00 2001 From: Adam Lugowski Date: Mon, 9 Sep 2024 11:16:37 -0700 Subject: [PATCH] [Frontend] Add progress reporting to run_batch.py (#8060) Co-authored-by: Adam Lugowski --- vllm/entrypoints/openai/run_batch.py | 54 ++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 32bbade256973..278be8cd11a12 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -1,9 +1,11 @@ import asyncio from io import StringIO -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, List, Optional import aiohttp +import torch from prometheus_client import start_http_server +from tqdm import tqdm from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -78,6 +80,38 @@ def parse_args(): return parser.parse_args() +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): async with aiohttp.ClientSession() as session, \ @@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None: async def run_request(serving_engine_func: Callable, - request: BatchRequestInput) -> BatchRequestOutput: + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): @@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable, else: raise ValueError("Request must not be sent in stream mode") + tracker.completed() return batch_output @@ -164,6 +200,9 @@ async def main(args): request_logger=request_logger, ) + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + # Submit all requests in the file to the engine "concurrently". response_futures: List[Awaitable[BatchRequestOutput]] = [] for request_json in (await read_file(args.input_file)).strip().split("\n"): @@ -178,16 +217,19 @@ async def main(args): if request.url == "/v1/chat/completions": response_futures.append( run_request(openai_serving_chat.create_chat_completion, - request)) + request, tracker)) + tracker.submitted() elif request.url == "/v1/embeddings": response_futures.append( - run_request(openai_serving_embedding.create_embedding, - request)) + run_request(openai_serving_embedding.create_embedding, request, + tracker)) + tracker.submitted() else: raise ValueError("Only /v1/chat/completions and /v1/embeddings are" "supported in the batch endpoint.") - responses = await asyncio.gather(*response_futures) + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) output_buffer = StringIO() for response in responses: