Skip to content

Commit

Permalink
Add alternative choices selection methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AidanCooper committed Jul 30, 2024
1 parent 1edd4e0 commit 2310247
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 48 deletions.
77 changes: 77 additions & 0 deletions docs/en/choices_methods.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Choices Methods in SGLang
This doc describes the choices methods supported by SGLang.

The optional `choices_method` arg determines how options supplied to SGLang's `choices` primitive are selected.

## Methods

### Token Length Normalized

Token length normalized is the default SGLang choices method. It selects the option with the highest average logprob across all of its tokens.

Usage example (alternatively, simply omit the `choices_method` arg):
```python
@sgl.function
def example(s):
s += sgl.user("What is the capital of France?")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["London", "Paris", "Berlin"],
choices_method=sgl.token_length_normalized,
)
)
```


This can perform poorly if an option contains many tokens, where its later tokens are predicted with high confidence based on its earlier tokens. For instance, even strong models will fail the above example if the specified options are `["Paris", "Antidisestablishmentarianism"]`.

### Greedy Token Selection

Greedy token selection simply selects the option with the highest logprob for its initial token. For overlapping options where one option is a subset of a longer option, the logprobs of the shorter option are extended using its average logprob for comparison against the longer option.

Usage example:
```python
@sgl.function
def example(s):
s += sgl.user("What is the capital of France?")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["London", "Paris", "Berlin"],
choices_method=sgl.greedy_token_selection,
)
)
```

This can perform poorly if an option misleads the model down a bad path based on an attractive initial token. For instance, greedy selection will result in an incorrect response for this example:
```python
@sgl.function
def us_president_example(s):
s += sgl.user("Name a US president.")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["Donald Duck", "Millard Fillmore"],
choices_method=sgl.greedy_token_selection,
)
)
```

### Unconditional Likelihood Normalized

Unconditional likelihood normalized selects the option with the highest average token logprob once normalized by the unconditional token logprobs, as described in [this EleutherAI blogpost](https://blog.eleuther.ai/multiple-choice-normalization/). This method incurs an additional LLM call to obtain the unconditional likelihoods.

Usage example:
```python
@sgl.function
def example(s):
s += sgl.user("What is the capital of France?")
s += sgl.assistant(
sgl.gen(
"answer",
choices=["London", "Paris", "Berlin"],
choices_method=sgl.unconditional_likelihood_normalized,
)
)
```
8 changes: 8 additions & 0 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
user_end,
video,
)
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)

# SGLang DSL APIs
__all__ = [
Expand All @@ -45,6 +50,9 @@
"user_begin",
"user_end",
"video",
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
]

# Global Configurations
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.choices import ChoicesSamplingMethod
from sglang.lang.ir import (
SglExpr,
SglExprList,
Expand Down Expand Up @@ -73,12 +74,15 @@ def gen(
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""

if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
return SglSelect(
name, choices, 0.0 if temperature is None else temperature, choices_method
)

# check regex is valid
if regex is not None:
Expand Down Expand Up @@ -186,9 +190,10 @@ def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: Optional[ChoicesSamplingMethod] = None,
):
assert choices is not None
return SglSelect(name, choices, temperature)
return SglSelect(name, choices, temperature, choices_method)


def _role_common(name: str, expr: Optional[SglExpr] = None):
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/lang/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, List, Optional, Union

from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

Expand Down Expand Up @@ -64,6 +65,7 @@ def select(
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: Optional[ChoicesSamplingMethod] = None,
):
raise NotImplementedError()

Expand Down
8 changes: 7 additions & 1 deletion python/sglang/lang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

Expand Down Expand Up @@ -296,12 +297,17 @@ def select(
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: Optional[ChoicesSamplingMethod],
) -> ChoicesDecision:
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
)
if choices_method:
raise NotImplementedError(
"choices_method is not supported for OpenAI backend. Leave as None."
)

n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
Expand Down
75 changes: 46 additions & 29 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import json
from typing import List, Optional

import numpy as np

from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.choices import (
ChoicesDecision,
ChoicesSamplingMethod,
token_length_normalized,
)
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request


class RuntimeEndpoint(BaseBackend):

def __init__(
self,
base_url: str,
Expand Down Expand Up @@ -216,21 +220,15 @@ def select(
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: Optional[ChoicesSamplingMethod],
) -> ChoicesDecision:
assert temperature <= 1e-5
choices_method = choices_method or token_length_normalized

# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
prompt_len = res.json()["meta_info"]["prompt_tokens"]
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]

# Compute logprob
data = {
Expand All @@ -239,28 +237,35 @@ def select(
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
obj = self._generate_http_request(s, data)

normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]

return (
decision,
normalized_prompt_logprobs,
input_token_logprobs,
output_token_logprobs,
# Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
data = {
"input_ids": input_ids,
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
}
obj = self._generate_http_request(s, data)
unconditional_token_logprobs = [
r["meta_info"]["input_token_logprobs"] for r in obj
]
else:
unconditional_token_logprobs = None

return choices_method(
choices=choices,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
output_token_logprobs=output_token_logprobs,
unconditional_token_logprobs=unconditional_token_logprobs,
)

def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
Expand All @@ -273,6 +278,18 @@ def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
)
self._assert_success(res)

def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()

def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
Expand Down
Loading

0 comments on commit 2310247

Please sign in to comment.