Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support stream=True in v1/completions #49

Merged
merged 6 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,25 @@ curl http://localhost:30000/generate \
}
}'
```

Learn more about the argument format [here](docs/sampling_params.md).

### OpenAI Compatible API

In addition, the server supports an experimental OpenAI-compatible API.

```python
import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
)
print(response)
```

### Additional Arguments
- Add `--tp 2` to enable tensor parallelism.
```
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [

[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba"]
"interegular", "lark", "numba", "pydantic"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def generate_stream(
pos = 0

incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"]
pos += len(text)
Expand Down
71 changes: 63 additions & 8 deletions python/sglang/srt/managers/openai_protocol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,67 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Union
import time
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field

@dataclass
class CompletionRequest:
prompt: Union[str, List[Any]]
model: str = "default"
temperature: Optional[float] = 0.7

class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None


class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None


class CompletionResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None


class CompletionStreamResponse(BaseModel):
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
107 changes: 85 additions & 22 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""SRT: SGLang Runtime"""
import argparse
import asyncio
import dataclasses
import json
import multiprocessing as mp
import sys
Expand All @@ -16,12 +14,19 @@
import requests
import uvicorn
import uvloop
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import CompletionRequest
from sglang.srt.managers.openai_protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo
)
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
Expand All @@ -41,39 +46,97 @@ async def get_model_info():
}
return result

async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj):
yield out


@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
result_generator = tokenizer_manager.generate_request(obj)

if obj.stream:

async def stream_results():
async for out in result_generator:
yield (json.dumps(out) + "\0").encode("utf-8")

async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(stream_results(), media_type="text/event-stream")
else:
ret = await result_generator.__anext__()
return ret

ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret


@app.post("/v1/completions")
async def v1_completions(obj: CompletionRequest):
assert obj.n == 1
obj = GenerateReqInput(
text=obj.prompt,
async def v1_completions(raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)

# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1

adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={
"temperature": obj.temperature,
"max_new_tokens": obj.max_tokens,
"stop": obj.stop,
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
stream=request.stream,
)
ret = await generate_request(obj)
return {
"choices": [{"text": ret["text"]}],
}
adapted_request.post_init()

if adapted_request.stream:
async def gnerate_stream_resp():
stream_buffer = ""
async for content in stream_generator(adapted_request):
text = content["text"]
delta = text[len(stream_buffer):]
stream_buffer = text
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
finish_reason=None,
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"

return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")


# Non-streaming response.
ret = await generate_request(adapted_request)

choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
finish_reason=None, # TODO(comaniac): Add finish reason.
)

prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
choices=[choice_data],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response


def launch_server(server_args, pipe_finish_writer):
Expand Down
11 changes: 7 additions & 4 deletions test/srt/test_httpserver_decode_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
"max_new_tokens": 512,
},
"stream": True,
},
stream=True,
)

prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
Expand Down
54 changes: 54 additions & 0 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000

Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""

import argparse

import openai


def test_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
)
print(response.choices[0].text)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0


def test_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
)
for r in response:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.created
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0
print()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
args = parser.parse_args()

test_completion(args)
test_completion_stream(args)