diff --git a/README.md b/README.md index 82934dc0a4..a7b32628b4 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,8 @@ The core features include: pip install "sglang[all]" ``` +Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels. + ### Method 2: From source ``` git clone https://github.com/sgl-project/sglang.git @@ -43,7 +45,11 @@ pip install --upgrade pip pip install -e "python[all]" ``` +Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels. + ### Notes +- If you see triton errors, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html). +- If you cannot install FlashInfer, you can use the slower triton kernels by adding `--disable-flashinfer` when launching the server. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"` ## Quick Start @@ -363,7 +369,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ``` python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7 ``` -- See [flashinfer.md](docs/flashinfer.md) on accelerating inference using highly optimized CUDA kernels. - See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance. ### Supported Models diff --git a/docs/flashinfer.md b/docs/flashinfer.md deleted file mode 100644 index 7acd083020..0000000000 --- a/docs/flashinfer.md +++ /dev/null @@ -1,18 +0,0 @@ -## Flashinfer Mode - -[flashinfer](https://github.com/flashinfer-ai/flashinfer) is a kernel library for LLM serving. -It can be used in SGLang runtime to accelerate attention computation. - -### Install flashinfer - -See https://docs.flashinfer.ai/installation.html. - -### Run a Server With Flashinfer Mode - -Add `--enable-flashinfer` argument to enable flashinfer when launching a server. - -Example: - -```bash -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --enable-flashinfer -``` diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 824fee8bc6..66d206082e 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -26,7 +26,7 @@ def __init__( from sglang.srt.managers.controller.model_runner import global_server_args_dict - if global_server_args_dict.get("enable_flashinfer", False): + if not global_server_args_dict.get("disable_flashinfer", False): self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 1450abd1df..e415147065 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -201,7 +201,7 @@ def create( if forward_mode == ForwardMode.EXTEND: ret.init_extend_args() - if global_server_args_dict.get("enable_flashinfer", False): + if not global_server_args_dict.get("disable_flashinfer", False): ret.init_flashinfer_args( model_runner.model_config.num_attention_heads // tp_size, model_runner.model_config.get_num_kv_heads(tp_size), @@ -263,7 +263,7 @@ def __init__( # Set some global args global global_server_args_dict global_server_args_dict = { - "enable_flashinfer": server_args.enable_flashinfer, + "disable_flashinfer": server_args.disable_flashinfer, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, } @@ -359,7 +359,7 @@ def init_cublas(self): return c def init_flash_infer(self): - if global_server_args_dict.get("enable_flashinfer", False): + if not global_server_args_dict.get("disable_flashinfer", False): from flashinfer import ( BatchPrefillWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4b7daf5f90..18305a2eb8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -50,7 +50,7 @@ class ServerArgs: load_balance_method: str = "round_robin" # Optimization/debug options - enable_flashinfer: bool = False + disable_flashinfer: bool = True attention_reduce_in_fp32: bool = False disable_radix_cache: bool = False disable_regex_jump_forward: bool = False @@ -287,9 +287,9 @@ def add_cli_args(parser: argparse.ArgumentParser): # Optimization/debug options parser.add_argument( - "--enable-flashinfer", + "--disable-flashinfer", action="store_true", - help="Enable flashinfer inference kernels", + help="Disable flashinfer inference kernels", ) parser.add_argument( "--attention-reduce-in-fp32", @@ -322,7 +322,7 @@ def url(self): def print_mode_args(self): return ( - f"enable_flashinfer={self.enable_flashinfer}, " + f"disable_flashinfer={self.disable_flashinfer}, " f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, " f"disable_radix_cache={self.disable_radix_cache}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "