Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update run_batch interface and max_prefill_tokens #574

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions docs/test_process.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
## SRT Unit Tests

### Low-level API
### Latency Alignment
```
cd sglang/test/srt/model

python3 test_llama_low_api.py
python3 test_llama_extend.py
python3 test_llava_low_api.py
python3 bench_llama_low_api.py
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
```

### High-level API
Expand Down
14 changes: 13 additions & 1 deletion python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []

def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
Expand Down Expand Up @@ -178,7 +179,18 @@ def run_batch(
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
if not isinstance(batch_kwargs[0], dict):
num_programs = len(batch_kwargs)
# change the list of argument values to dict of arg_name -> arg_value
batch_kwargs = [
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
for arg_values in batch_kwargs
if isinstance(arg_values, (list, tuple)) and
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
]
# Ensure to raise an exception if the number of arguments mismatch
if len(batch_kwargs) != num_programs:
raise Exception("Given arguments mismatch the SGL function signature")

default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
Expand Down
10 changes: 4 additions & 6 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ def __init__(
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
max(
self.model_config.context_len,
min(self.max_total_num_tokens // 6, 32768),
)
4096
if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens
)
Expand Down Expand Up @@ -371,8 +368,9 @@ def get_new_fill_batch(self) -> Optional[Batch]:
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_tokens
and (req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0)
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
Expand Down