Skip to content

Commit

Permalink
[Frontend] OpenAI server: propagate usage accounting to FastAPI middl…
Browse files Browse the repository at this point in the history
…eware layer (vllm-project#8672)
  • Loading branch information
agt authored and siddharth9820 committed Sep 30, 2024
1 parent 9db1b59 commit a23722a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 11 deletions.
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0


class RequestResponseMetadata(BaseModel):
request_id: str
final_usage_info: Optional[UsageInfo] = None


class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: Optional[str] = None
Expand Down
26 changes: 23 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
Expand Down Expand Up @@ -175,6 +176,11 @@ async def create_chat_completion(
"--enable-auto-tool-choice and --tool-call-parser to be set")

request_id = f"chat-{random_uuid()}"

request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata

try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
Expand Down Expand Up @@ -241,11 +247,13 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
request, result_generator, request_id, conversation, tokenizer,
request_metadata)

try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -262,6 +270,7 @@ async def chat_completion_stream_generator(
request_id: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
Expand Down Expand Up @@ -580,6 +589,13 @@ async def chat_completion_stream_generator(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens)

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e)
Expand All @@ -595,6 +611,7 @@ async def chat_completion_full_generator(
request_id: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.base_model_paths[0].name
Expand Down Expand Up @@ -714,6 +731,9 @@ async def chat_completion_full_generator(
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)

request_metadata.final_usage_info = usage

response = ChatCompletionResponse(
id=request_id,
created=created_time,
Expand Down
37 changes: 29 additions & 8 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse, UsageInfo)
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
Expand Down Expand Up @@ -94,6 +96,10 @@ async def create_completion(
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())

request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
Expand Down Expand Up @@ -165,13 +171,15 @@ async def create_completion(

# Streaming response
if stream:
return self.completion_stream_generator(request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer)
return self.completion_stream_generator(
request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer,
request_metadata=request_metadata)

# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
Expand All @@ -198,6 +206,7 @@ async def create_completion(
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
Expand Down Expand Up @@ -227,6 +236,7 @@ async def completion_stream_generator(
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
Expand Down Expand Up @@ -346,6 +356,14 @@ async def completion_stream_generator(
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
Expand All @@ -360,6 +378,7 @@ def request_output_to_completion_response(
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
Expand Down Expand Up @@ -433,6 +452,8 @@ def request_output_to_completion_response(
total_tokens=num_prompt_tokens + num_generated_tokens,
)

request_metadata.final_usage_info = usage

return CompletionResponse(
id=request_id,
created=created_time,
Expand Down

0 comments on commit a23722a

Please sign in to comment.