Skip to content

Commit

Permalink
Turn on flashinfer by default (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Jul 2, 2024
1 parent 95dc093 commit 9380f50
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 27 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The core features include:
pip install "sglang[all]"
```

Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels.

### Method 2: From source
```
git clone https://github.com/sgl-project/sglang.git
Expand All @@ -43,7 +45,11 @@ pip install --upgrade pip
pip install -e "python[all]"
```

Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels.

### Notes
- If you see triton errors, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html).
- If you cannot install FlashInfer, you can use the slower triton kernels by adding `--disable-flashinfer` when launching the server.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`

## Quick Start
Expand Down Expand Up @@ -363,7 +369,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
- See [flashinfer.md](docs/flashinfer.md) on accelerating inference using highly optimized CUDA kernels.
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.

### Supported Models
Expand Down
18 changes: 0 additions & 18 deletions docs/flashinfer.md

This file was deleted.

2 changes: 1 addition & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(

from sglang.srt.managers.controller.model_runner import global_server_args_dict

if global_server_args_dict.get("enable_flashinfer", False):
if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def create(
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()

if global_server_args_dict.get("enable_flashinfer", False):
if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
Expand Down Expand Up @@ -263,7 +263,7 @@ def __init__(
# Set some global args
global global_server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}

Expand Down Expand Up @@ -359,7 +359,7 @@ def init_cublas(self):
return c

def init_flash_infer(self):
if global_server_args_dict.get("enable_flashinfer", False):
if not global_server_args_dict.get("disable_flashinfer", False):
from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper,
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ServerArgs:
load_balance_method: str = "round_robin"

# Optimization/debug options
enable_flashinfer: bool = False
disable_flashinfer: bool = True
attention_reduce_in_fp32: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
Expand Down Expand Up @@ -287,9 +287,9 @@ def add_cli_args(parser: argparse.ArgumentParser):

# Optimization/debug options
parser.add_argument(
"--enable-flashinfer",
"--disable-flashinfer",
action="store_true",
help="Enable flashinfer inference kernels",
help="Disable flashinfer inference kernels",
)
parser.add_argument(
"--attention-reduce-in-fp32",
Expand Down Expand Up @@ -322,7 +322,7 @@ def url(self):

def print_mode_args(self):
return (
f"enable_flashinfer={self.enable_flashinfer}, "
f"disable_flashinfer={self.disable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
Expand Down

0 comments on commit 9380f50

Please sign in to comment.