Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support min-p sampling #1167

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/en/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ temperature: float = 1.0,
top_p: float = 1.0,
# Top-k sampling
top_k: int = -1,
# Min-p sampling
min_p: float = 0.0,
# Whether to ignore EOS token.
ignore_eos: bool = False,
# Whether to skip the special tokens during detokenization.
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def gen(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
Expand Down Expand Up @@ -103,6 +104,7 @@ def gen(
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
Expand All @@ -123,6 +125,7 @@ def gen_int(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
Expand All @@ -139,6 +142,7 @@ def gen_int(
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
Expand All @@ -159,6 +163,7 @@ def gen_string(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
Expand All @@ -175,6 +180,7 @@ def gen_string(
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/lang/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def run(
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
Expand All @@ -145,6 +146,7 @@ def run(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
Expand All @@ -160,6 +162,7 @@ def run_batch(
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
Expand All @@ -178,6 +181,7 @@ def run_batch(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def _resolve_sampling_params(self, sampling_params):
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SglSamplingParams:
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
Expand All @@ -42,6 +43,7 @@ def clone(self):
self.temperature,
self.top_p,
self.top_k,
self.min_p,
self.frequency_penalty,
self.presence_penalty,
self.ignore_eos,
Expand Down Expand Up @@ -114,6 +116,7 @@ def to_srt_kwargs(self):
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
Expand Down Expand Up @@ -149,6 +152,7 @@ def run(
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
Expand All @@ -169,6 +173,7 @@ def run(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
Expand All @@ -190,6 +195,7 @@ def run_batch(
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
Expand Down Expand Up @@ -228,6 +234,7 @@ def run_batch(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
Expand Down Expand Up @@ -408,6 +415,7 @@ def __init__(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
Expand All @@ -428,6 +436,7 @@ def __init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
Expand Down
41 changes: 32 additions & 9 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

import torch
import torch.distributed as dist
from flashinfer.sampling import top_k_top_p_sampling_from_probs
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from vllm.distributed import get_tensor_model_parallel_group

import sglang.srt.sampling.penaltylib as penaltylib
Expand Down Expand Up @@ -339,6 +344,7 @@ class ScheduleBatch:
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None

Expand Down Expand Up @@ -403,6 +409,9 @@ def batch_sampling_params(self, vocab_size):
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
self.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)

# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
Expand Down Expand Up @@ -701,6 +710,7 @@ def filter_batch(self, unfinished_indices: List[int]):
"temperatures",
"top_ps",
"top_ks",
"min_ps",
"logit_bias",
]:
self_val = getattr(self, item, None)
Expand Down Expand Up @@ -730,6 +740,7 @@ def merge(self, other: "ScheduleBatch"):
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
Expand Down Expand Up @@ -780,13 +791,20 @@ def sample(self, logits: torch.Tensor):
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
if self.min_ps.any():
probs = top_k_renorm_prob(probs, self.top_ks)
probs = top_p_renorm_prob(probs, self.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, self.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
else:
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps, self.min_ps
)

if not torch.all(success):
Expand All @@ -810,17 +828,22 @@ def sample(self, logits: torch.Tensor):
return batch_next_token_ids


def top_k_top_p_sampling_from_probs_torch(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
):
"""A top-k and top-k sampling implementation with native pytorch operations."""
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
Expand All @@ -42,6 +43,7 @@ def __init__(
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
Expand Down Expand Up @@ -69,6 +71,8 @@ def verify(self):
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
Expand Down
Loading