Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adds basic support for image content in OpenAI chat routes #113

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Mistral
- Mixtral
- LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- Qwen / Qwen 2
- AWQ quantization

Expand Down
28 changes: 23 additions & 5 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

from sglang.srt.managers.openai_protocol import ChatCompletionRequest

Expand Down Expand Up @@ -52,6 +52,7 @@ class Conversation:
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
image_data: Optional[List[str]] = None

def get_prompt(self) -> str:
"""Get the prompt for generation."""
Expand Down Expand Up @@ -251,6 +252,10 @@ def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])

def append_image(self, image: str):
"""Append a new message."""
self.image_data.append(image)

def update_last_message(self, message: str):
"""Update the last output.

Expand Down Expand Up @@ -341,18 +346,31 @@ def generate_chat_conv(
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
image_data=[],
)

if isinstance(request.messages, str):
raise ValueError("The messages should be a list of dict.")
for message in request.messages:
msg_role = message["role"]
msg_role = message.role
if msg_role == "system":
conv.system_message = message["content"]
conv.system_message = message.content
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
# Handle the various types of Chat Request content types here.
role = conv.roles[0]
if isinstance(message.content, str):
conv.append_message(conv.roles[0], message.content)
else:
real_content = ""
for content in message.content:
if content.type == "text":
real_content += content.text
elif content.type == "image_url":
real_content += "<image>"
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")

Expand Down
38 changes: 37 additions & 1 deletion python/sglang/srt/managers/openai_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Dict, List, Optional, Union
from typing_extensions import Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel):
usage: UsageInfo


class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant"]
content: str


class ChatCompletionMessageContentTextPart(BaseModel):
type: Literal["text"]
text: str


class ChatCompletionMessageContentImageURL(BaseModel):
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"


class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL


ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
]


class ChatCompletionMessageUserParam(BaseModel):
role: Literal["user"]
content: Union[str, List[ChatCompletionMessageContentPart]]


ChatCompletionMessageParam = Union[
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
]


class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: Union[str, List[ChatCompletionMessageParam]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
Expand Down
11 changes: 8 additions & 3 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,17 @@ async def generate_request(self, obj: GenerateReqInput):
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash, image_size = None, None, None
else:

if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[0]
)
elif isinstance(obj.image_data, str):
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
else:
pixel_values, image_hash, image_size = None, None, None
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
Expand Down
20 changes: 19 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.conversation import (
Expand Down Expand Up @@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request):
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1

# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
# This flow doesn't support the full OpenAI spec. Verify messages
# has the right type before proceeding:
for m in request.messages:
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
comaniac marked this conversation as resolved.
Show resolved Hide resolved
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
Expand All @@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request):
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None

adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
Expand Down Expand Up @@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer):

# Load chat template if needed
if server_args.chat_template is not None:
print(server_args.chat_template)
if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template):
raise RuntimeError(
Expand Down
46 changes: 46 additions & 0 deletions python/sglang/test/test_conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.openai_protocol import (
ChatCompletionMessageGenericParam,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentImageURL,
ChatCompletionMessageContentTextPart,
ChatCompletionMessageUserParam,
ChatCompletionRequest,
)


def test_chat_completion_to_conv_image():
"""Test that we can convert a chat image request to a convo"""
request = ChatCompletionRequest(
model="default",
messages=[
ChatCompletionMessageGenericParam(
role="system", content="You are a helpful AI assistant"
),
ChatCompletionMessageUserParam(
role="user",
content=[
ChatCompletionMessageContentTextPart(
type="text", text="Describe this image"
),
ChatCompletionMessageContentImagePart(
type="image_url",
image_url=ChatCompletionMessageContentImageURL(
url="https://someurl.com"
),
),
],
),
],
)
conv = generate_chat_conv(request, "vicuna_v1.1")
assert conv.messages == [
["USER", "Describe this image<image>"],
["ASSISTANT", None],
]
assert conv.system_message == "You are a helpful AI assistant"
assert conv.image_data == ["https://someurl.com"]


if __name__ == "__main__":
test_chat_completion_to_conv_image()
51 changes: 51 additions & 0 deletions python/sglang/test/test_openai_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from sglang.srt.managers.openai_protocol import (
ChatCompletionMessageGenericParam,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentImageURL,
ChatCompletionMessageContentTextPart,
ChatCompletionMessageUserParam,
ChatCompletionRequest,
)


def test_chat_completion_request_image():
"""Test that Chat Completion Requests with images can be converted."""

image_request = {
"model": "default",
"messages": [
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
],
},
],
"temperature": 0,
"max_tokens": 64,
}
request = ChatCompletionRequest(**image_request)
assert len(request.messages) == 2
assert request.messages[0] == ChatCompletionMessageGenericParam(
role="system", content="You are a helpful AI assistant"
)
assert request.messages[1] == ChatCompletionMessageUserParam(
role="user",
content=[
ChatCompletionMessageContentTextPart(
type="text", text="Describe this image"
),
ChatCompletionMessageContentImagePart(
type="image_url",
image_url=ChatCompletionMessageContentImageURL(
url="https://someurl.com"
),
),
],
)


if __name__ == "__main__":
test_chat_completion_request_image()
35 changes: 35 additions & 0 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,36 @@ def test_chat_completion(args):
assert response.usage.total_tokens > 0


def test_chat_completion_image(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0


def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
Expand Down Expand Up @@ -100,9 +130,14 @@ def test_chat_completion_stream(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args()

test_completion(args)
test_completion_stream(args)
test_chat_completion(args)
test_chat_completion_stream(args)
if args.test_image:
test_chat_completion_image(args)