From 215bba80972b68eece42281e9840d5de2b6f2e73 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Mon, 29 Jan 2024 06:13:01 +0000 Subject: [PATCH 1/4] Adding rudimentary support for images in the openai compatible server --- README.md | 2 +- python/sglang/srt/conversation.py | 28 ++++++++-- python/sglang/srt/managers/openai_protocol.py | 38 +++++++++++++- .../sglang/srt/managers/tokenizer_manager.py | 11 ++-- python/sglang/srt/server.py | 5 ++ python/sglang/test/test_conversation.py | 46 +++++++++++++++++ python/sglang/test/test_openai_protocol.py | 51 +++++++++++++++++++ test/srt/test_openai_server.py | 31 +++++++++++ 8 files changed, 202 insertions(+), 10 deletions(-) create mode 100644 python/sglang/test/test_conversation.py create mode 100644 python/sglang/test/test_openai_protocol.py diff --git a/README.md b/README.md index 8381dd87e6..fe48f8b54d 100644 --- a/README.md +++ b/README.md @@ -323,7 +323,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - Mistral - Mixtral - LLaVA - - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000` + - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - Qwen - AWQ quantization diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 24c84a5c96..64651ac4ce 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses from enum import IntEnum, auto -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from sglang.srt.managers.openai_protocol import ChatCompletionRequest @@ -52,6 +52,7 @@ class Conversation: sep2: str = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None + image_data: Optional[List[str]] = None def get_prompt(self) -> str: """Get the prompt for generation.""" @@ -251,6 +252,10 @@ def append_message(self, role: str, message: str): """Append a new message.""" self.messages.append([role, message]) + def append_image(self, image: str): + """Append a new message.""" + self.image_data.append(image) + def update_last_message(self, message: str): """Update the last output. @@ -341,18 +346,31 @@ def generate_chat_conv( sep=conv.sep, sep2=conv.sep2, stop_str=conv.stop_str, + image_data=[], ) if isinstance(request.messages, str): raise ValueError("The messages should be a list of dict.") for message in request.messages: - msg_role = message["role"] + msg_role = message.role if msg_role == "system": - conv.system_message = message["content"] + conv.system_message = message.content elif msg_role == "user": - conv.append_message(conv.roles[0], message["content"]) + # Handle the various types of Chat Request content types here. + role = conv.roles[0] + if isinstance(message.content, str): + conv.append_message(conv.roles[0], message.content) + else: + real_content = "" + for content in message.content: + if content.type == "text": + real_content += content.text + if content.type == "image_url": + real_content += "" + conv.append_image(content.image_url.url) + conv.append_message(conv.roles[0], real_content) elif msg_role == "assistant": - conv.append_message(conv.roles[1], message["content"]) + conv.append_message(conv.roles[1], message.content) else: raise ValueError(f"Unknown role: {msg_role}") diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py index 974e38a910..f4ef99dd95 100644 --- a/python/sglang/srt/managers/openai_protocol.py +++ b/python/sglang/srt/managers/openai_protocol.py @@ -1,5 +1,6 @@ import time from typing import Dict, List, Optional, Union +from typing_extensions import Literal from pydantic import BaseModel, Field @@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel): usage: UsageInfo +class ChatCompletionMessageGenericParam(BaseModel): + role: Literal["system", "assistant"] + content: str + + +class ChatCompletionMessageContentTextPart(BaseModel): + type: Literal["text"] + text: str + + +class ChatCompletionMessageContentImageURL(BaseModel): + url: str + detail: Optional[Literal["auto", "low", "high"]] = "auto" + + +class ChatCompletionMessageContentImagePart(BaseModel): + type: Literal["image_url"] + image_url: ChatCompletionMessageContentImageURL + + +ChatCompletionMessageContentPart = Union[ + ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart +] + + +class ChatCompletionMessageUserParam(BaseModel): + role: Literal["user"] + content: Union[str, List[ChatCompletionMessageContentPart]] + + +ChatCompletionMessageParam = Union[ + ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam +] + + class ChatCompletionRequest(BaseModel): model: str - messages: Union[str, List[Dict[str, str]]] + messages: Union[str, List[ChatCompletionMessageParam]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d08b336346..b0c4d4e8a7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -150,12 +150,17 @@ async def generate_request(self, obj: GenerateReqInput): if sampling_params.max_new_tokens != 0: sampling_params.normalize(self.tokenizer) sampling_params.verify() - if obj.image_data is None: - pixel_values, image_hash, image_size = None, None, None - else: + + if isinstance(obj.image_data, list) and len(obj.image_data) > 0: + pixel_values, image_hash, image_size = await self.get_pixel_values( + obj.image_data[0] + ) + elif isinstance(obj.image_data, str): pixel_values, image_hash, image_size = await self.get_pixel_values( obj.image_data ) + else: + pixel_values, image_hash, image_size = None, None, None tokenized_obj = TokenizedGenerateReqInput( rid=rid, input_text=obj.text, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9750c4e72c..b645dc1302 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -197,9 +197,11 @@ async def v1_chat_completions(raw_request: Request): request.messages, tokenize=False, add_generation_prompt=True ) stop = request.stop + image_data = None else: conv = generate_chat_conv(request, chat_template_name) prompt = conv.get_prompt() + image_data = conv.image_data stop = conv.stop_str or [] if request.stop: if isinstance(request.stop, str): @@ -210,9 +212,11 @@ async def v1_chat_completions(raw_request: Request): # Use the raw prompt and stop strings if the messages is already a string. prompt = request.messages stop = request.stop + image_data = None adapted_request = GenerateReqInput( text=prompt, + image_data=image_data, sampling_params={ "temperature": request.temperature, "max_new_tokens": request.max_tokens, @@ -303,6 +307,7 @@ def launch_server(server_args, pipe_finish_writer): # Load chat template if needed if server_args.chat_template is not None: + print(server_args.chat_template) if not chat_template_exists(server_args.chat_template): if not os.path.exists(server_args.chat_template): raise RuntimeError( diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py new file mode 100644 index 0000000000..4f4f956fe4 --- /dev/null +++ b/python/sglang/test/test_conversation.py @@ -0,0 +1,46 @@ +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.managers.openai_protocol import ( + ChatCompletionMessageGenericParam, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentImageURL, + ChatCompletionMessageContentTextPart, + ChatCompletionMessageUserParam, + ChatCompletionRequest, +) + + +def test_chat_completion_to_conv_image(): + """Test that we can convert a chat image request to a convo""" + request = ChatCompletionRequest( + model="default", + messages=[ + ChatCompletionMessageGenericParam( + role="system", content="You are a helpful AI assistant" + ), + ChatCompletionMessageUserParam( + role="user", + content=[ + ChatCompletionMessageContentTextPart( + type="text", text="Describe this image" + ), + ChatCompletionMessageContentImagePart( + type="image_url", + image_url=ChatCompletionMessageContentImageURL( + url="https://someurl.com" + ), + ), + ], + ), + ], + ) + conv = generate_chat_conv(request, "vicuna_v1.1") + assert conv.messages == [ + ["USER", "Describe this image"], + ["ASSISTANT", None], + ] + assert conv.system_message == "You are a helpful AI assistant" + assert conv.image_data == ["https://someurl.com"] + + +if __name__ == "__main__": + test_chat_completion_to_conv_image() diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py new file mode 100644 index 0000000000..ed18e428a9 --- /dev/null +++ b/python/sglang/test/test_openai_protocol.py @@ -0,0 +1,51 @@ +from sglang.srt.managers.openai_protocol import ( + ChatCompletionMessageGenericParam, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentImageURL, + ChatCompletionMessageContentTextPart, + ChatCompletionMessageUserParam, + ChatCompletionRequest, +) + + +def test_chat_completion_request_image(): + """Test that Chat Completion Requests with images can be converted.""" + + image_request = { + "model": "default", + "messages": [ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": {"url": "https://someurl.com"}}, + ], + }, + ], + "temperature": 0, + "max_tokens": 64, + } + request = ChatCompletionRequest(**image_request) + assert len(request.messages) == 2 + assert request.messages[0] == ChatCompletionMessageGenericParam( + role="system", content="You are a helpful AI assistant" + ) + assert request.messages[1] == ChatCompletionMessageUserParam( + role="user", + content=[ + ChatCompletionMessageContentTextPart( + type="text", text="Describe this image" + ), + ChatCompletionMessageContentImagePart( + type="image_url", + image_url=ChatCompletionMessageContentImageURL( + url="https://someurl.com" + ), + ), + ], + ) + + +if __name__ == "__main__": + test_chat_completion_request_image() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 33d5b0672f..cdfc8c05cf 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -71,6 +71,36 @@ def test_chat_completion(args): assert response.usage.total_tokens > 0 +def test_chat_completion_image(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg" + }, + }, + ], + }, + ], + temperature=0, + max_tokens=32, + ) + print(response.choices[0].message.content) + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_chat_completion_stream(args): client = openai.Client(api_key="EMPTY", base_url=args.base_url) response = client.chat.completions.create( @@ -105,4 +135,5 @@ def test_chat_completion_stream(args): test_completion(args) test_completion_stream(args) test_chat_completion(args) + test_chat_completion_image(args) test_chat_completion_stream(args) From 2c0b8dcd9d760ee421cca1125ff4c94d338ef22e Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 30 Jan 2024 00:31:39 +0000 Subject: [PATCH 2/4] Addressing review comments --- python/sglang/srt/conversation.py | 2 +- python/sglang/srt/server.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 64651ac4ce..df872f77a4 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -365,7 +365,7 @@ def generate_chat_conv( for content in message.content: if content.type == "text": real_content += content.text - if content.type == "image_url": + elif content.type == "image_url": real_content += "" conv.append_image(content.image_url.url) conv.append_message(conv.roles[0], real_content) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b645dc1302..7e06f3ba5c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -190,6 +190,11 @@ async def v1_chat_completions(raw_request: Request): # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. assert request.n == 1 + # Prep the data needed for the underlying GenerateReqInput: + # - prompt: The full prompt string. + # - stop: Custom stop tokens. + # - image_data: None or a list of image strings (URLs or base64 strings). + # None skips any image processing in GenerateReqInput. if not isinstance(request.messages, str): # Apply chat template and its stop strings. if chat_template_name is None: From 4e0e66d664dcb33491ffe860227145182967fcde Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 30 Jan 2024 01:45:06 +0000 Subject: [PATCH 3/4] Disabling image input tests by default --- test/srt/test_openai_server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index cdfc8c05cf..f0dc078e22 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -130,10 +130,14 @@ def test_chat_completion_stream(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") + parser.add_argument( + "--test-image", action="store_true", help="Enables testing image inputs" + ) args = parser.parse_args() test_completion(args) test_completion_stream(args) test_chat_completion(args) - test_chat_completion_image(args) test_chat_completion_stream(args) + if args.test_image: + test_chat_completion_image(args) From 414044dcb2fc20062346a469e405ba28f379e934 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Tue, 30 Jan 2024 01:53:05 +0000 Subject: [PATCH 4/4] Adding 503 when using structured requests and no chat template --- python/sglang/srt/server.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7e06f3ba5c..b36ed55b39 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -16,7 +16,7 @@ import requests import uvicorn import uvloop -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.conversation import ( @@ -198,6 +198,14 @@ async def v1_chat_completions(raw_request: Request): if not isinstance(request.messages, str): # Apply chat template and its stop strings. if chat_template_name is None: + # This flow doesn't support the full OpenAI spec. Verify messages + # has the right type before proceeding: + for m in request.messages: + if not isinstance(m.content, str): + raise HTTPException( + status_code=503, + detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.", + ) prompt = tokenizer_manager.tokenizer.apply_chat_template( request.messages, tokenize=False, add_generation_prompt=True )