Skip to content

Commit

Permalink
Merge branch 'sgl-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
haichuan1221 authored Jul 30, 2024
2 parents 096e0ce + ae5c0fc commit 3d65cb2
Show file tree
Hide file tree
Showing 45 changed files with 180 additions and 108 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/pr-e2e-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: PR E2E Test

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:

jobs:
gpu-job:
runs-on: self-hosted
env:
CUDA_VISIBLE_DEVICES: 6

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
pip install --upgrade transformers
- name: Launch server and run benchmark
run: |
python3 -m sglang.launch_server --model /home/lmzheng/zhyncs/Meta-Llama-3.1-8B-Instruct --port 8413 &
echo "Waiting for server to start..."
for i in {1..60}; do
if curl -s http://127.0.0.1:8413/health; then
echo "Server is up!"
break
fi
if [ $i -eq 60 ]; then
echo "Server failed to start within 60 seconds"
exit 1
fi
sleep 1
done
python3 -m sglang.bench_serving --backend sglang --port 8413
echo "Stopping server..."
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}')
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/

### Method 2: From source
```
git clone https://github.com/sgl-project/sglang.git
# Use the stable rel branch
git clone -b rel https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
Expand Down
6 changes: 3 additions & 3 deletions docs/en/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
If OOM happens during decoding, try to decrease `--max-running-requests`.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.

### (Minor) Tune `--schedule-heuristic`
If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match.
### (Minor) Tune `--schedule-policy`
If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match.
When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve.
you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve.
61 changes: 31 additions & 30 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SGL API Components

from sglang.api import (
Runtime,
assistant,
Expand All @@ -22,46 +23,46 @@
video,
)

# Global Configurations
from sglang.global_config import global_config

# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport
from sglang.version import __version__

Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")


# public APIs management
# SGLang DSL APIs
__all__ = [
"global_config",
"Anthropic",
"LiteLLM",
"OpenAI",
"RuntimeEndpoint",
"VertexAI",
"function",
"Runtime",
"set_default_backend",
"assistant",
"assistant_begin",
"assistant_end",
"flush_cache",
"get_server_args",
"function",
"gen",
"gen_int",
"gen_string",
"get_server_args",
"image",
"video",
"select",
"set_default_backend",
"system",
"system_begin",
"system_end",
"user",
"assistant",
"user_begin",
"user_end",
"assistant_begin",
"assistant_end",
"system_begin",
"system_end",
"video",
]

# Global Configurations
from sglang.global_config import global_config

__all__ += ["global_config"]

from sglang.version import __version__

__all__ += ["__version__"]

# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport

Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")

__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
4 changes: 2 additions & 2 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
import torch.distributed as dist

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ async def async_request_trt_llm(
"min_length": request_func_input.output_len,
"end_id": 1048576,
}
if args.disable_ignore_eos:
del payload["min_length"]
del payload["end_id"]
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

Expand Down Expand Up @@ -149,7 +152,7 @@ async def async_request_openai_completions(
"best_of": 1,
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": True,
"ignore_eos": not args.disable_ignore_eos,
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}

Expand Down Expand Up @@ -969,6 +972,11 @@ def set_ulimit(target_soft_limit=65535):
action="store_true",
help="Disable streaming mode.",
)
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignoring EOS.",
)

set_ulimit()

Expand Down
19 changes: 17 additions & 2 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def _execute_select(self, expr: SglSelect):
"output_token_logprobs": output_token_logprobs,
}
self.variable_event[name].set()
self.stream_var_event[name].set()
self.text_ += decision

def _execute_variable(self, expr: SglVariable):
Expand Down Expand Up @@ -778,7 +779,14 @@ def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
event = None
while not event:
if var_name in self.stream_executor.stream_var_event:
event = self.stream_executor.stream_var_event[var_name]
if self.stream_executor.is_finished:
yield ""
return

while True:
event.wait()
event.clear()
Expand Down Expand Up @@ -813,7 +821,14 @@ async def text_async_iter(
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
event = None
while not event:
if var_name in self.stream_executor.stream_var_event:
event = self.stream_executor.stream_var_event[var_name]
if self.stream_executor.is_finished:
yield ""
return

while True:
await loop.run_in_executor(None, event.wait)
event.clear()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
tensor_model_parallel_all_gather,
)

from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import (
from sglang.srt.model_executor.model_runner import (
ForwardMode,
InputMetadata,
global_server_args_dict,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import triton
import triton.language as tl

from sglang.srt.managers.controller.infer_batch import global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict

if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import zmq

from sglang.srt.managers.controller.manager_single import (
from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.io_struct import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import zmq

from sglang.srt.managers.controller.tp_worker import (
from sglang.srt.managers.tp_worker import (
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import zmq.asyncio

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,47 @@
limitations under the License.
"""

"""Request scheduler heuristic."""
"""Request policy scheduler"""

import random
from collections import defaultdict


class ScheduleHeuristic:
class PolicyScheduler:
def __init__(
self,
schedule_heuristic,
policy,
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and schedule_heuristic == "lpm":
if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled.
schedule_heuristic = "fcfs"
policy = "fcfs"

self.schedule_heuristic = schedule_heuristic
self.policy = policy
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache

def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm":
if self.policy == "lpm":
# longest prefix match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.schedule_heuristic == "fcfs":
elif self.policy == "fcfs":
# first come first serve
return waiting_queue
elif self.schedule_heuristic == "lof":
elif self.policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.schedule_heuristic == "random":
elif self.policy == "random":
random.shuffle(waiting_queue)
return waiting_queue
elif self.schedule_heuristic == "dfs-weight":
elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
Expand All @@ -70,7 +70,7 @@ def get_priority_queue(self, waiting_queue):
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
raise ValueError(f"Unknown schedule_policy: {self.policy}")

def calc_weight(self, cur_node, node_to_weight):
for child in cur_node.children.values():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

Expand Down
Loading

0 comments on commit 3d65cb2

Please sign in to comment.