From a23722acb62f359744da8ccd0dd8f44182b27857 Mon Sep 17 00:00:00 2001 From: Adam Tilghman Date: Wed, 25 Sep 2024 00:49:26 -0700 Subject: [PATCH] [Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer (#8672) --- vllm/entrypoints/openai/protocol.py | 5 +++ vllm/entrypoints/openai/serving_chat.py | 26 +++++++++++-- vllm/entrypoints/openai/serving_completion.py | 37 +++++++++++++++---- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..40d27f984fbaa 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel): completion_tokens: Optional[int] = 0 +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + class JsonSchemaResponseFormat(OpenAIBaseModel): name: str description: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ee4b3ce17cfa..0321ea98ec742 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,8 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, + ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -175,6 +176,11 @@ async def create_chat_completion( "--enable-auto-tool-choice and --tool-call-parser to be set") request_id = f"chat-{random_uuid()}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + try: guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -241,11 +247,13 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -262,6 +270,7 @@ async def chat_completion_stream_generator( request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: model_name = self.base_model_paths[0].name created_time = int(time.time()) @@ -580,6 +589,13 @@ async def chat_completion_stream_generator( exclude_unset=True, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) @@ -595,6 +611,7 @@ async def chat_completion_full_generator( request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.base_model_paths[0].name @@ -714,6 +731,9 @@ async def chat_completion_full_generator( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) + + request_metadata.final_usage_info = usage + response = ChatCompletionResponse( id=request_id, created=created_time, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9abd74d0561d0..0e8609002e39e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -18,7 +18,9 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - ErrorResponse, UsageInfo) + ErrorResponse, + RequestResponseMetadata, + UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, @@ -94,6 +96,10 @@ async def create_completion( request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: @@ -165,13 +171,15 @@ async def create_completion( # Streaming response if stream: - return self.completion_stream_generator(request, - result_generator, - request_id, - created_time, - model_name, - num_prompts=len(prompts), - tokenizer=tokenizer) + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts), + tokenizer=tokenizer, + request_metadata=request_metadata) # Non-streaming response final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) @@ -198,6 +206,7 @@ async def create_completion( created_time, model_name, tokenizer, + request_metadata, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -227,6 +236,7 @@ async def completion_stream_generator( model_name: str, num_prompts: int, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -346,6 +356,14 @@ async def completion_stream_generator( exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) @@ -360,6 +378,7 @@ def request_output_to_completion_response( created_time: int, model_name: str, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -433,6 +452,8 @@ def request_output_to_completion_response( total_tokens=num_prompt_tokens + num_generated_tokens, ) + request_metadata.final_usage_info = usage + return CompletionResponse( id=request_id, created=created_time,