From daecfd38057b9d7ef9438f855293171f7b96422f Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 1 Jul 2024 00:24:13 +0000 Subject: [PATCH] update run_batch interface and max_prefill_tokens --- docs/test_process.md | 9 ++------- python/sglang/lang/ir.py | 14 +++++++++++++- python/sglang/srt/managers/controller/tp_worker.py | 10 ++++------ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/docs/test_process.md b/docs/test_process.md index 1c1bc0f2c5..e7aff5b5a0 100644 --- a/docs/test_process.md +++ b/docs/test_process.md @@ -1,13 +1,8 @@ ## SRT Unit Tests -### Low-level API +### Latency Alignment ``` -cd sglang/test/srt/model - -python3 test_llama_low_api.py -python3 test_llama_extend.py -python3 test_llava_low_api.py -python3 bench_llama_low_api.py +python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256 ``` ### High-level API diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 0567689e0e..eb7d87c24d 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -120,6 +120,7 @@ def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None): argspec = inspect.getfullargspec(func) assert argspec.args[0] == "s", 'The first argument must be "s"' self.arg_names = argspec.args[1:] + self.arg_defaults = argspec.defaults if argspec.defaults is not None else [] def bind(self, **kwargs): assert all(key in self.arg_names for key in kwargs) @@ -178,7 +179,18 @@ def run_batch( assert isinstance(batch_kwargs, (list, tuple)) if len(batch_kwargs) == 0: return [] - assert isinstance(batch_kwargs[0], dict) + if not isinstance(batch_kwargs[0], dict): + num_programs = len(batch_kwargs) + # change the list of argument values to dict of arg_name -> arg_value + batch_kwargs = [ + {self.arg_names[i]: v for i, v in enumerate(arg_values)} + for arg_values in batch_kwargs + if isinstance(arg_values, (list, tuple)) and + len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names) + ] + # Ensure to raise an exception if the number of arguments mismatch + if len(batch_kwargs) != num_programs: + raise Exception("Given arguments mismatch the SGL function signature") default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 7ee1e50794..a788118ec9 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -98,10 +98,7 @@ def __init__( ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_prefill_tokens = ( - max( - self.model_config.context_len, - min(self.max_total_num_tokens // 6, 32768), - ) + 4096 if server_args.max_prefill_tokens is None else server_args.max_prefill_tokens ) @@ -371,8 +368,9 @@ def get_new_fill_batch(self) -> Optional[Batch]: if ( req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size - and req.extend_input_len + new_batch_input_tokens - < self.max_prefill_tokens + and (req.extend_input_len + new_batch_input_tokens + <= self.max_prefill_tokens + or len(can_run_list) == 0) ): delta = self.tree_cache.inc_lock_ref(req.last_node) available_size += delta