From 70f737f08ea95eb6ab0049e80e6ca5c124ff3d2d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 03:31:22 -0800 Subject: [PATCH 1/5] return logprob --- docs/sampling_params.md | 4 +- python/sglang/backend/runtime_endpoint.py | 6 +- python/sglang/srt/layers/logits_processor.py | 41 ++++++------- python/sglang/srt/managers/io_struct.py | 34 +++++------ .../sglang/srt/managers/router/infer_batch.py | 23 ++++---- .../sglang/srt/managers/router/model_rpc.py | 57 ++++++++++--------- .../srt/managers/router/model_runner.py | 26 ++++----- .../sglang/srt/managers/tokenizer_manager.py | 8 +-- 8 files changed, 98 insertions(+), 101 deletions(-) diff --git a/docs/sampling_params.md b/docs/sampling_params.md index 08d1844101..07d07853d0 100644 --- a/docs/sampling_params.md +++ b/docs/sampling_params.md @@ -9,8 +9,8 @@ class GenerateReqInput: image_data: Optional[Union[List[str], str]] = None sampling_params: Union[List[Dict], Dict] = None rid: Optional[Union[List[str], str]] = None - return_normalized_logprob: Optional[Union[List[bool], bool]] = None - normalized_logprob_start_len: Optional[Union[List[int], int]] = None + return_logprob: Optional[Union[List[bool], bool]] = None + logprob_start_len: Optional[Union[List[int], int]] = None stream: bool = False ``` diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 2bb449dca6..fd059ebc66 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -150,13 +150,13 @@ def select( data = { "text": [s.text_ + c for c in choices], "sampling_params": {"max_new_tokens": 0}, - "return_normalized_logprob": True, - "normalized_logprob_start_len": prompt_len, + "return_logprob": True, + "logprob_start_len": prompt_len, } self._add_images(s, data) res = http_request(self.base_url + "/generate", json=data) assert res.status_code == 200 - logps = [r["meta_info"]["normalized_logprob"] for r in res.json()] + logps = [r["meta_info"]["normalized_prompt_logprob"] for r in res.json()] decision = choices[np.argmax(logps)] return decision, logps diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 7c819c34c1..dd5ad9f165 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -14,7 +14,7 @@ def __init__(self, config): self.tp_size = get_tensor_model_parallel_world_size() def forward(self, input_ids, hidden_states, weight, input_metadata): - if not input_metadata.return_normalized_logprob: + if not input_metadata.return_logprob: if input_metadata.forward_mode == ForwardMode.DECODE: last_hidden = hidden_states else: @@ -33,7 +33,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata): if self.tp_size > 1: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size] - return last_logits, None + return last_logits, None, None else: assert input_metadata.forward_mode != ForwardMode.DECODE last_index = ( @@ -51,30 +51,21 @@ def forward(self, input_ids, hidden_states, weight, input_metadata): logits = logits[:, : self.config.vocab_size] all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) - normalized_logprobs = compute_normalized_logprobs( - all_logprobs, - input_ids, - input_metadata.extend_seq_lens, - input_metadata.extend_start_loc, - ) - - last_logits = logits[last_index] - return last_logits, normalized_logprobs - + logprobs = all_logprobs[ + torch.arange(all_logprobs.shape[0], device="cuda"), + torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), + ] + logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) -def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc): - logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), - ] - logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) + start = input_metadata.extend_start_loc.clone() + end = start + input_metadata.extend_seq_lens - 2 + start.clamp_(min=0, max=logprobs.shape[0] - 1) + end.clamp_(min=0, max=logprobs.shape[0] - 1) + sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] + normalized_logprobs = sum_logp / ((input_metadata.extend_seq_lens - 1).clamp(min=1)) - start = start_loc.clone() - end = start + seq_lens - 2 - start.clamp_(min=0, max=logprobs.shape[0] - 1) - end.clamp_(min=0, max=logprobs.shape[0] - 1) - sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] - return sum_logp / ((seq_lens - 1).clamp(min=1)) + last_logits = logits[last_index] + return last_logits, logprobs, normalized_logprobs if __name__ == "__main__": @@ -105,4 +96,4 @@ def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc): print("logprobs", logprobs) print("start", start) print("end", end) - print("sum_logp", sum_logp) + print("sum_logp", sum_logp) \ No newline at end of file diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a6dc1a3808..fc1fcc8a21 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -11,8 +11,8 @@ class GenerateReqInput: image_data: Optional[Union[List[str], str]] = None sampling_params: Union[List[Dict], Dict] = None rid: Optional[Union[List[str], str]] = None - return_normalized_logprob: Optional[Union[List[bool], bool]] = None - normalized_logprob_start_len: Optional[Union[List[int], int]] = None + return_logprob: Optional[Union[List[bool], bool]] = None + logprob_start_len: Optional[Union[List[int], int]] = None stream: bool = False def post_init(self): @@ -23,10 +23,10 @@ def post_init(self): self.sampling_params = {} if self.rid is None: self.rid = uuid.uuid4().hex - if self.return_normalized_logprob is None: - self.return_normalized_logprob = False - if self.normalized_logprob_start_len is None: - self.normalized_logprob_start_len = 0 + if self.return_logprob is None: + self.return_logprob = False + if self.logprob_start_len is None: + self.logprob_start_len = 0 else: num = len(self.text) @@ -45,16 +45,16 @@ def post_init(self): else: assert isinstance(self.rid, list) - if self.return_normalized_logprob is None: - self.return_normalized_logprob = [False] * num - elif not isinstance(self.return_normalized_logprob, list): - self.return_normalized_logprob = [self.return_normalized_logprob] * num + if self.return_logprob is None: + self.return_logprob = [False] * num + elif not isinstance(self.return_logprob, list): + self.return_logprob = [self.return_logprob] * num - if self.normalized_logprob_start_len is None: - self.normalized_logprob_start_len = [0] * num - elif not isinstance(self.normalized_logprob_start_len, list): - self.normalized_logprob_start_len = [ - self.normalized_logprob_start_len + if self.logprob_start_len is None: + self.logprob_start_len = [0] * num + elif not isinstance(self.logprob_start_len, list): + self.logprob_start_len = [ + self.logprob_start_len ] * num @@ -65,8 +65,8 @@ class TokenizedGenerateReqInput: pixel_values: List[float] image_hash: int sampling_params: SamplingParams - return_normalized_logprob: bool - normalized_logprob_start_len: int + return_logprob: bool + logprob_start_len: int stream: bool diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 1e1a93c9ba..d665d86bc8 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -28,8 +28,8 @@ def __init__(self, rid): self.pixel_values = None self.image_offset = 0 self.sampling_params = None - self.return_normalized_logprob = False - self.normalized_logprob_start_len = 0 + self.return_logprob = False + self.logprob_start_len = 0 self.stream = False self.tokenizer = None @@ -37,10 +37,11 @@ def __init__(self, rid): self.finish_reason = None self.hit_stop_str = None - self.adjust_input_len = 0 + self.extend_input_len = 0 self.prefix_indices = [] self.last_node = None + self.logprob = None self.normalized_logprob = None # for constrained decoding @@ -99,7 +100,7 @@ class Batch: out_cache_loc: torch.Tensor = None out_cache_cont_start: torch.Tensor = None out_cache_cont_end: torch.Tensor = None - return_normalized_logprob: bool = False + return_logprob: bool = False # for multimodal pixel_values: List[torch.Tensor] = None @@ -119,14 +120,14 @@ class Batch: @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): - return_normalized_logprob = any(req.return_normalized_logprob for req in reqs) + return_logprob = any(req.return_logprob for req in reqs) return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, - return_normalized_logprob=return_normalized_logprob, + return_logprob=return_logprob, ) def is_empty(self): @@ -257,7 +258,7 @@ def retract_decode(self): self.tree_cache.dec_ref_counter(req.last_node) req.prefix_indices = None req.last_node = None - req.adjust_input_len = 0 + req.extend_input_len = 0 req.output_ids = [] # TODO: apply more fine-grained retraction @@ -310,8 +311,8 @@ def filter_batch(self, unfinished_indices: List[int]): self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None - self.return_normalized_logprob = any( - req.return_normalized_logprob for req in self.reqs + self.return_logprob = any( + req.return_logprob for req in self.reqs ) for item in [ @@ -336,8 +337,8 @@ def merge(self, other): [self.position_ids_offsets, other.position_ids_offsets] ) self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None - self.return_normalized_logprob = any( - req.return_normalized_logprob for req in self.reqs + self.return_logprob = any( + req.return_logprob for req in self.reqs ) for item in [ diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index b4425cf008..5824d5d294 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -214,8 +214,8 @@ def handle_generate_request( req.input_ids, pad_value ) req.sampling_params = recv_req.sampling_params - req.return_normalized_logprob = recv_req.return_normalized_logprob - req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len + req.return_logprob = recv_req.return_logprob + req.logprob_start_len = recv_req.logprob_start_len req.stream = recv_req.stream req.tokenizer = self.tokenizer @@ -240,9 +240,9 @@ def get_new_fill_batch(self): for req in self.forward_queue: prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) - if req.return_normalized_logprob: - prefix_indices = prefix_indices[: req.normalized_logprob_start_len] - req.adjust_input_len = len(req.input_ids) - len(prefix_indices) + if req.return_logprob: + prefix_indices = prefix_indices[: req.logprob_start_len] + req.extend_input_len = len(req.input_ids) - len(prefix_indices) req.prefix_indices = prefix_indices req.last_node = last_node @@ -267,32 +267,32 @@ def get_new_fill_batch(self): ) for req in self.forward_queue: - if req.return_normalized_logprob: + if req.return_logprob: # Need at least two tokens to compute normalized logprob - if req.adjust_input_len < 2: - delta = 2 - req.adjust_input_len - req.adjust_input_len += delta + if req.extend_input_len < 2: + delta = 2 - req.extend_input_len + req.extend_input_len += delta req.prefix_indices = req.prefix_indices[:-delta] if req.image_offset is not None: req.image_offset += delta - if req.adjust_input_len == 0 and req.max_new_tokens() > 0: + if req.extend_input_len == 0 and req.max_new_tokens() > 0: # Need at least one token to compute logits - req.adjust_input_len = 1 + req.extend_input_len = 1 req.prefix_indices = req.prefix_indices[:-1] if req.image_offset is not None: req.image_offset += 1 if ( - req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens + req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size - and req.adjust_input_len + new_batch_input_tokens + and req.extend_input_len + new_batch_input_tokens < self.max_prefill_num_token ): delta = self.tree_cache.inc_ref_counter(req.last_node) available_size += delta if not ( - req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens + req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size ): delta = self.tree_cache.dec_ref_counter(req.last_node) @@ -301,9 +301,9 @@ def get_new_fill_batch(self): self.token_to_kv_pool.add_refs(req.prefix_indices) can_run_list.append(req) new_batch_total_tokens += ( - req.adjust_input_len + req.max_new_tokens() + req.extend_input_len + req.max_new_tokens() ) - new_batch_input_tokens += req.adjust_input_len + new_batch_input_tokens += req.extend_input_len if len(can_run_list) == 0: return None @@ -339,11 +339,12 @@ def forward_fill_batch(self, batch: Batch): if batch.extend_num_tokens != 0: # Forward - logits, normalized_logprobs = self.model_runner.forward( - batch, ForwardMode.EXTEND, batch.return_normalized_logprob + logits, logprobs, normalized_logprobs = self.model_runner.forward( + batch, ForwardMode.EXTEND, batch.return_logprob ) # print("extend logits", logits) - if normalized_logprobs is not None: + if logprobs is not None: + logprobs = logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist() next_token_ids, next_token_probs = batch.sample(logits) @@ -354,12 +355,15 @@ def forward_fill_batch(self, batch: Batch): # Check finish condition reqs = batch.reqs - for i in range(len(reqs)): - reqs[i].output_ids = [next_token_ids[i]] - reqs[i].check_finished() + pt = 0 + for i, req in enumerqte(reqs): + req.output_ids = [next_token_ids[i]] + req.check_finished() - if normalized_logprobs is not None: - reqs[i].normalized_logprob = normalized_logprobs[i] + if logprobs is not None: + req.logprob = logprobs[pt:pt + req.extend_input_len-1] + req.normalized_logprob = normalized_logprobs[i] + pt += req.extend_input_len self.handle_finished_requests(batch) @@ -427,8 +431,9 @@ def handle_finished_requests(self, batch: Batch): "prompt_tokens": len(req.input_ids), "completion_tokens": len(req.output_ids), } - if req.return_normalized_logprob: - meta_info["normalized_logprob"] = req.normalized_logprob + if req.return_logprob: + meta_info["prompt_logprob"] = req.logprob + meta_info["normalized_prompt_logprob"] = req.normalized_logprob output_meta_info.append(meta_info) output_finished.append(req.finished) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index d08796e4d5..70cf4bdd31 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -45,7 +45,7 @@ class InputMetadata: out_cache_cont_end: torch.Tensor = None other_kv_index: torch.Tensor = None - return_normalized_logprob: bool = False + return_logprob: bool = False # for flashinfer use_flashinfer: bool = False @@ -127,7 +127,7 @@ def create( out_cache_loc, out_cache_cont_start=None, out_cache_cont_end=None, - return_normalized_logprob=False, + return_logprob=False, ): batch_size = len(req_pool_indices) start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") @@ -175,7 +175,7 @@ def create( out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, - return_normalized_logprob=return_normalized_logprob, + return_logprob=return_logprob, other_kv_index=other_kv_index, ) @@ -337,7 +337,7 @@ def forward_prefill( prefix_lens, position_ids_offsets, out_cache_loc, - return_normalized_logprob, + return_logprob, ): input_metadata = InputMetadata.create( self, @@ -348,7 +348,7 @@ def forward_prefill( prefix_lens=prefix_lens, position_ids_offsets=position_ids_offsets, out_cache_loc=out_cache_loc, - return_normalized_logprob=return_normalized_logprob, + return_logprob=return_logprob, ) return self.model.forward(input_ids, input_metadata.positions, input_metadata) @@ -361,7 +361,7 @@ def forward_extend( prefix_lens, position_ids_offsets, out_cache_loc, - return_normalized_logprob, + return_logprob, ): input_metadata = InputMetadata.create( self, @@ -372,7 +372,7 @@ def forward_extend( prefix_lens=prefix_lens, position_ids_offsets=position_ids_offsets, out_cache_loc=out_cache_loc, - return_normalized_logprob=return_normalized_logprob, + return_logprob=return_logprob, ) return self.model.forward(input_ids, input_metadata.positions, input_metadata) @@ -415,7 +415,7 @@ def forward_extend_multi_modal( prefix_lens, position_ids_offsets, out_cache_loc, - return_normalized_logprob, + return_logprob, ): input_metadata = InputMetadata.create( self, @@ -426,7 +426,7 @@ def forward_extend_multi_modal( prefix_lens=prefix_lens, position_ids_offsets=position_ids_offsets, out_cache_loc=out_cache_loc, - return_normalized_logprob=return_normalized_logprob, + return_logprob=return_logprob, ) return self.model.forward( input_ids, @@ -437,7 +437,7 @@ def forward_extend_multi_modal( ) def forward( - self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False + self, batch: Batch, forward_mode: ForwardMode, return_logprob=False ): if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: kwargs = { @@ -450,7 +450,7 @@ def forward( "position_ids_offsets": batch.position_ids_offsets, "out_cache_loc": batch.out_cache_loc, } - kwargs["return_normalized_logprob"] = return_normalized_logprob + kwargs["return_logprob"] = return_logprob return self.forward_extend_multi_modal(**kwargs) else: kwargs = { @@ -467,10 +467,10 @@ def forward( kwargs["out_cache_cont_end"] = batch.out_cache_cont_end return self.forward_decode(**kwargs) elif forward_mode == ForwardMode.EXTEND: - kwargs["return_normalized_logprob"] = return_normalized_logprob + kwargs["return_logprob"] = return_logprob return self.forward_extend(**kwargs) elif forward_mode == ForwardMode.PREFILL: - kwargs["return_normalized_logprob"] = return_normalized_logprob + kwargs["return_logprob"] = return_logprob return self.forward_prefill(**kwargs) else: raise ValueError(f"Invaid forward mode: {forward_mode}") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 313b778c60..2b7e97925e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -132,8 +132,8 @@ async def generate_request(self, obj: GenerateReqInput): pixel_values=pixel_values, image_hash=image_hash, sampling_params=sampling_params, - return_normalized_logprob=obj.return_normalized_logprob, - normalized_logprob_start_len=obj.normalized_logprob_start_len, + return_logprob=obj.return_logprob, + logprob_start_len=obj.logprob_start_len, stream=obj.stream, ) self.send_to_router.send_pyobj(tokenized_obj) @@ -173,8 +173,8 @@ async def generate_request(self, obj: GenerateReqInput): pixel_values=pixel_values, image_hash=image_hash, sampling_params=sampling_params, - return_normalized_logprob=obj.return_normalized_logprob[i], - normalized_logprob_start_len=obj.normalized_logprob_start_len[i], + return_logprob=obj.return_logprob[i], + logprob_start_len=obj.logprob_start_len[i], stream=obj.stream, ) self.send_to_router.send_pyobj(tokenized_obj) From d2d6c32f9cf5769f590ec525110d5e1a1907bd7b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 12:29:16 +0000 Subject: [PATCH 2/5] return logprob --- README.md | 4 ++++ python/sglang/srt/managers/router/model_rpc.py | 2 +- test/srt/test_httpserver_decode.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ab3f51f092..83aa713e22 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ state = multi_turn_question.run( for m in state.messages(): print(m["role"], ":", m["content"]) + +print(state["answer_1"]) ``` ### Using Local Models @@ -99,6 +101,8 @@ state = multi_turn_question.run( for m in state.messages(): print(m["role"], ":", m["content"]) + +print(state["answer_1"]) ``` ### More Examples diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 5824d5d294..ed79569e27 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -356,7 +356,7 @@ def forward_fill_batch(self, batch: Batch): # Check finish condition reqs = batch.reqs pt = 0 - for i, req in enumerqte(reqs): + for i, req in enumerate(reqs): req.output_ids = [next_token_ids[i]] req.check_finished() diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index a79ffb6e43..21ec0be6af 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -26,6 +26,8 @@ "temperature": 0, "max_new_tokens": 32, }, + # "return_logprob": True, + # "logprob_start_len": 0, }, ) print(response.json()) From 75adf150e6edf4754825474cec21b517b51b09c4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 12:57:10 +0000 Subject: [PATCH 3/5] add an example --- examples/usage/{async.py => async_io.py} | 0 examples/usage/choices_logprob.py | 42 +++++++++++++++++++ python/sglang/backend/openai.py | 2 +- python/sglang/backend/runtime_endpoint.py | 10 +++-- python/sglang/lang/interpreter.py | 8 +++- .../sglang/srt/managers/router/model_rpc.py | 2 +- 6 files changed, 56 insertions(+), 8 deletions(-) rename examples/usage/{async.py => async_io.py} (100%) create mode 100644 examples/usage/choices_logprob.py diff --git a/examples/usage/async.py b/examples/usage/async_io.py similarity index 100% rename from examples/usage/async.py rename to examples/usage/async_io.py diff --git a/examples/usage/choices_logprob.py b/examples/usage/choices_logprob.py new file mode 100644 index 0000000000..ad0d823f62 --- /dev/null +++ b/examples/usage/choices_logprob.py @@ -0,0 +1,42 @@ +""" +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +""" + +import sglang as sgl + + +@sgl.function +def tool_use(s, question): + s += "To answer this question: " + question + ", " + s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + + +def main(): + # Run one case + question = "What is 5 + 5?" + state = tool_use.run(question) + print("questions:", question) + print("choice:", state["tool"]) + meta_info = state.get_meta_info("tool") + print("logprobs of choice 1", meta_info["prompt_logprob"][0]) + print("logprobs of choice 2", meta_info["prompt_logprob"][1]) + print('-' * 50) + + # Run a batch + questions = [ + "What is 5 + 6?", + "Who is Micheal Jordan?", + ] + states = tool_use.run_batch([{"question": q} for q in questions]) + for question, state in zip(questions, states): + print("questions:", question) + print("choice:", state["tool"]) + meta_info = state.get_meta_info("tool") + print("logprobs of choice 1", meta_info["prompt_logprob"][0]) + print("logprobs of choice 2", meta_info["prompt_logprob"][1]) + print('-' * 50) + + +if __name__ == "__main__": + sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) + main() diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 7bd763ce92..a0bed33dfc 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -209,7 +209,7 @@ def select( prompt_tokens.append(ret_token) decision = choices[np.argmax(scores)] - return decision, scores + return decision, scores, scores def openai_completion(client, is_chat=None, prompt=None, **kwargs): diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index fd059ebc66..672f4427e9 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -151,15 +151,17 @@ def select( "text": [s.text_ + c for c in choices], "sampling_params": {"max_new_tokens": 0}, "return_logprob": True, - "logprob_start_len": prompt_len, + "logprob_start_len": max(prompt_len - 2, 0), } self._add_images(s, data) res = http_request(self.base_url + "/generate", json=data) assert res.status_code == 200 - logps = [r["meta_info"]["normalized_prompt_logprob"] for r in res.json()] + obj = res.json() + normalized_prompt_logprob = [r["meta_info"]["normalized_prompt_logprob"] for r in obj] + prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj] - decision = choices[np.argmax(logps)] - return decision, logps + decision = choices[np.argmax(normalized_prompt_logprob)] + return decision, normalized_prompt_logprob, prompt_logprob def concatenate_and_append(self, src_rids: List[str], dst_rid: str): res = http_request( diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 23b8ca4bc8..91d62fca49 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -80,7 +80,7 @@ def run_program_batch( # Run all programs if num_threads == "auto": - num_threads = max(64, multiprocessing.cpu_count() * 8) + num_threads = max(96, multiprocessing.cpu_count() * 16) num_threads = min(num_threads, len(batch_arguments)) if num_threads == 1: @@ -364,10 +364,14 @@ def _execute_gen(self, expr: SglGen): self.stream_var_event[name].set() def _execute_select(self, expr: SglSelect): - decision, scores = self.backend.select(self, expr.choices, expr.temperature) + decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(self, expr.choices, expr.temperature) if expr.name is not None: name = expr.name self.variables[name] = decision + self.meta_info[name] = { + "normalized_prompt_logprob": normalized_prompt_logprob, + "prompt_logprob": prompt_logprob, + } self.variable_event[name].set() self.text_ += decision diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index ed79569e27..e830c77d1b 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -351,7 +351,7 @@ def forward_fill_batch(self, batch: Batch): next_token_ids = next_token_ids.cpu().tolist() else: next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) - normalized_logprobs = None + logprobs = normalized_logprobs = None # Check finish condition reqs = batch.reqs From e3fe1923d2969cc806919156b8476e70d2fa21e1 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 12:59:52 +0000 Subject: [PATCH 4/5] format --- examples/usage/choices_logprob.py | 2 +- python/sglang/backend/runtime_endpoint.py | 4 +++- python/sglang/lang/chat_template.py | 4 ++-- python/sglang/lang/interpreter.py | 4 +++- python/sglang/srt/layers/logits_processor.py | 6 ++++-- python/sglang/srt/managers/io_struct.py | 4 +--- python/sglang/srt/managers/router/infer_batch.py | 8 ++------ python/sglang/srt/managers/router/model_rpc.py | 2 +- python/sglang/srt/managers/router/model_runner.py | 4 +--- 9 files changed, 18 insertions(+), 20 deletions(-) diff --git a/examples/usage/choices_logprob.py b/examples/usage/choices_logprob.py index ad0d823f62..3b5254dd08 100644 --- a/examples/usage/choices_logprob.py +++ b/examples/usage/choices_logprob.py @@ -25,7 +25,7 @@ def main(): # Run a batch questions = [ "What is 5 + 6?", - "Who is Micheal Jordan?", + "Who is Michael Jordan?", ] states = tool_use.run_batch([{"question": q} for q in questions]) for question, state in zip(questions, states): diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 672f4427e9..9e5b7ed63d 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -157,7 +157,9 @@ def select( res = http_request(self.base_url + "/generate", json=data) assert res.status_code == 200 obj = res.json() - normalized_prompt_logprob = [r["meta_info"]["normalized_prompt_logprob"] for r in obj] + normalized_prompt_logprob = [ + r["meta_info"]["normalized_prompt_logprob"] for r in obj + ] prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj] decision = choices[np.argmax(normalized_prompt_logprob)] diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 9b7347fd37..9c3ffd5470 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from enum import Enum, auto -from typing import Callable, Dict, List, Tuple, Optional +from typing import Callable, Dict, List, Optional, Tuple class ChatTemplateStyle(Enum): @@ -111,7 +111,7 @@ def get_chat_template_by_model_path(model_path): "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, - stop_str=('<|im_end|>',) + stop_str=("<|im_end|>",), ) ) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 91d62fca49..4e70f942be 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -364,7 +364,9 @@ def _execute_gen(self, expr: SglGen): self.stream_var_event[name].set() def _execute_select(self, expr: SglSelect): - decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(self, expr.choices, expr.temperature) + decision, normalized_prompt_logprob, prompt_logprob = self.backend.select( + self, expr.choices, expr.temperature + ) if expr.name is not None: name = expr.name self.variables[name] = decision diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index dd5ad9f165..35dcd8e383 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -62,7 +62,9 @@ def forward(self, input_ids, hidden_states, weight, input_metadata): start.clamp_(min=0, max=logprobs.shape[0] - 1) end.clamp_(min=0, max=logprobs.shape[0] - 1) sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] - normalized_logprobs = sum_logp / ((input_metadata.extend_seq_lens - 1).clamp(min=1)) + normalized_logprobs = sum_logp / ( + (input_metadata.extend_seq_lens - 1).clamp(min=1) + ) last_logits = logits[last_index] return last_logits, logprobs, normalized_logprobs @@ -96,4 +98,4 @@ def forward(self, input_ids, hidden_states, weight, input_metadata): print("logprobs", logprobs) print("start", start) print("end", end) - print("sum_logp", sum_logp) \ No newline at end of file + print("sum_logp", sum_logp) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index fc1fcc8a21..c318d5f710 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -53,9 +53,7 @@ def post_init(self): if self.logprob_start_len is None: self.logprob_start_len = [0] * num elif not isinstance(self.logprob_start_len, list): - self.logprob_start_len = [ - self.logprob_start_len - ] * num + self.logprob_start_len = [self.logprob_start_len] * num @dataclass diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index d665d86bc8..f9cf9a6fe5 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -311,9 +311,7 @@ def filter_batch(self, unfinished_indices: List[int]): self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None - self.return_logprob = any( - req.return_logprob for req in self.reqs - ) + self.return_logprob = any(req.return_logprob for req in self.reqs) for item in [ "temperatures", @@ -337,9 +335,7 @@ def merge(self, other): [self.position_ids_offsets, other.position_ids_offsets] ) self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None - self.return_logprob = any( - req.return_logprob for req in self.reqs - ) + self.return_logprob = any(req.return_logprob for req in self.reqs) for item in [ "temperatures", diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index e830c77d1b..0082c239ed 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -361,7 +361,7 @@ def forward_fill_batch(self, batch: Batch): req.check_finished() if logprobs is not None: - req.logprob = logprobs[pt:pt + req.extend_input_len-1] + req.logprob = logprobs[pt : pt + req.extend_input_len - 1] req.normalized_logprob = normalized_logprobs[i] pt += req.extend_input_len diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 70cf4bdd31..bd035da220 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -436,9 +436,7 @@ def forward_extend_multi_modal( image_offsets, ) - def forward( - self, batch: Batch, forward_mode: ForwardMode, return_logprob=False - ): + def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False): if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: kwargs = { "input_ids": batch.input_ids, From 50f824ba78299dab446fea77ad638d2f3e5ec45d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 13:06:05 +0000 Subject: [PATCH 5/5] update --- python/sglang/srt/layers/logits_processor.py | 4 ++-- python/sglang/srt/managers/router/model_rpc.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 35dcd8e383..0dbbc31da4 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -33,7 +33,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata): if self.tp_size > 1: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size] - return last_logits, None, None + return last_logits, (None, None) else: assert input_metadata.forward_mode != ForwardMode.DECODE last_index = ( @@ -67,7 +67,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata): ) last_logits = logits[last_index] - return last_logits, logprobs, normalized_logprobs + return last_logits, (logprobs, normalized_logprobs) if __name__ == "__main__": diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 0082c239ed..8978ce43f5 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -339,7 +339,7 @@ def forward_fill_batch(self, batch: Batch): if batch.extend_num_tokens != 0: # Forward - logits, logprobs, normalized_logprobs = self.model_runner.forward( + logits, (logprobs, normalized_logprobs) = self.model_runner.forward( batch, ForwardMode.EXTEND, batch.return_logprob ) # print("extend logits", logits)