Skip to content

Commit

Permalink
Fix RuntimeEndpoint (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Mar 11, 2024
1 parent d5ae2eb commit 13662fd
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
3 changes: 3 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,21 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend


def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return False
return backend.flush_cache()


def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return None
return backend.get_server_args()


def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
Expand Down
14 changes: 10 additions & 4 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@


class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
api_key: Optional[str] = None,
verify: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True

Expand Down Expand Up @@ -61,7 +67,7 @@ def cache_prefix(self, prefix_str: str):
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
Expand All @@ -71,7 +77,7 @@ def commit_lazy_operations(self, s: StreamExecutor):
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
Expand Down Expand Up @@ -159,7 +165,7 @@ def generate_stream(
json=data,
stream=True,
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
pos = 0
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint
Expand Down Expand Up @@ -56,11 +54,14 @@
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import handle_port_init
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

API_KEY_HEADER_NAME = "X-API-Key"


class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
Expand All @@ -77,6 +78,7 @@ async def dispatch(self, request: Request, call_next):
response = await call_next(request)
return response


app = FastAPI()
tokenizer_manager = None
chat_template_name = None
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def status_code(self):
return self.resp.status


def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
def http_request(
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}

Expand Down

0 comments on commit 13662fd

Please sign in to comment.