Skip to content

Commit

Permalink
Fuse top_k and top_k in the sampler (#1457)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 18, 2024
1 parent 1acccb3 commit 7f24ea9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/en/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class GenerateReqInput:
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ def forward(
logits = logits.next_token_logits

# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
probs = logits[:] = torch.softmax(logits, dim=-1)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits

if torch.any(torch.isnan(probs)):
logger.warning("Detected errors during sampling! NaN in the probability.")
Expand All @@ -53,7 +56,11 @@ def forward(
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)

if not torch.all(success):
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def init_memory_pool(
)

self.req_to_token_pool = ReqToTokenPool(
max_num_reqs,
self.model_config.context_len + 8,
max_num_reqs + 1,
self.model_config.context_len + 4,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
Expand Down

0 comments on commit 7f24ea9

Please sign in to comment.