Skip to content

Commit

Permalink
Disallow extra fields in OpenAI API
Browse files Browse the repository at this point in the history
- Also improve type annotations in the process
  • Loading branch information
DarkLight1337 committed Apr 25, 2024
1 parent 96e90fd commit 8cf1fb3
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 52 deletions.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pytest-rerunfailures
pytest-shard
httpx
einops # required for MPT
openai
requests
ray
peft
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ssl

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.serving_engine import LoRA
from vllm.entrypoints.openai.serving_engine import LoRAModulePath


class LoRAParserAction(argparse.Action):
Expand All @@ -18,7 +18,7 @@ def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
lora_list.append(LoRAModulePath(name, path))
setattr(namespace, self.dest, lora_list)


Expand Down
64 changes: 35 additions & 29 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,28 @@
from typing import Dict, List, Literal, Optional, Union

import torch
from pydantic import BaseModel, Field, model_validator
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated

from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid


class ErrorResponse(BaseModel):
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")


class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int


class ModelPermission(BaseModel):
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
Expand All @@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
is_blocking: bool = False


class ModelCard(BaseModel):
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
Expand All @@ -44,26 +50,26 @@ class ModelCard(BaseModel):
permission: List[ModelPermission] = Field(default_factory=list)


class ModelList(BaseModel):
class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)


class UsageInfo(BaseModel):
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class ResponseFormat(BaseModel):
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]


class ChatCompletionRequest(BaseModel):
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]]
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand Down Expand Up @@ -204,7 +210,7 @@ def check_guided_decoding_count(cls, data):
return data


class CompletionRequest(BaseModel):
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
Expand Down Expand Up @@ -343,19 +349,19 @@ def check_guided_decoding_count(cls, data):
return data


class LogProbs(BaseModel):
class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None


class CompletionResponseChoice(BaseModel):
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
Expand All @@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
)


class CompletionResponse(BaseModel):
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
Expand All @@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
Expand All @@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
)


class CompletionStreamResponse(BaseModel):
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
Expand All @@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = Field(default=None)


class ChatMessage(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
content: str


class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None


class ChatCompletionResponse(BaseModel):
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
Expand All @@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
usage: UsageInfo


class DeltaMessage(BaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None


class ChatCompletionStreamResponse(BaseModel):
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
Expand Down
55 changes: 44 additions & 11 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import codecs
import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)

from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
Expand All @@ -20,20 +24,41 @@
logger = init_logger(__name__)


@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str


class OpenAIServingChat(OpenAIServing):

def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None):
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)

def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []

# To be implemented: https://github.com/vllm-project/vllm/pull/3467
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
raise NotImplementedError("Complex input not supported yet")

async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
Expand All @@ -52,10 +77,19 @@ async def create_chat_completion(
return error_check_ret

try:
conversation: List[ConversationMessage] = []

for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])

conversation.extend(messages)

prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
add_generation_prompt=request.add_generation_prompt,
)
except Exception as e:
logger.error(
f"Error in applying chat template from request: {str(e)}")
Expand Down Expand Up @@ -106,9 +140,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:

result_generator: AsyncIterator[RequestOutput],
request_id: str) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
Expand Down Expand Up @@ -253,7 +286,7 @@ async def chat_completion_full_generator(

model_name = self.served_model_names[0]
created_time = int(time.time())
final_res: RequestOutput = None
final_res: Optional[RequestOutput] = None

async for res in result_generator:
if await raw_request.is_disconnected():
Expand Down Expand Up @@ -318,7 +351,7 @@ async def chat_completion_full_generator(

return response

def _load_chat_template(self, chat_template):
def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer

if chat_template is not None:
Expand Down
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
Expand Down Expand Up @@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None):
lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
Expand Down Expand Up @@ -84,7 +85,7 @@ async def create_completion(self, request: CompletionRequest,
created_time = int(time.time())

# Schedule the request and get the result generator.
generators = []
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
Expand Down Expand Up @@ -148,7 +149,7 @@ async def create_completion(self, request: CompletionRequest,
num_prompts=len(prompts))

# Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts)
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
Expand Down
Loading

0 comments on commit 8cf1fb3

Please sign in to comment.