Skip to content

Commit

Permalink
Support extra field regex in OpenAI API (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Feb 11, 2024
1 parent 4d303c4 commit 50afed4
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None


class CompletionResponseChoice(BaseModel):
index: int
Expand Down Expand Up @@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
best_of: Optional[int] = None

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None


class ChatMessage(BaseModel):
role: Optional[str] = None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request):
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
return_logprob=request.logprobs is not None,
stream=request.stream,
Expand Down Expand Up @@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request):
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
stream=request.stream,
)
Expand Down
25 changes: 25 additions & 0 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import argparse
import json

import openai

Expand Down Expand Up @@ -151,6 +152,29 @@ def test_chat_completion_stream(args):
print()


def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)

regex = (r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": "[\w\d\s]+"\n"""
+ r"""\}"""
)

response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
Expand All @@ -169,5 +193,6 @@ def test_chat_completion_stream(args):
test_completion_stream(args, echo=True, logprobs=True)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)

0 comments on commit 50afed4

Please sign in to comment.