Skip to content

Commit

Permalink
add session control
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Nov 20, 2024
1 parent 63a395b commit 7223b3a
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 8 deletions.
1 change: 1 addition & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def event_loop(self):
output_strs=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
session_ids=recv_obj.session_ids,
)
)

Expand Down
27 changes: 27 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class GenerateReqInput:
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

# Session id info for continual prompting
session_id: Optional[Union[List[str], str]] = None
session_rid: Optional[Union[List[str], str]] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
Expand Down Expand Up @@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
# LoRA related
lora_path: Optional[str] = None # None means just use the base model

# Session id info for continual prompting
session_id: Optional[int] = None
session_rid: Optional[str] = None


@dataclass
class EmbeddingReqInput:
Expand Down Expand Up @@ -293,6 +301,8 @@ class BatchTokenIDOut:
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool]
# The updated session unique id
session_ids: List[str]


@dataclass
Expand All @@ -305,6 +315,8 @@ class BatchStrOut:
meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason]
# The update session unique id
session_ids: List[str]


@dataclass
Expand Down Expand Up @@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
@dataclass
class GetMemPoolSizeReqOutput:
size: int


@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int


@dataclass
class CloseSessionReqInput:
session_id: str


@dataclass
class OpenSessionReqOutput:
session_id: str
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
session_id: Optional[str] = None,
):
# Input and output info
self.rid = rid
Expand All @@ -188,6 +189,8 @@ def __init__(
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.session_id = session_id

self.sampling_params = sampling_params
self.lora_path = lora_path

Expand Down
66 changes: 58 additions & 8 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand All @@ -59,6 +62,7 @@
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
Expand Down Expand Up @@ -106,6 +110,9 @@ def __init__(
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics

# Session info
self.sessions = {}

# Init inter-process communication
context = zmq.Context(2)

Expand Down Expand Up @@ -509,6 +516,11 @@ def process_input_requests(self, recv_reqs: List):
self.start_profile()
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
Expand All @@ -520,14 +532,30 @@ def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
lora_path=recv_req.lora_path,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
lora_path=recv_req.lora_path,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist"
)
self.waiting_queue.append(req)
return
else:
# Handle sessions
session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
return

# Image inputs
if recv_req.image_inputs is not None:
Expand Down Expand Up @@ -1151,6 +1179,7 @@ def stream_output(self, reqs: List[Req]):
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model
output_embeddings = []

Expand Down Expand Up @@ -1178,6 +1207,7 @@ def stream_output(self, reqs: List[Req]):
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)

meta_info = {
"prompt_tokens": len(req.origin_input_ids),
Expand Down Expand Up @@ -1228,6 +1258,7 @@ def stream_output(self, reqs: List[Req]):
output_meta_info,
output_finished_reason,
output_no_stop_trim,
output_session_ids,
)
)
else: # embedding or reward model
Expand Down Expand Up @@ -1330,6 +1361,25 @@ def stop_profile(self) -> None:
)
logger.info("Profiler is done")

def open_session(self, recv_req: OpenSessionReqInput) -> str:
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id

def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]


def run_scheduler_process(
server_args: ServerArgs,
Expand Down
62 changes: 62 additions & 0 deletions python/sglang/srt/managers/session_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import copy
import uuid
from dataclasses import dataclass
from typing import Optional

from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req


class Session:
def __init__(self, capacity_of_str_len: int, session_id: str = None):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len
self.reqs: List[Req] = []

def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
# renew session id
self.session_id = uuid.uuid4().hex
if req.session_rid is not None:
while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid:
break
self.reqs = self.reqs[:-1]
if len(self.reqs) > 0:
input_ids = (
self.reqs[-1].origin_input_ids
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
)
else:
input_ids = req.input_ids
new_req = Req(
req.rid,
None,
input_ids,
req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
)
new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT(
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
)
else:
self.reqs.append(new_req)
return new_req, self.session_id
38 changes: 38 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import signal
import sys
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union

import fastapi
Expand All @@ -42,11 +43,14 @@
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
CloseSessionReqInput,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand Down Expand Up @@ -146,6 +150,9 @@ def __init__(
self.model_update_lock = asyncio.Lock()
self.model_update_result = None

# For session info
self.session_futures = {} # session_id -> asyncio event

# Others
self.gracefully_exit = False

Expand Down Expand Up @@ -211,6 +218,8 @@ async def _tokenize_one_request(
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id
session_rid = obj.session_rid

if len(input_ids) >= self.context_len:
raise ValueError(
Expand All @@ -236,6 +245,8 @@ async def _tokenize_one_request(
top_logprobs_num,
obj.stream,
obj.lora_path,
session_id=session_id,
session_rid=session_rid,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
Expand Down Expand Up @@ -451,6 +462,26 @@ async def update_weights(
else:
return False, "Another update is in progress. Please try again later."

async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()

session_id = uuid.uuid4().hex
obj.session_id = session_id
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id]
del self.session_futures[session_id]
return session_id

async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)

def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
Expand Down Expand Up @@ -521,6 +552,11 @@ async def handle_loop(self):
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue

assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
Expand All @@ -536,11 +572,13 @@ async def handle_loop(self):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
Expand Down
Loading

0 comments on commit 7223b3a

Please sign in to comment.