Skip to content

Commit

Permalink
Simplify tokenizer manager (#1899)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 3, 2024
1 parent efbc116 commit 838dcda
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 49 deletions.
13 changes: 11 additions & 2 deletions docs/references/custom_chat_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
```

If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporarily register your chat template as follows:
If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.

## JSON Format
You can load the JSON format, which is defined by `conversation.py`.

```json
{
Expand All @@ -28,4 +30,11 @@ Meanwhile, you can also temporarily register your chat template as follows:

```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
```

## Jinja Format
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating

```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
```
16 changes: 12 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def post_init(self):
if self.parallel_sample_num == 1:
num = self.batch_size
else:
# FIXME support cascade inference
# first bs samples are used for caching the prefix for parallel sampling
# The first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size

if self.image_data is None:
Expand Down Expand Up @@ -196,6 +195,9 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

# Whether it is a single request or a batch request
is_single: bool = True

def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
Expand Down Expand Up @@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams


RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]]


@dataclass
class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
conv: Union[List[List[Dict]], List[Dict]]
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
conv: RewardReqConv
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

# Whether it is a single request or a batch request
is_single: bool = True

def post_init(self):
self.is_single = isinstance(self.conv[0], dict)

Expand Down
70 changes: 27 additions & 43 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
ProfileReq,
RewardReqConv,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
server_args: ServerArgs,
port_args: PortArgs,
):
# Parse args
self.server_args = server_args

# Init inter-process communication
Expand All @@ -114,6 +116,7 @@ def __init__(
self.context_len = server_args.context_length or get_context_length(
self.hf_config
)

# Create image processor placeholder
self.image_processor = get_dummy_image_processor()

Expand Down Expand Up @@ -165,7 +168,8 @@ async def generate_request(

if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)

obj.post_init()
Expand All @@ -187,12 +191,8 @@ async def _send_single_request(
if not is_cache_for_prefill: # The normal case with a single prompt
if index is None:
rid = obj.rid
if hasattr(obj, "conv"):
# reward model
conv = obj.conv
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
if isinstance(obj, RewardReqInput):
input_text = self._apply_chat_template(obj.conv)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text
Expand All @@ -213,12 +213,8 @@ async def _send_single_request(
top_logprobs_num = obj.top_logprobs_num
else:
rid = obj.rid[index]
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[index]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
if isinstance(obj, RewardReqInput):
input_text = self._apply_chat_template(obj.conv[input_id_index])
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[input_id_index]
Expand Down Expand Up @@ -349,8 +345,9 @@ async def _handle_single_request(
async for response in self._wait_for_response(state, obj, rid, request):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(state, obj, rid, request)
await state.event.wait()
assert state.finished
del self.rid_to_state[rid]
yield input_ids

async def _handle_batch_request(
Expand Down Expand Up @@ -456,6 +453,15 @@ def _get_sampling_params(self, sampling_params_data: dict):
sampling_params.verify()
return sampling_params

def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]:
if isinstance(conv, str):
return conv
elif isinstance(conv, list):
if isinstance(conv[0], str):
return conv
else:
return self.tokenizer.apply_chat_template(conv, tokenize=False)

async def _wait_for_response(
self,
state: ReqState,
Expand Down Expand Up @@ -491,40 +497,18 @@ async def _wait_for_response(

out["index"] = response_index

# Log requests
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj}, out={out}")

state.out_list = []
if state.finished:
# Log requests
if self.server_args.log_requests:
logger.info(f"in={obj}, out={out}")
del self.rid_to_state[rid]
yield out
break

state.event.clear()
yield out

async def _wait_for_cache_prefill_response(
self,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request: Optional[fastapi.Request] = None,
):
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue

assert state.finished
del self.rid_to_state[rid]

def flush_cache(self):
req = FlushCacheReq()
self.send_to_scheduler.send_pyobj(req)
Expand Down Expand Up @@ -553,6 +537,7 @@ async def get_memory_pool_size(self):
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()

# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if self.server_args.dp_size == 1:
res = await self.mem_pool_size
return res.size
Expand Down Expand Up @@ -638,7 +623,7 @@ async def sigterm_watchdog(self):
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
f"Gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
Expand Down Expand Up @@ -695,7 +680,6 @@ async def handle_loop(self):
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}

else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
Expand Down Expand Up @@ -747,7 +731,7 @@ def detokenize_logprob_tokens(
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
]

def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
Expand Down

0 comments on commit 838dcda

Please sign in to comment.