Skip to content

Commit

Permalink
Fix the case when max_new_tokens is too large (#1025)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Aug 11, 2024
1 parent 7b6a533 commit d785412
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
from typing import Dict, List, Optional

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode

# Clip the max new tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
CLIP_MAX_NEW_TOKENS = 4096


class PolicyScheduler:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
Expand Down Expand Up @@ -98,7 +102,7 @@ def __init__(
tree_cache: BasePrefixCache,
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: int,
rem_chunk_tokens: Optional[int],
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
Expand Down Expand Up @@ -126,7 +130,11 @@ def remove_running_tokens(
):
self.rem_total_tokens -= sum(
[
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* new_token_ratio
for r in running_batch.reqs
]
)
Expand All @@ -151,7 +159,11 @@ def add_inflight_req(self, req: Req):
self._prefill_one_req(
len(req.prefix_indices),
req.extend_input_len,
req.sampling_params.max_new_tokens if not truncated else 0,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)

# Return if chunked prefill not finished
Expand All @@ -168,7 +180,9 @@ def _lock_node(self, last_node: TreeNode):
self.rem_total_tokens += delta

def add_one_req(self, req: Req):
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)

Expand All @@ -191,7 +205,9 @@ def add_one_req(self, req: Req):
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len, input_tokens, req.sampling_params.max_new_tokens
prefix_len,
input_tokens,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
Expand Down

0 comments on commit d785412

Please sign in to comment.