From da0183f21035916901763ce3571ca50befe06a95 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 21 Jan 2024 22:55:40 +0000 Subject: [PATCH] Fix openai backend --- README.md | 3 ++- python/pyproject.toml | 9 +++++---- python/sglang/backend/openai.py | 10 +++++++++- python/sglang/lang/interpreter.py | 20 ++++++++++---------- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 9292e04a17..6fe3689589 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,8 @@ def image_qa(s, image_file, question): ``` ### Constrained Decoding -Use `regex=` to specify a regular expression as a decoding constraint. +Use `regex` to specify a regular expression as a decoding constraint. +This is only supported for local models. ```python @sgl.function diff --git a/python/pyproject.toml b/python/pyproject.toml index 1966470a24..0cf288d608 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,10 +18,11 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", - "interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"] -openai = ["openai>=1.0"] -anthropic = ["anthropic"] +srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", + "zmq", "vllm>=0.2.5", "interegular", "lark", "numba", + "pydantic", "diskcache", "cloudpickle"] +openai = ["openai>=1.0", "numpy"] +anthropic = ["anthropic", "numpy"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] [project.urls] diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index c0908c15ec..d34605ecd4 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -77,7 +77,9 @@ def generate( ): if sampling_params.dtype is None: if self.is_chat_model: - assert s.text_.endswith("ASSISTANT:") + if not s.text_.endswith("ASSISTANT:"): + raise RuntimeError("This use case is not supported. " + "For OpenAI chat models, sgl.gen must be right after sgl.assistant") prompt = s.messages_ else: prompt = s.text_ @@ -149,6 +151,12 @@ def select( choices: List[str], temperature: float, ): + if self.is_chat_model: + raise NotImplementedError( + "select/choices is not supported for chat models. " + "Please try to use a non-chat model such as gpt-3.5-turbo-instruct" + ) + n_choices = len(choices) token_ids = [self.tokenizer.encode(x) for x in choices] scores = [0] * n_choices diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index f7427ddeab..2803eedae9 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -197,16 +197,7 @@ def __init__( self.stream_var_event = None def submit(self, expr: SglExpr): - if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)): - self.variable_event[expr.name] = threading.Event() - if self.stream: - self.stream_var_event[expr.name] = threading.Event() - elif isinstance(expr, SglExprList): - for e in expr.expr_list: - if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)): - self.variable_event[e.name] = threading.Event() - if self.stream: - self.stream_var_event[e.name] = threading.Event() + self._init_var_event(expr) if self.use_thread: self.queue.put(expr) @@ -467,6 +458,15 @@ def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend): src_rids = [state.stream_executor.sid for state in expr.states] self.backend.concatenate_and_append(src_rids, self.sid) + def _init_var_event(self, expr): + if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)): + self.variable_event[expr.name] = threading.Event() + if self.stream: + self.stream_var_event[expr.name] = threading.Event() + elif isinstance(expr, SglExprList): + for e in expr.expr_list: + self._init_var_event(e) + def _resolve_sampling_params(self, sampling_params): clone = None for item in [