Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.compile for triton backend #1422

Merged
merged 10 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,28 +205,29 @@ print(response)
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).

### Additional Server Arguments
- Add `--tp 2` to enable multi-GPU tensor parallelism. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2
```
- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2
```
- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --mem-fraction-static 0.7
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7
```
- See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable fp8 weight quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quanzation, you can add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
- To enable torch.compile support, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quanzation, add `--kv-cache-dtype fp8_e5m2`.
- To enable DeepSeek MLA acceleration, add `--enable-mla`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
Expand Down
29 changes: 16 additions & 13 deletions python/sglang/srt/layers/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata

if TYPE_CHECKING:
Expand Down Expand Up @@ -332,7 +332,6 @@ class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import (
REDUCE_TORCH_TYPE,
decode_attention_fwd,
)
from sglang.srt.layers.triton_attention.extend_attention import (
Expand All @@ -343,9 +342,13 @@ def __init__(self, model_runner: ModelRunner):

self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
self.num_head = model_runner.model_config.num_attention_heads

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_type = torch.float32
else:
self.reduce_type = torch.float16

self.forward_metadata = None

self.cuda_graph_max_seq_len = model_runner.model_config.context_len
Expand All @@ -362,7 +365,7 @@ def init_forward_metadata(
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
dtype=self.reduce_type,
device="cuda",
)

Expand All @@ -382,8 +385,11 @@ def init_cuda_graph_state(self, max_bs: int):
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(self.num_head, self.cuda_graph_max_total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_type,
device="cuda",
)

Expand All @@ -403,13 +409,6 @@ def init_forward_metadata_replay_cuda_graph(
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)

def get_cuda_graph_seq_len_fill_value(self):
return 1

Expand Down Expand Up @@ -444,6 +443,10 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
return o

def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
ispobock marked this conversation as resolved.
Show resolved Hide resolved

# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
Expand Down
36 changes: 19 additions & 17 deletions python/sglang/srt/layers/triton_attention/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,11 @@
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
import torch
import triton
import triton.language as tl

from sglang.srt.managers.schedule_batch import global_server_args_dict

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:
REDUCE_TRITON_TYPE = tl.float16
REDUCE_TORCH_TYPE = torch.float16


@triton.jit
def tanh(x):
Expand Down Expand Up @@ -63,6 +55,11 @@ def _fwd_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
reduce_type = tl.float32
else:
reduce_type = tl.float16

cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
Expand All @@ -85,7 +82,7 @@ def _fwd_kernel_stage1(
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)

for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)
q = tl.load(Q + off_q + start_mark).to(reduce_type)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
Expand All @@ -101,7 +98,7 @@ def _fwd_kernel_stage1(
K_Buffer + offs_buf_k,
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
other=0.0,
).to(REDUCE_TRITON_TYPE)
).to(reduce_type)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale

Expand Down Expand Up @@ -198,7 +195,7 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 32
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
Lk = k_buffer.shape[-1]

batch, head_num = B_req_idx.shape[0], q.shape[1]

Expand Down Expand Up @@ -304,6 +301,11 @@ def _fwd_grouped_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
reduce_type = tl.float32
else:
reduce_type = tl.float16

cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
start_n = tl.program_id(2)
Expand Down Expand Up @@ -336,7 +338,7 @@ def _fwd_grouped_kernel_stage1(
for start_mark in range(0, block_mask, 1):
q = tl.load(
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
).to(REDUCE_TRITON_TYPE)
).to(reduce_type)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
Expand All @@ -352,11 +354,11 @@ def _fwd_grouped_kernel_stage1(
K_Buffer + offs_buf_k,
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
other=0.0,
).to(REDUCE_TRITON_TYPE)
).to(reduce_type)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
REDUCE_TRITON_TYPE
reduce_type
)
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
Expand All @@ -367,7 +369,7 @@ def _fwd_grouped_kernel_stage1(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[None, :] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
).to(reduce_type)
qk += tl.dot(qpe, kpe)
qk *= sm_scale

Expand Down Expand Up @@ -477,8 +479,8 @@ def _decode_grouped_att_m_fwd(
sm_scale,
logit_cap,
):
BLOCK = 32
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
BLOCK = 64
Lk = k_buffer.shape[-1]

if Lk == 576:
BLOCK_DMODEL = 512
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8"

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":

def is_in_ci():
"""Return whether it is in CI runner."""
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"


if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
else:
Expand Down
6 changes: 3 additions & 3 deletions test/srt/test_bench_latency.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import subprocess
import unittest

from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
is_in_ci,
)


Expand Down Expand Up @@ -38,7 +38,7 @@ def test_default(self):
lastline = output.split("\n")[-3]
value = float(lastline.split(" ")[-2])

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert value > 130
finally:
kill_child_process(process.pid)
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_moe_default(self):
lastline = output.split("\n")[-3]
value = float(lastline.split(" ")[-2])

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert value > 125
finally:
kill_child_process(process.pid)
Expand Down
16 changes: 8 additions & 8 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import unittest

from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
is_in_ci,
run_bench_serving,
)

Expand All @@ -18,7 +18,7 @@ def test_offline_throughput_default(self):
other_server_args=[],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["output_throughput"] > 2600

def test_offline_throughput_without_radix_cache(self):
Expand All @@ -29,7 +29,7 @@ def test_offline_throughput_without_radix_cache(self):
other_server_args=["--disable-radix-cache"],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["output_throughput"] > 2800

def test_offline_throughput_without_chunked_prefill(self):
Expand All @@ -40,7 +40,7 @@ def test_offline_throughput_without_chunked_prefill(self):
other_server_args=["--chunked-prefill-size", "-1"],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["output_throughput"] > 2600

def test_offline_throughput_with_triton_attention_backend(self):
Expand All @@ -56,7 +56,7 @@ def test_offline_throughput_with_triton_attention_backend(self):
],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["output_throughput"] > 2600

def test_online_latency_default(self):
Expand All @@ -67,7 +67,7 @@ def test_online_latency_default(self):
other_server_args=[],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["median_e2e_latency_ms"] < 12000
assert res["median_ttft_ms"] < 80
assert res["median_itl_ms"] < 12
Expand All @@ -80,7 +80,7 @@ def test_moe_offline_throughput_default(self):
other_server_args=["--tp", "2"],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["output_throughput"] > 1850

def test_moe_offline_throughput_without_radix_cache(self):
Expand All @@ -91,7 +91,7 @@ def test_moe_offline_throughput_without_radix_cache(self):
other_server_args=["--tp", "2", "--disable-radix-cache"],
)

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
if is_in_ci():
assert res["output_throughput"] > 1950


Expand Down
Loading
Loading