Skip to content

Commit

Permalink
add openai api
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Aug 9, 2024
1 parent e040a24 commit c85cf2a
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 19 deletions.
10 changes: 6 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def post_init(self):
if is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
self.sampling_params = {"max_new_tokens": 0}
if self.sampling_params is None:
self.sampling_params = {"max_new_tokens": 1}
else:
# support select operation
self.batch_size = (
Expand All @@ -205,9 +206,10 @@ def post_init(self):
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
self.sampling_params = [
{"max_new_tokens": 0} for _ in range(self.batch_size)
]
if self.sampling_params is None:
self.sampling_params = [
{"max_new_tokens": 1} for _ in range(self.batch_size)
]


@dataclass
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ async def _handle_single_request(
):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
req.embedding = embeddings[i]
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()

if req.finished():
Expand Down
32 changes: 21 additions & 11 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
BatchResponse,
Expand All @@ -52,6 +52,7 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
Expand Down Expand Up @@ -1016,10 +1017,10 @@ async def generate_stream_resp():
def v1_embedding_request(all_requests, tokenizer_manager):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].prompt)
first_prompt_type = type(all_requests[0].input)

for request in all_requests:
prompt = request.prompt
prompt = request.input
assert (
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
Expand All @@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
return adapted_request, all_requests


def v1_embedding_response(request, ret, to_file=False):
response = []
def v1_embedding_response(ret, model_path, to_file=False):
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
response.append(
EmbeddingResponse(
embedding_objects.append(
EmbeddingObject(
embedding=ret[idx]["embedding"],
index=idx,
embedding=ret[idx],
object="embedding",
)
)
return response
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]

return EmbeddingResponse(
data=embedding_objects,
model=model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)


async def v1_embeddings(tokenizer_manager, raw_request: Request):
Expand All @@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]

response = v1_embedding_response(request, ret)
response = v1_embedding_response(ret, tokenizer_manager.model_path)

return response

Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
user: Optional[str] = None


class EmbeddingResponse(BaseModel):
index: str
embedding: List[float] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"


class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None
9 changes: 8 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
v1_chat_completions,
v1_completions,
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_retrieve_batch,
v1_retrieve_file,
Expand Down Expand Up @@ -174,6 +175,12 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)


@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(tokenizer_manager, raw_request)
return response


@app.get("/v1/models")
def available_models():
"""Show available models."""
Expand Down Expand Up @@ -406,7 +413,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):

# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
max_new_tokens = 8 if model_info["is_generation"] else 1
try:
for _ in range(server_args.dp_size):
res = requests.post(
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
suites = {
"minimal": [
"test_eval_accuracy.py",
"test_embedding_openai_server.py",
"test_openai_server.py",
"test_vision_openai_server.py",
"test_chunked_prefill.py",
Expand Down
87 changes: 87 additions & 0 deletions test/srt/test_embedding_openai_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import time
import unittest

import openai

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import EmbeddingObject
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import popen_launch_server


class TestOpenAIServer(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.model = "intfloat/e5-mistral-7b-instruct"
cls.base_url = "http://127.0.0.1:8157"
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)

def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))

if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_prompts = len(prompt_arg)
else:
prompt_arg = prompt_input
num_prompts = 1

response = client.embeddings.create(
input=prompt_arg,
model=self.model,
)

assert len(response.data) == num_prompts
assert isinstance(response.data, list)
assert response.data[0].embedding
assert response.data[0].index is not None
assert response.data[0].object == "embedding"
assert response.model == self.model
assert response.object == "list"
assert (
response.usage.prompt_tokens == num_prompt_tokens
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
assert (
response.usage.total_tokens == num_prompt_tokens
), f"{response.usage.total_tokens} vs {num_prompt_tokens}"

def run_batch(self):
# FIXME not implemented
pass

def test_embedding(self):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for use_list_input in [False]:
for token_input in [False, True]:
self.run_embedding(use_list_input, token_input)

def test_batch(self):
self.run_batch()


if __name__ == "__main__":
unittest.main(warnings="ignore")

# t = TestOpenAIServer()
# t.setUpClass()
# t.test_embedding()
# t.tearDownClass()

0 comments on commit c85cf2a

Please sign in to comment.