Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return logprob for choices #87

Merged
merged 5 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ state = multi_turn_question.run(

for m in state.messages():
print(m["role"], ":", m["content"])

print(state["answer_1"])
```

### Using Local Models
Expand Down Expand Up @@ -99,6 +101,8 @@ state = multi_turn_question.run(

for m in state.messages():
print(m["role"], ":", m["content"])

print(state["answer_1"])
```

### More Examples
Expand Down
4 changes: 2 additions & 2 deletions docs/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
return_logprob: Optional[Union[List[bool], bool]] = None
logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False
```

Expand Down
File renamed without changes.
42 changes: 42 additions & 0 deletions examples/usage/choices_logprob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""

import sglang as sgl


@sgl.function
def tool_use(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])


def main():
# Run one case
question = "What is 5 + 5?"
state = tool_use.run(question)
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print('-' * 50)

# Run a batch
questions = [
"What is 5 + 6?",
"Who is Michael Jordan?",
]
states = tool_use.run_batch([{"question": q} for q in questions])
for question, state in zip(questions, states):
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print('-' * 50)


if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
main()
2 changes: 1 addition & 1 deletion python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def select(
prompt_tokens.append(ret_token)

decision = choices[np.argmax(scores)]
return decision, scores
return decision, scores, scores


def openai_completion(client, is_chat=None, prompt=None, **kwargs):
Expand Down
14 changes: 9 additions & 5 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,20 @@ def select(
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0},
"return_normalized_logprob": True,
"normalized_logprob_start_len": prompt_len,
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()]
obj = res.json()
normalized_prompt_logprob = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]

decision = choices[np.argmax(logps)]
return decision, logps
decision = choices[np.argmax(normalized_prompt_logprob)]
return decision, normalized_prompt_logprob, prompt_logprob

def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple, Optional
from typing import Callable, Dict, List, Optional, Tuple


class ChatTemplateStyle(Enum):
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_chat_template_by_model_path(model_path):
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=('<|im_end|>',)
stop_str=("<|im_end|>",),
)
)

Expand Down
10 changes: 8 additions & 2 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_program_batch(

# Run all programs
if num_threads == "auto":
num_threads = max(64, multiprocessing.cpu_count() * 8)
num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments))

if num_threads == 1:
Expand Down Expand Up @@ -364,10 +364,16 @@ def _execute_gen(self, expr: SglGen):
self.stream_var_event[name].set()

def _execute_select(self, expr: SglSelect):
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(
self, expr.choices, expr.temperature
)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprob": normalized_prompt_logprob,
"prompt_logprob": prompt_logprob,
}
self.variable_event[name].set()
self.text_ += decision

Expand Down
39 changes: 16 additions & 23 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, config):
self.tp_size = get_tensor_model_parallel_world_size()

def forward(self, input_ids, hidden_states, weight, input_metadata):
if not input_metadata.return_normalized_logprob:
if not input_metadata.return_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
Expand All @@ -33,7 +33,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata):
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, None
return last_logits, (None, None)
else:
assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = (
Expand All @@ -51,30 +51,23 @@ def forward(self, input_ids, hidden_states, weight, input_metadata):
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)

normalized_logprobs = compute_normalized_logprobs(
all_logprobs,
input_ids,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)

start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)

last_logits = logits[last_index]
return last_logits, normalized_logprobs


def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)

start = start_loc.clone()
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
return sum_logp / ((seq_lens - 1).clamp(min=1))
return last_logits, (logprobs, normalized_logprobs)


if __name__ == "__main__":
Expand Down
34 changes: 16 additions & 18 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
return_logprob: Optional[Union[List[bool], bool]] = None
logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False

def post_init(self):
Expand All @@ -23,10 +23,10 @@ def post_init(self):
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_normalized_logprob is None:
self.return_normalized_logprob = False
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = 0
if self.return_logprob is None:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = 0
else:
num = len(self.text)

Expand All @@ -45,17 +45,15 @@ def post_init(self):
else:
assert isinstance(self.rid, list)

if self.return_normalized_logprob is None:
self.return_normalized_logprob = [False] * num
elif not isinstance(self.return_normalized_logprob, list):
self.return_normalized_logprob = [self.return_normalized_logprob] * num
if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num

if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = [0] * num
elif not isinstance(self.normalized_logprob_start_len, list):
self.normalized_logprob_start_len = [
self.normalized_logprob_start_len
] * num
if self.logprob_start_len is None:
self.logprob_start_len = [0] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num


@dataclass
Expand All @@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values: List[float]
image_hash: int
sampling_params: SamplingParams
return_normalized_logprob: bool
normalized_logprob_start_len: int
return_logprob: bool
logprob_start_len: int
stream: bool


Expand Down
23 changes: 10 additions & 13 deletions python/sglang/srt/managers/router/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,20 @@ def __init__(self, rid):
self.pixel_values = None
self.image_offset = 0
self.sampling_params = None
self.return_normalized_logprob = False
self.normalized_logprob_start_len = 0
self.return_logprob = False
self.logprob_start_len = 0
self.stream = False

self.tokenizer = None
self.finished = False
self.finish_reason = None
self.hit_stop_str = None

self.adjust_input_len = 0
self.extend_input_len = 0
self.prefix_indices = []
self.last_node = None

self.logprob = None
self.normalized_logprob = None

# for constrained decoding
Expand Down Expand Up @@ -99,7 +100,7 @@ class Batch:
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
return_normalized_logprob: bool = False
return_logprob: bool = False

# for multimodal
pixel_values: List[torch.Tensor] = None
Expand All @@ -119,14 +120,14 @@ class Batch:

@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs)
return_logprob = any(req.return_logprob for req in reqs)

return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_normalized_logprob=return_normalized_logprob,
return_logprob=return_logprob,
)

def is_empty(self):
Expand Down Expand Up @@ -257,7 +258,7 @@ def retract_decode(self):
self.tree_cache.dec_ref_counter(req.last_node)
req.prefix_indices = None
req.last_node = None
req.adjust_input_len = 0
req.extend_input_len = 0
req.output_ids = []
# TODO: apply more fine-grained retraction

Expand Down Expand Up @@ -310,9 +311,7 @@ def filter_batch(self, unfinished_indices: List[int]):
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
self.return_logprob = any(req.return_logprob for req in self.reqs)

for item in [
"temperatures",
Expand All @@ -336,9 +335,7 @@ def merge(self, other):
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
self.return_logprob = any(req.return_logprob for req in self.reqs)

for item in [
"temperatures",
Expand Down
Loading