diff --git a/README.md b/README.md index 9a343316cb..eb0bb485fc 100644 --- a/README.md +++ b/README.md @@ -238,9 +238,25 @@ curl http://localhost:30000/generate \ } }' ``` - Learn more about the argument format [here](docs/sampling_params.md). +### OpenAI Compatible API + +In addition, the server supports an experimental OpenAI-compatible API. + +```python +import openai +client = openai.Client( + base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") +response = client.completions.create( + model="default", + prompt="The capital of France is", + temperature=0, + max_tokens=32, +) +print(response) +``` + ### Additional Arguments - Add `--tp 2` to enable tensor parallelism. ``` diff --git a/python/pyproject.toml b/python/pyproject.toml index a356802e65..0df9414608 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ [project.optional-dependencies] srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", - "interegular", "lark", "numba"] + "interegular", "lark", "numba", "pydantic"] openai = ["openai>=1.0"] anthropic = ["anthropic"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 1a58cf75da..2bb449dca6 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -116,9 +116,12 @@ def generate_stream( pos = 0 incomplete_text = "" - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) text = find_printable_text(data["text"][pos:]) meta_info = data["meta_info"] pos += len(text) diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py index daa4ac9dc1..e80b1441ce 100644 --- a/python/sglang/srt/managers/openai_protocol.py +++ b/python/sglang/srt/managers/openai_protocol.py @@ -1,12 +1,67 @@ -from dataclasses import dataclass -from typing import Any, List, Optional, Union +import time +from typing import Dict, List, Optional, Union +from pydantic import BaseModel, Field -@dataclass -class CompletionRequest: - prompt: Union[str, List[Any]] - model: str = "default" - temperature: Optional[float] = 0.7 + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[str]] + suffix: Optional[str] = None max_tokens: Optional[int] = 16 + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 n: Optional[int] = 1 - stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[str] = None + + +class CompletionResponse(BaseModel): + id: str + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[str] = None + + +class CompletionStreamResponse(BaseModel): + id: str + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 5f2c2f2890..5a171aa9ad 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -1,7 +1,5 @@ """SRT: SGLang Runtime""" -import argparse import asyncio -import dataclasses import json import multiprocessing as mp import sys @@ -16,12 +14,19 @@ import requests import uvicorn import uvloop -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.managers.openai_protocol import CompletionRequest +from sglang.srt.managers.openai_protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + UsageInfo +) from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs @@ -41,39 +46,97 @@ async def get_model_info(): } return result +async def stream_generator(obj): + async for out in tokenizer_manager.generate_request(obj): + yield out + @app.post("/generate") async def generate_request(obj: GenerateReqInput): obj.post_init() - result_generator = tokenizer_manager.generate_request(obj) if obj.stream: async def stream_results(): - async for out in result_generator: - yield (json.dumps(out) + "\0").encode("utf-8") - + async for out in stream_generator(obj): + yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return StreamingResponse(stream_results(), media_type="text/event-stream") - else: - ret = await result_generator.__anext__() - return ret + + ret = await tokenizer_manager.generate_request(obj).__anext__() + return ret @app.post("/v1/completions") -async def v1_completions(obj: CompletionRequest): - assert obj.n == 1 - obj = GenerateReqInput( - text=obj.prompt, +async def v1_completions(raw_request: Request): + request_json = await raw_request.json() + request = CompletionRequest(**request_json) + + # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. + assert request.n == 1 + + adapted_request = GenerateReqInput( + text=request.prompt, sampling_params={ - "temperature": obj.temperature, - "max_new_tokens": obj.max_tokens, - "stop": obj.stop, + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "stop": request.stop, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, }, + stream=request.stream, ) - ret = await generate_request(obj) - return { - "choices": [{"text": ret["text"]}], - } + adapted_request.post_init() + + if adapted_request.stream: + async def gnerate_stream_resp(): + stream_buffer = "" + async for content in stream_generator(adapted_request): + text = content["text"] + delta = text[len(stream_buffer):] + stream_buffer = text + choice_data = CompletionResponseStreamChoice( + index=0, + text=delta, + logprobs=None, + finish_reason=None, + ) + chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + object="text_completion", + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") + + + # Non-streaming response. + ret = await generate_request(adapted_request) + + choice_data = CompletionResponseChoice( + index=0, + text=ret["text"], + logprobs=None, + finish_reason=None, # TODO(comaniac): Add finish reason. + ) + + prompt_tokens = ret["meta_info"]["prompt_tokens"] + completion_tokens = ret["meta_info"]["completion_tokens"] + response = CompletionResponse( + id=ret["meta_info"]["id"], + model=request.model, + choices=[choice_data], + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response def launch_server(server_args, pipe_finish_writer): diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index 048ee363f6..e397f137da 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -25,7 +25,7 @@ "text": "The capital of France is", "sampling_params": { "temperature": 0, - "max_new_tokens": 1024, + "max_new_tokens": 512, }, "stream": True, }, @@ -33,9 +33,12 @@ ) prev = 0 - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) output = data["text"].strip() print(output[prev:], end="", flush=True) prev = len(output) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py new file mode 100644 index 0000000000..f5db747cec --- /dev/null +++ b/test/srt/test_openai_server.py @@ -0,0 +1,54 @@ +""" +python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 + +Output: +The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo +""" + +import argparse + +import openai + + +def test_completion(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + response = client.completions.create( + model="default", + prompt="The capital of France is", + temperature=0, + max_tokens=32, + ) + print(response.choices[0].text) + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + +def test_completion_stream(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + response = client.completions.create( + model="default", + prompt="The capital of France is", + temperature=0, + max_tokens=32, + stream=True, + ) + for r in response: + print(r.choices[0].text, end="", flush=True) + assert r.id + assert r.created + assert r.usage.prompt_tokens > 0 + assert r.usage.completion_tokens > 0 + assert r.usage.total_tokens > 0 + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") + args = parser.parse_args() + + test_completion(args) + test_completion_stream(args)