Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code structure refactor #807

Merged
merged 9 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/en/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
If OOM happens during decoding, try to decrease `--max-running-requests`.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.

### (Minor) Tune `--schedule-heuristic`
If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match.
### (Minor) Tune `--schedule-policy`
If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match.
When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve.
you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve.
61 changes: 31 additions & 30 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SGL API Components

from sglang.api import (
Runtime,
assistant,
Expand All @@ -22,46 +23,46 @@
video,
)

# Global Configurations
from sglang.global_config import global_config

# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport
from sglang.version import __version__

Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")


# public APIs management
# SGLang DSL APIs
__all__ = [
"global_config",
"Anthropic",
"LiteLLM",
"OpenAI",
"RuntimeEndpoint",
"VertexAI",
"function",
"Runtime",
"set_default_backend",
"assistant",
"assistant_begin",
"assistant_end",
"flush_cache",
"get_server_args",
"function",
"gen",
"gen_int",
"gen_string",
"get_server_args",
"image",
"video",
"select",
"set_default_backend",
"system",
"system_begin",
"system_end",
"user",
"assistant",
"user_begin",
"user_end",
"assistant_begin",
"assistant_end",
"system_begin",
"system_end",
"video",
]

# Global Configurations
from sglang.global_config import global_config

__all__ += ["global_config"]

from sglang.version import __version__

__all__ += ["__version__"]

# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport

Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")

__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
4 changes: 2 additions & 2 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
import torch.distributed as dist

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
tensor_model_parallel_all_gather,
)

from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import (
from sglang.srt.model_executor.model_runner import (
ForwardMode,
InputMetadata,
global_server_args_dict,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import triton
import triton.language as tl

from sglang.srt.managers.controller.infer_batch import global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict

if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import zmq

from sglang.srt.managers.controller.manager_single import (
from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.io_struct import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import zmq

from sglang.srt.managers.controller.tp_worker import (
from sglang.srt.managers.tp_worker import (
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import zmq.asyncio

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,47 @@
limitations under the License.
"""

"""Request scheduler heuristic."""
"""Request policy scheduler"""

import random
from collections import defaultdict


class ScheduleHeuristic:
class PolicyScheduler:
def __init__(
self,
schedule_heuristic,
policy,
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and schedule_heuristic == "lpm":
if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled.
schedule_heuristic = "fcfs"
policy = "fcfs"

self.schedule_heuristic = schedule_heuristic
self.policy = policy
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache

def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm":
if self.policy == "lpm":
# longest prefix match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.schedule_heuristic == "fcfs":
elif self.policy == "fcfs":
# first come first serve
return waiting_queue
elif self.schedule_heuristic == "lof":
elif self.policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.schedule_heuristic == "random":
elif self.policy == "random":
random.shuffle(waiting_queue)
return waiting_queue
elif self.schedule_heuristic == "dfs-weight":
elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
Expand All @@ -70,7 +70,7 @@ def get_priority_queue(self, waiting_queue):
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
raise ValueError(f"Unknown schedule_policy: {self.policy}")

def calc_weight(self, cur_node, node_to_weight):
for child in cur_node.children.values():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.policy_scheduler import PolicyScheduler
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward

# Chunked prefill
Expand Down Expand Up @@ -150,8 +150,8 @@ def __init__(
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = ScheduleHeuristic(
self.schedule_heuristic,
self.scheduler = PolicyScheduler(
self.schedule_policy,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Flush the KV cache.

Usage:
python3 -m sglang.srt.flush_cache --url http://localhost:30000
python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
"""

import argparse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LogitsMetadata,
LogitsProcessor,
)
from sglang.srt.managers.controller.infer_batch import (
from sglang.srt.managers.schedule_batch import (
Batch,
ForwardMode,
InputMetadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
from vllm.model_executor.models import ModelRegistry

from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import (
from sglang.srt.managers.schedule_batch import (
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
Expand Down Expand Up @@ -273,7 +273,7 @@ def init_flash_infer(self):
)

def init_cuda_graphs(self):
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
self.cuda_graph_runner = None
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata

LoraConfig = None

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata


@torch.compile
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata


class DbrxRouter(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.infer_batch import InputMetadata
from sglang.srt.managers.schedule_batch import InputMetadata


class DeepseekMLP(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata


class DeepseekV2MLP(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata


class GemmaMLP(nn.Module):
Expand Down
Loading
Loading