Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turn on flashinfer by default #578

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The core features include:
pip install "sglang[all]"
```

Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels.

### Method 2: From source
```
git clone https://github.com/sgl-project/sglang.git
Expand All @@ -43,7 +45,11 @@ pip install --upgrade pip
pip install -e "python[all]"
```

Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels.

### Notes
- If you see triton errors, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html).
- If you cannot install FlashInfer, you can use the slower triton kernels by adding `--disable-flashinfer` when launching the server.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`

## Quick Start
Expand Down Expand Up @@ -363,7 +369,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
- See [flashinfer.md](docs/flashinfer.md) on accelerating inference using highly optimized CUDA kernels.
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.

### Supported Models
Expand Down
18 changes: 0 additions & 18 deletions docs/flashinfer.md

This file was deleted.

2 changes: 1 addition & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(

from sglang.srt.managers.controller.model_runner import global_server_args_dict

if global_server_args_dict.get("enable_flashinfer", False):
if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def create(
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()

if global_server_args_dict.get("enable_flashinfer", False):
if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
Expand Down Expand Up @@ -263,7 +263,7 @@ def __init__(
# Set some global args
global global_server_args_dict
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}

Expand Down Expand Up @@ -359,7 +359,7 @@ def init_cublas(self):
return c

def init_flash_infer(self):
if global_server_args_dict.get("enable_flashinfer", False):
if not global_server_args_dict.get("disable_flashinfer", False):
from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper,
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ServerArgs:
load_balance_method: str = "round_robin"

# Optimization/debug options
enable_flashinfer: bool = False
disable_flashinfer: bool = True
attention_reduce_in_fp32: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
Expand Down Expand Up @@ -287,9 +287,9 @@ def add_cli_args(parser: argparse.ArgumentParser):

# Optimization/debug options
parser.add_argument(
"--enable-flashinfer",
"--disable-flashinfer",
action="store_true",
help="Enable flashinfer inference kernels",
help="Disable flashinfer inference kernels",
)
parser.add_argument(
"--attention-reduce-in-fp32",
Expand Down Expand Up @@ -322,7 +322,7 @@ def url(self):

def print_mode_args(self):
return (
f"enable_flashinfer={self.enable_flashinfer}, "
f"disable_flashinfer={self.disable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
Expand Down