diff --git a/requirements-common.txt b/requirements-common.txt index 3cc7bba8f84db..e9db261c6aec9 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,6 +8,7 @@ py-cpuinfo transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. tokenizers >= 0.19.1 # Required for Llama 3. fastapi +openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 1317e51b2dd11..6dd8125786b75 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -21,7 +21,6 @@ pytest-rerunfailures pytest-shard httpx einops # required for MPT -openai requests ray peft diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5c361b4d184ee..16c5b6c08d37f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,7 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.serving_engine import LoRA +from vllm.entrypoints.openai.serving_engine import LoRAModulePath class LoRAParserAction(argparse.Action): @@ -18,7 +18,7 @@ def __call__(self, parser, namespace, values, option_string=None): lora_list = [] for item in values: name, path = item.split('=') - lora_list.append(LoRA(name, path)) + lora_list.append(LoRAModulePath(name, path)) setattr(namespace, self.dest, lora_list) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d9763d024eb83..0a949f9867754 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,14 +4,20 @@ from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, model_validator +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -class ErrorResponse(BaseModel): +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): object: str = "error" message: str type: str @@ -19,7 +25,7 @@ class ErrorResponse(BaseModel): code: int -class ModelPermission(BaseModel): +class ModelPermission(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") object: str = "model_permission" created: int = Field(default_factory=lambda: int(time.time())) @@ -34,7 +40,7 @@ class ModelPermission(BaseModel): is_blocking: bool = False -class ModelCard(BaseModel): +class ModelCard(OpenAIBaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) @@ -44,26 +50,26 @@ class ModelCard(BaseModel): permission: List[ModelPermission] = Field(default_factory=list) -class ModelList(BaseModel): +class ModelList(OpenAIBaseModel): object: str = "list" data: List[ModelCard] = Field(default_factory=list) -class UsageInfo(BaseModel): +class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 -class ResponseFormat(BaseModel): +class ResponseFormat(OpenAIBaseModel): # type must be "json_object" or "text" type: Literal["text", "json_object"] -class ChatCompletionRequest(BaseModel): +class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create - messages: List[Dict[str, str]] + messages: List[ChatCompletionMessageParam] model: str frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -204,7 +210,7 @@ def check_guided_decoding_count(cls, data): return data -class CompletionRequest(BaseModel): +class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: str @@ -343,19 +349,19 @@ def check_guided_decoding_count(cls, data): return data -class LogProbs(BaseModel): +class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None -class CompletionResponseChoice(BaseModel): +class CompletionResponseChoice(OpenAIBaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel): ) -class CompletionResponse(BaseModel): +class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -373,12 +379,12 @@ class CompletionResponse(BaseModel): usage: UsageInfo -class CompletionResponseStreamChoice(BaseModel): +class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel): ) -class CompletionStreamResponse(BaseModel): +class CompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel): usage: Optional[UsageInfo] = Field(default=None) -class ChatMessage(BaseModel): +class ChatMessage(OpenAIBaseModel): role: str content: str -class ChatCompletionResponseChoice(BaseModel): +class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionResponse(BaseModel): +class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel): usage: UsageInfo -class DeltaMessage(BaseModel): +class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None -class ChatCompletionResponseStreamChoice(BaseModel): +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionStreamResponse(BaseModel): +class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2ff335eb71073..197b2321b791d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,11 @@ import codecs import time -from typing import AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, + Optional, Tuple, TypedDict, Union, final) from fastapi import Request +from openai.types.chat import (ChatCompletionContentPartParam, + ChatCompletionRole) from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( @@ -10,7 +13,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -20,20 +24,41 @@ logger = init_logger(__name__) +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict): + role: str + content: str + + class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], response_role: str, - lora_modules: Optional[List[LoRA]] = None, - chat_template=None): + lora_modules: Optional[List[LoRAModulePath]] = None, + chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) + def _parse_chat_message_content( + self, + role: ChatCompletionRole, + content: Optional[Union[str, + Iterable[ChatCompletionContentPartParam]]], + ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: + if content is None: + return [], [] + if isinstance(content, str): + return [ConversationMessage(role=role, content=content)], [] + + # To be implemented: https://github.com/vllm-project/vllm/pull/3467 + # To be implemented: https://github.com/vllm-project/vllm/pull/4200 + raise NotImplementedError("Complex input not supported yet") + async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request ) -> Union[ErrorResponse, AsyncGenerator[str, None], @@ -52,10 +77,19 @@ async def create_chat_completion( return error_check_ret try: + conversation: List[ConversationMessage] = [] + + for m in request.messages: + messages, _ = self._parse_chat_message_content( + m["role"], m["content"]) + + conversation.extend(messages) + prompt = self.tokenizer.apply_chat_template( - conversation=request.messages, + conversation=conversation, tokenize=False, - add_generation_prompt=request.add_generation_prompt) + add_generation_prompt=request.add_generation_prompt, + ) except Exception as e: logger.error( f"Error in applying chat template from request: {str(e)}") @@ -106,9 +140,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str - ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" @@ -253,7 +286,7 @@ async def chat_completion_full_generator( model_name = self.served_model_names[0] created_time = int(time.time()) - final_res: RequestOutput = None + final_res: Optional[RequestOutput] = None async for res in result_generator: if await raw_request.is_disconnected(): @@ -318,7 +351,7 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template): + def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer if chat_template is not None: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 211b2e0424c3e..7904bb698c45a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -11,7 +11,8 @@ CompletionResponseStreamChoice, CompletionStreamResponse, LogProbs, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules: Optional[List[LoRA]] = None): + lora_modules: Optional[List[LoRAModulePath]] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules) @@ -84,7 +85,7 @@ async def create_completion(self, request: CompletionRequest, created_time = int(time.time()) # Schedule the request and get the result generator. - generators = [] + generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) @@ -148,7 +149,7 @@ async def create_completion(self, request: CompletionRequest, num_prompts=len(prompts)) # Non-streaming response - final_res_batch: RequestOutput = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 31da27a447c6c..f5abc005d528e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -22,17 +22,15 @@ @dataclass -class LoRA: +class LoRAModulePath: name: str local_path: str class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, - served_model_names: List[str], - lora_modules=Optional[List[LoRA]]): + def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]]): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: