From 7f24ea95c344ae85c6633d47083722ebc5377f07 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 18 Sep 2024 04:35:35 -0700 Subject: [PATCH] Fuse top_k and top_k in the sampler (#1457) --- docs/en/sampling_params.md | 1 + python/sglang/srt/layers/sampler.py | 11 +++++++++-- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md index 0e1c13e4bdf..690b206d3a7 100644 --- a/docs/en/sampling_params.md +++ b/docs/en/sampling_params.md @@ -23,6 +23,7 @@ class GenerateReqInput: # Whether to return logprobs. return_logprob: Optional[Union[List[bool], bool]] = None # The start location of the prompt for return_logprob. + # By default, this value is "-1", which means it will only return logprobs for output tokens. logprob_start_len: Optional[Union[List[int], int]] = None # The number of top logprobs to return. top_logprobs_num: Optional[Union[List[int], int]] = None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 88ae1322a21..ad7f0a1f3c4 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -31,8 +31,11 @@ def forward( logits = logits.next_token_logits # Post process logits + logits = logits.contiguous() logits.div_(sampling_info.temperatures) - probs = logits[:] = torch.softmax(logits, dim=-1) + probs = torch.softmax(logits, dim=-1) + logits = None + del logits if torch.any(torch.isnan(probs)): logger.warning("Detected errors during sampling! NaN in the probability.") @@ -53,7 +56,11 @@ def forward( ) else: batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps + probs, + uniform_samples, + sampling_info.top_ks, + sampling_info.top_ps, + filter_apply_order="joint", ) if not torch.all(success): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index dc8dcd4ed39..049a43840eb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -400,8 +400,8 @@ def init_memory_pool( ) self.req_to_token_pool = ReqToTokenPool( - max_num_reqs, - self.model_config.context_len + 8, + max_num_reqs + 1, + self.model_config.context_len + 4, ) if ( self.model_config.attention_arch == AttentionArch.MLA