Skip to content

Commit

Permalink
[Minor] Improve the code style in TokenizerManager (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 27, 2024
1 parent 3fdab91 commit 0736b27
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 37 deletions.
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
logit_bias[i] = int_token_logit_bias
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias

# Set fields
self.input_ids = torch.tensor(
Expand Down
89 changes: 54 additions & 35 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,10 @@ async def generate_request(self, obj: GenerateReqInput, request=None):
async for response in self._handle_batch_request(obj, request):
yield response

async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
if is_prefill:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[0]
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
else:
async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False
):
if not is_cache_for_prefill:
rid = obj.rid if index is None else obj.rid[index]
input_text = obj.text if index is None else obj.text[index]
input_ids = (
Expand All @@ -177,6 +163,22 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals
top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
)
else:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[0]
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]

tokenized_obj = TokenizedGenerateReqInput(
rid,
Expand All @@ -196,26 +198,26 @@ async def _handle_single_request(self, obj, request, index=None, is_prefill=Fals
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
if is_prefill:
await self._wait_for_prefill_response(event, state, obj, request, rid)
yield input_ids
else:
if not is_cache_for_prefill:
async for response in self._wait_for_response(
event, state, obj, rid, request
):
yield response
else:
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids

async def _handle_batch_request(self, obj, request):
async def _handle_batch_request(self, obj: GenerateReqInput, request):
batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1)

if parallel_sample_num != 1:
## send prefill requests
# Send prefill requests to cache the common input
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
async for input_id in self._handle_single_request(
obj, request, index=i, is_prefill=True
obj, request, index=i, is_cache_for_prefill=True
):
if input_id_result is not None:
input_id_result.append(input_id)
Expand All @@ -224,6 +226,7 @@ async def _handle_batch_request(self, obj, request):
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]

# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
Expand Down Expand Up @@ -308,17 +311,15 @@ async def _handle_batch_request(self, obj, request):

yield output_list

def _validate_input_length(self, input_ids):
def _validate_input_length(self, input_ids: List[int]):
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)

def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
def _get_sampling_params(self, sampling_params_data: dict):
sampling_params = SamplingParams(**sampling_params_data)
if max_new_tokens is not None:
sampling_params.max_new_tokens = max_new_tokens
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
Expand All @@ -332,7 +333,14 @@ async def _get_pixel_values(self, image_data):
else:
return None, None, None

async def _wait_for_response(self, event, state, obj, rid, request):
async def _wait_for_response(
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request,
):
while True:
try:
await asyncio.wait_for(event.wait(), timeout=4)
Expand Down Expand Up @@ -361,7 +369,14 @@ async def _wait_for_response(self, event, state, obj, rid, request):
event.clear()
yield out

async def _wait_for_prefill_response(self, event, state, obj, request, rid):
async def _wait_for_cache_prefill_response(
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request,
):
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
Expand All @@ -380,7 +395,7 @@ def flush_cache(self):
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)

def abort_request(self, rid):
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
Expand Down Expand Up @@ -426,7 +441,11 @@ async def handle_loop(self):
state.event.set()

def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
self,
ret: dict,
return_logprob: bool,
top_logprobs_num: int,
return_text_in_logprobs: bool,
):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
Expand All @@ -450,7 +469,7 @@ def convert_logprob_style(
)
return ret

def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

Expand All @@ -461,7 +480,7 @@ def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]

def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
for i, t in enumerate(top_logprobs):
if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/test/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def decode_json(s):
s += "}"

ret = decode_json.run()
js_obj = json.loads(ret["json_output"])
try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print(ret["json_output"])
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)

Expand Down

0 comments on commit 0736b27

Please sign in to comment.