Skip to content

Commit

Permalink
Fix max_tokens for OpenAI chat completion API (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 27, 2024
1 parent de854fb commit f95e661
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
34 changes: 18 additions & 16 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,21 @@ def __init__(
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
self.max_running_requests = (
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
)
self.max_running_requests = min(
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
(
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
),
self.model_runner.req_to_token_pool.size - 1,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
self.max_req_input_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
)
set_random_seed(server_args.random_seed)

# Print info
Expand Down Expand Up @@ -295,18 +299,16 @@ def handle_generate_request(
)

# Truncate prompts that are too long
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
req.sampling_params.max_new_tokens or 1 << 30,
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
if req.sampling_params.max_new_tokens < 0:
req.origin_input_ids = req.origin_input_ids[
: self.max_total_num_tokens - 128
]
logger.error("Request longer than memory pool size, truncated!!!")

self.forward_queue.append(req)

def get_new_prefill_batch(self) -> Optional[Batch]:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def verify(self):
raise ValueError(
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
)
if self.max_new_tokens < 0:
raise ValueError(
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
)
if self.max_new_tokens is not None:
if self.max_new_tokens < 0:
raise ValueError(
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
)

def normalize(self, tokenizer):
# Process stop strings
Expand Down

0 comments on commit f95e661

Please sign in to comment.