diff --git a/examples/frontend_language/quick_start/shortfin_example_chat.py b/examples/frontend_language/quick_start/shortfin_example_chat.py new file mode 100644 index 0000000000..3e943477dc --- /dev/null +++ b/examples/frontend_language/quick_start/shortfin_example_chat.py @@ -0,0 +1,66 @@ +""" +Usage: +# Prior to running this script, you need to have a Shortfin server running. +# Build: +# https://github.com/nod-ai/SHARK-Platform/blob/main/shortfin/README.md +# Run: +# https://github.com/nod-ai/SHARK-Platform/blob/main/shortfin/python/shortfin_apps/llm/README.md + +python3 shortfin_example_chat.py --base_url +""" + +import argparse +import os + +import sglang as sgl + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + + +def single(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + + +def stream(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True, + ) + + for out in state.text_iter(): + print(out, end="", flush=True) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base_url", default="http://localhost:8000") + args = parser.parse_args() + base_url = args.base_url + + backend = sgl.Shortfin(base_url=base_url) + sgl.set_default_backend(backend) + + # Run a single request + print("\n========== single ==========\n") + single() + + # Stream output + print("\n========== stream ==========\n") + stream() diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 3c4457c983..c4d5108c57 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -74,5 +74,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") +Shortfin = LazyImport("sglang.lang.backend.shortfin", "Shortfin") -__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"] +__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "Shortfin", "RuntimeEndpoint"] diff --git a/python/sglang/lang/backend/shortfin.py b/python/sglang/lang/backend/shortfin.py new file mode 100644 index 0000000000..d67a49e0d1 --- /dev/null +++ b/python/sglang/lang/backend/shortfin.py @@ -0,0 +1,95 @@ +import json +from typing import Optional + +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams +from sglang.utils import http_request + + +class Shortfin(BaseBackend): + def __init__( + self, + chat_template=None, + base_url: Optional[str] = None, + timeout: Optional[float] = None, + ): + super().__init__() + + if base_url is None: + raise ValueError("`base_url` is required for Shortfin backend") + + self.chat_template = chat_template or get_chat_template_by_model_path("default") + + self.client_params = {"base_url": base_url, "timeout": timeout} + + def _make_generate_request(self, shortfin_kwargs, stream=False): + resp = http_request( + f"{self.client_params['base_url']}/generate", + json=shortfin_kwargs, + timeout=self.client_params["timeout"], + stream=stream, + ) + self._assert_success(resp) + return resp + + def _assert_success(self, res): + if res.status_code != 200: + raise RuntimeError(res.json()) + + def _clean_response_message(self, text): + return text.replace(b"data: ", b"").strip(b"\n") + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + shortfin_kwargs = sampling_params.to_shortfin_kwargs() + + messages = s.text_ + shortfin_kwargs["text"] = messages + + resp = http_request( + f"{self.client_params['base_url']}/generate", + json=shortfin_kwargs, + timeout=self.client_params["timeout"], + ) + self._assert_success(resp) + + response_message = resp.resp.read() + response_message = self._clean_response_message(response_message) + return response_message.decode("utf-8"), {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + shortfin_kwargs = sampling_params.to_shortfin_kwargs() + shortfin_kwargs["stream"] = True + + messages = s.text_ + shortfin_kwargs["text"] = messages + + resp = http_request( + f"{self.client_params['base_url']}/generate", + json=shortfin_kwargs, + stream=True, + timeout=self.client_params["timeout"], + ) + self._assert_success(resp) + + prefix = b"" + for chunk in resp: + if chunk == b"data: [DONE]\n\n": + break + text = chunk[len(prefix) :] + prefix += text.strip(b"\n") + text = self._clean_response_message(text) + if text is not None: + yield text.decode("utf-8"), {} diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index d3c010108e..86b5a6e7ae 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -112,6 +112,23 @@ def to_litellm_kwargs(self): "presence_penalty": self.presence_penalty, } + def to_shortfin_kwargs(self): + kwargs = { + "return_logprob": self.return_logprob, + "logprob_start_len": self.logprob_start_len, + "top_logprobs_num": self.top_logprobs_num, + } + kwargs["return_text_in_logprobs"] = ( + self.return_text_in_logprobs + if self.return_text_in_logprobs is not None + else False + ) + kwargs["sampling_params"] = { + "max_tokens": self.max_new_tokens, + "temperature": self.temperature, + } + return kwargs + def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens, diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 139a01c42e..b4fc37b2a7 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -79,7 +79,7 @@ def status_code(self): return self.resp.status -def http_request(url, json=None, stream=False, api_key=None, verify=None): +def http_request(url, json=None, stream=False, api_key=None, verify=None, timeout=None): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"} @@ -88,7 +88,9 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None): headers["Authorization"] = f"Bearer {api_key}" if stream: - return requests.post(url, json=json, stream=True, headers=headers) + return requests.post( + url, json=json, stream=True, headers=headers, timeout=timeout + ) else: req = urllib.request.Request(url, headers=headers) if json is None: @@ -97,7 +99,9 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None): data = bytes(dumps(json), encoding="utf-8") try: - resp = urllib.request.urlopen(req, data=data, cafile=verify) + resp = urllib.request.urlopen( + req, data=data, cafile=verify, timeout=timeout + ) return HttpResponse(resp) except urllib.error.HTTPError as e: return HttpResponse(e) diff --git a/test/lang/test_shortfin_backend.py b/test/lang/test_shortfin_backend.py new file mode 100644 index 0000000000..806a358a38 --- /dev/null +++ b/test/lang/test_shortfin_backend.py @@ -0,0 +1,25 @@ +import os +import unittest + +from sglang import Shortfin, set_default_backend +from sglang.test.test_programs import test_mt_bench, test_stream + + +class TestShortfinBackend(unittest.TestCase): + chat_backend = None + + @classmethod + def setUpClass(cls): + base_url = os.environ["SHORTFIN_BASE_URL"] + cls.chat_backend = Shortfin(base_url=base_url) + set_default_backend(cls.chat_backend) + + def test_mt_bench(self): + test_mt_bench() + + def test_stream(self): + test_stream() + + +if __name__ == "__main__": + unittest.main()