From a8f482dd47cca999e9cf8f2fc158e4caae3d9b65 Mon Sep 17 00:00:00 2001 From: comaniac Date: Fri, 9 Feb 2024 17:51:06 +0000 Subject: [PATCH 1/2] Support extra field regex in OpenAI API --- python/sglang/srt/managers/openai_protocol.py | 6 +++++ python/sglang/srt/server.py | 2 ++ test/srt/test_openai_server.py | 24 +++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py index 320eab42b76..1cf1fed73f7 100644 --- a/python/sglang/srt/managers/openai_protocol.py +++ b/python/sglang/srt/managers/openai_protocol.py @@ -36,6 +36,9 @@ class CompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + regex: Optional[str] = None + class CompletionResponseChoice(BaseModel): index: int @@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None best_of: Optional[int] = None + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + regex: Optional[str] = None + class ChatMessage(BaseModel): role: Optional[str] = None diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0a9e6d24bbe..e5b066769f0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request): "top_p": request.top_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, + "regex": request.regex, }, return_logprob=request.logprobs is not None, stream=request.stream, @@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request): "top_p": request.top_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, + "regex": request.regex, }, stream=request.stream, ) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 01aa53e5bc6..58f1d21e86f 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -14,6 +14,7 @@ """ import argparse +import json import openai @@ -151,6 +152,29 @@ def test_chat_completion_stream(args): print() +def test_regex(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + + regex = (r"""\{\n""" + + r""" "name": "[\w]+",\n""" + + r""" "population": "[\w\d\s]+"\n""" + + r"""\}""" + ) + + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"regex": regex}, + ) + text = response.choices[0].message.content + print(json.loads(text)) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") From f58b2d016f489009b1a2cbe1c87d9bded880dbe9 Mon Sep 17 00:00:00 2001 From: comaniac Date: Fri, 9 Feb 2024 18:02:36 +0000 Subject: [PATCH 2/2] run test --- test/srt/test_openai_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 58f1d21e86f..2bff16960e2 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -193,5 +193,6 @@ def test_regex(args): test_completion_stream(args, echo=True, logprobs=True) test_chat_completion(args) test_chat_completion_stream(args) + test_regex(args) if args.test_image: test_chat_completion_image(args)