Skip to content

Commit

Permalink
Add max_prefill_num_token into server arguments (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Feb 3, 2024
1 parent 67be11c commit e095b16
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def exposed_init_model(
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len, self.max_total_num_token // 6
self.model_config.context_len,
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def __init__(
load_format: str = "auto",
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
mem_fraction_static: float = 0.9,
mem_fraction_static: float = ServerArgs.mem_fraction_static,
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
tp_size: int = 1,
model_mode: List[str] = (),
schedule_heuristic: str = "lpm",
Expand All @@ -451,6 +452,7 @@ def __init__(
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static,
max_prefill_num_token=max_prefill_num_token,
tp_size=tp_size,
model_mode=model_mode,
schedule_heuristic=schedule_heuristic,
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ServerArgs:
chat_template: Optional[str] = None
trust_remote_code: bool = True
mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None
tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
Expand Down Expand Up @@ -109,6 +110,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.mem_fraction_static,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
)
parser.add_argument(
"--max-prefill-num-token",
type=int,
default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
)
parser.add_argument(
"--tp-size",
type=int,
Expand Down

0 comments on commit e095b16

Please sign in to comment.