From 7223b3aed5a978b89f3fd8062434b005fd95342f Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 18 Nov 2024 06:14:14 +0000 Subject: [PATCH] add session control --- .../srt/managers/detokenizer_manager.py | 1 + python/sglang/srt/managers/io_struct.py | 27 ++++ python/sglang/srt/managers/schedule_batch.py | 3 + python/sglang/srt/managers/scheduler.py | 66 +++++++-- .../sglang/srt/managers/session_controller.py | 62 ++++++++ .../sglang/srt/managers/tokenizer_manager.py | 38 +++++ python/sglang/srt/server.py | 26 ++++ test/srt/test_session_id.py | 133 ++++++++++++++++++ 8 files changed, 348 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/managers/session_controller.py create mode 100644 test/srt/test_session_id.py diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 5db8ce4f1a..036ac3a789 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -175,6 +175,7 @@ def event_loop(self): output_strs=output_strs, meta_info=recv_obj.meta_info, finished_reason=recv_obj.finished_reason, + session_ids=recv_obj.session_ids, ) ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 96009ffb26..f778124191 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -56,6 +56,10 @@ class GenerateReqInput: # LoRA related lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + # Session id info for continual prompting + session_id: Optional[Union[List[str], str]] = None + session_rid: Optional[Union[List[str], str]] = None + def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None @@ -200,6 +204,10 @@ class TokenizedGenerateReqInput: # LoRA related lora_path: Optional[str] = None # None means just use the base model + # Session id info for continual prompting + session_id: Optional[int] = None + session_rid: Optional[str] = None + @dataclass class EmbeddingReqInput: @@ -293,6 +301,8 @@ class BatchTokenIDOut: meta_info: List[Dict] finished_reason: List[BaseFinishReason] no_stop_trim: List[bool] + # The updated session unique id + session_ids: List[str] @dataclass @@ -305,6 +315,8 @@ class BatchStrOut: meta_info: List[Dict] # The finish reason finished_reason: List[BaseFinishReason] + # The update session unique id + session_ids: List[str] @dataclass @@ -357,3 +369,18 @@ class GetMemPoolSizeReq: @dataclass class GetMemPoolSizeReqOutput: size: int + + +@dataclass +class OpenSessionReqInput: + capacity_of_str_len: int + + +@dataclass +class CloseSessionReqInput: + session_id: str + + +@dataclass +class OpenSessionReqOutput: + session_id: str diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ca08a6af30..026b98a333 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -180,6 +180,7 @@ def __init__( origin_input_ids: Tuple[int], sampling_params: SamplingParams, lora_path: Optional[str] = None, + session_id: Optional[str] = None, ): # Input and output info self.rid = rid @@ -188,6 +189,8 @@ def __init__( self.origin_input_ids = origin_input_ids self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids + self.session_id = session_id + self.sampling_params = sampling_params self.lora_path = lora_path diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a411e7af7d..be0a0f699c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -37,9 +37,12 @@ AbortReq, BatchEmbeddingOut, BatchTokenIDOut, + CloseSessionReqInput, FlushCacheReq, GetMemPoolSizeReq, GetMemPoolSizeReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -59,6 +62,7 @@ PrefillAdder, SchedulePolicy, ) +from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache @@ -106,6 +110,9 @@ def __init__( self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics + # Session info + self.sessions = {} + # Init inter-process communication context = zmq.Context(2) @@ -509,6 +516,11 @@ def process_input_requests(self, recv_reqs: List): self.start_profile() else: self.stop_profile() + elif isinstance(recv_req, OpenSessionReqInput): + session_id = self.open_session(recv_req) + self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) + elif isinstance(recv_req, CloseSessionReqInput): + self.close_session(recv_req) elif isinstance(recv_req, GetMemPoolSizeReq): self.send_to_tokenizer.send_pyobj( GetMemPoolSizeReqOutput(self.max_total_num_tokens) @@ -520,14 +532,30 @@ def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, ): - req = Req( - recv_req.rid, - recv_req.input_text, - recv_req.input_ids, - recv_req.sampling_params, - lora_path=recv_req.lora_path, - ) - req.tokenizer = self.tokenizer + if recv_req.session_id is None or recv_req.session_id not in self.sessions: + req = Req( + recv_req.rid, + recv_req.input_text, + recv_req.input_ids, + recv_req.sampling_params, + lora_path=recv_req.lora_path, + ) + req.tokenizer = self.tokenizer + if recv_req.session_id is not None: + req.finished_reason = FINISH_ABORT( + f"Invalid request: session id {recv_req.session_id} does not exist" + ) + self.waiting_queue.append(req) + return + else: + # Handle sessions + session = self.sessions[recv_req.session_id] + req, new_session_id = session.create_req(recv_req, self.tokenizer) + del self.sessions[recv_req.session_id] + self.sessions[new_session_id] = session + if isinstance(req.finished_reason, FINISH_ABORT): + self.waiting_queue.append(req) + return # Image inputs if recv_req.image_inputs is not None: @@ -1151,6 +1179,7 @@ def stream_output(self, reqs: List[Req]): output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_no_stop_trim = [] + output_session_ids = [] else: # embedding or reward model output_embeddings = [] @@ -1178,6 +1207,7 @@ def stream_output(self, reqs: List[Req]): req.sampling_params.spaces_between_special_tokens ) output_no_stop_trim.append(req.sampling_params.no_stop_trim) + output_session_ids.append(req.session_id) meta_info = { "prompt_tokens": len(req.origin_input_ids), @@ -1228,6 +1258,7 @@ def stream_output(self, reqs: List[Req]): output_meta_info, output_finished_reason, output_no_stop_trim, + output_session_ids, ) ) else: # embedding or reward model @@ -1330,6 +1361,25 @@ def stop_profile(self) -> None: ) logger.info("Profiler is done") + def open_session(self, recv_req: OpenSessionReqInput) -> str: + # handle error + session_id = recv_req.session_id + if session_id in self.sessions: + logger.warning(f"session id {session_id} already exist, cannot open.") + else: + self.sessions[session_id] = Session( + recv_req.capacity_of_str_len, session_id + ) + return session_id + + def close_session(self, recv_req: CloseSessionReqInput): + # handle error + session_id = recv_req.session_id + if session_id not in self.sessions: + logger.warning(f"session id {session_id} does not exist, cannot delete.") + else: + del self.sessions[session_id] + def run_scheduler_process( server_args: ServerArgs, diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py new file mode 100644 index 0000000000..0eeab39e11 --- /dev/null +++ b/python/sglang/srt/managers/session_controller.py @@ -0,0 +1,62 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import copy +import uuid +from dataclasses import dataclass +from typing import Optional + +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput +from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req + + +class Session: + def __init__(self, capacity_of_str_len: int, session_id: str = None): + self.session_id = session_id if session_id is not None else uuid.uuid4().hex + self.capacity_of_str_len = capacity_of_str_len + self.reqs: List[Req] = [] + + def create_req(self, req: TokenizedGenerateReqInput, tokenizer): + # renew session id + self.session_id = uuid.uuid4().hex + if req.session_rid is not None: + while len(self.reqs) > 0: + if self.reqs[-1].rid == req.session_rid: + break + self.reqs = self.reqs[:-1] + if len(self.reqs) > 0: + input_ids = ( + self.reqs[-1].origin_input_ids + + self.reqs[-1].output_ids[ + : self.reqs[-1].sampling_params.max_new_tokens + ] + + req.input_ids + ) + else: + input_ids = req.input_ids + new_req = Req( + req.rid, + None, + input_ids, + req.sampling_params, + lora_path=req.lora_path, + session_id=self.session_id, + ) + new_req.tokenizer = tokenizer + if req.session_rid is not None and len(self.reqs) == 0: + new_req.finished_reason = FINISH_ABORT( + f"Invalid request: requested session rid {req.session_rid} does not exist in the session history" + ) + else: + self.reqs.append(new_req) + return new_req, self.session_id diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1db60ef49a..a0475f8b46 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -23,6 +23,7 @@ import signal import sys import time +import uuid from typing import Dict, List, Optional, Tuple, Union import fastapi @@ -42,11 +43,14 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + CloseSessionReqInput, EmbeddingReqInput, FlushCacheReq, GenerateReqInput, GetMemPoolSizeReq, GetMemPoolSizeReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, ProfileReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -146,6 +150,9 @@ def __init__( self.model_update_lock = asyncio.Lock() self.model_update_result = None + # For session info + self.session_futures = {} # session_id -> asyncio event + # Others self.gracefully_exit = False @@ -211,6 +218,8 @@ async def _tokenize_one_request( return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num + session_id = obj.session_id + session_rid = obj.session_rid if len(input_ids) >= self.context_len: raise ValueError( @@ -236,6 +245,8 @@ async def _tokenize_one_request( top_logprobs_num, obj.stream, obj.lora_path, + session_id=session_id, + session_rid=session_rid, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -451,6 +462,26 @@ async def update_weights( else: return False, "Another update is in progress. Please try again later." + async def open_session( + self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None + ): + if self.to_create_loop: + self.create_handle_loop() + + session_id = uuid.uuid4().hex + obj.session_id = session_id + self.send_to_scheduler.send_pyobj(obj) + self.session_futures[session_id] = asyncio.Future() + session_id = await self.session_futures[session_id] + del self.session_futures[session_id] + return session_id + + async def close_session( + self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None + ): + assert not self.to_create_loop, "close session should not be the first request" + await self.send_to_scheduler.send_pyobj(obj) + def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): @@ -521,6 +552,11 @@ async def handle_loop(self): if len(self.mem_pool_size_tmp) == self.server_args.dp_size: self.mem_pool_size.set_result(self.mem_pool_size_tmp) continue + elif isinstance(recv_obj, OpenSessionReqOutput): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id + ) + continue assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) @@ -536,11 +572,13 @@ async def handle_loop(self): out_dict = { "text": recv_obj.output_strs[i], "meta_info": recv_obj.meta_info[i], + "session_id": recv_obj.session_ids[i], } elif isinstance(recv_obj, BatchTokenIDOut): out_dict = { "token_ids": recv_obj.output_ids[i], "meta_info": recv_obj.meta_info[i], + "session_id": recv_obj.session_ids[i], } else: assert isinstance(recv_obj, BatchEmbeddingOut) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 1ebaf16d98..5621933a6f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -50,8 +50,10 @@ ) from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( + CloseSessionReqInput, EmbeddingReqInput, GenerateReqInput, + OpenSessionReqInput, UpdateWeightReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process @@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ) +@app.api_route("/open_session", methods=["GET", "POST"]) +async def open_session(obj: OpenSessionReqInput, request: Request): + """Open a session, and return its unique session id.""" + try: + session_id = await tokenizer_manager.open_session(obj, request) + return session_id + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +@app.api_route("/close_session", methods=["GET", "POST"]) +async def close_session(obj: CloseSessionReqInput, request: Request): + """Close the session""" + try: + await tokenizer_manager.close_session(obj, request) + return Response(status_code=200) + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @time_func_latency async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" diff --git a/test/srt/test_session_id.py b/test/srt/test_session_id.py new file mode 100644 index 0000000000..ae56e844b3 --- /dev/null +++ b/test/srt/test_session_id.py @@ -0,0 +1,133 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# FIXME: Make it a CI test + +import requests + +from sglang.srt.hf_transformers_utils import get_tokenizer + +url = "http://localhost:30000" + +# Open a session +response = requests.post( + url + "/open_session", + json={"capacity_of_str_len": 1000}, +) +session_id = response.json() +print("session_id", session_id, "\n") + +# Prefill only +prompt = "chunk 1" +response = requests.post( + url + "/generate", + json={ + "text": prompt, + "session_id": session_id, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + }, + }, +) +print(response.json(), "\n") +session_id = response.json()["session_id"] + +# Generate +prompt = "Chunk 2" +response = requests.post( + url + "/generate", + json={ + "text": prompt, + "session_id": session_id, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, +) +print(response.json(), "\n") +session_id = response.json()["session_id"] +rid = response.json()["meta_info"]["id"] + +# Generate +prompt = "Chunk 3" +response = requests.post( + url + "/generate", + json={ + "text": prompt, + "session_id": session_id, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 2, + }, + }, +) +print(response.json(), "\n") +session_id = response.json()["session_id"] +rid_to_del = response.json()["meta_info"]["id"] + +# Interrupt and re-generate +prompt = "Chunk 4" +response = requests.post( + url + "/generate", + json={ + "text": prompt, + "session_id": session_id, + "session_rid": rid, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, +) +print(response.json(), "\n") +session_id = response.json()["session_id"] + +# Query a session based on a deleted request, should see finish reason abort +prompt = "Chunk 4" +response = requests.post( + url + "/generate", + json={ + "text": prompt, + "session_id": session_id, + "session_rid": rid_to_del, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, + }, +) +print(response.json(), "\n") + +# Close session +ret = requests.post( + url + "/close_session", + json={"session_id": session_id}, +) +print(ret, "\n") + +# Query a deleted session, should see finish reason abort +prompt = "chunk 1" +response = requests.post( + url + "/generate", + json={ + "text": prompt, + "session_id": session_id, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + }, + }, +) +print(response.json(), "\n")