Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.compile for triton backend #1422

Merged
merged 10 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,28 +205,29 @@ print(response)
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).

### Additional Server Arguments
- Add `--tp 2` to enable multi-GPU tensor parallelism. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2
```
- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2
```
- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --mem-fraction-static 0.7
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7
```
- See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable fp8 weight quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quanzation, you can add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- To enable DeepSeek MLA acceleration, add `--enable-mla`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
```
# Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ def main(server_args, bench_args):


if __name__ == "__main__":
# TODO(kevin85421): Make the parser setup unit testable.
multiprocessing.set_start_method("spawn", force=True)

parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
Expand Down
29 changes: 16 additions & 13 deletions python/sglang/srt/layers/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata

if TYPE_CHECKING:
Expand Down Expand Up @@ -332,7 +332,6 @@ class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import (
REDUCE_TORCH_TYPE,
decode_attention_fwd,
)
from sglang.srt.layers.triton_attention.extend_attention import (
Expand All @@ -343,9 +342,13 @@ def __init__(self, model_runner: ModelRunner):

self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
self.num_head = model_runner.model_config.num_attention_heads

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16

self.forward_metadata = None

self.cuda_graph_max_seq_len = model_runner.model_config.context_len
Expand All @@ -362,7 +365,7 @@ def init_forward_metadata(
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
dtype=self.reduce_dtype,
device="cuda",
)

Expand All @@ -382,8 +385,11 @@ def init_cuda_graph_state(self, max_bs: int):
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(self.num_head, self.cuda_graph_max_total_num_tokens),
dtype=self.REDUCE_TORCH_TYPE,
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
device="cuda",
)

Expand All @@ -403,13 +409,6 @@ def init_forward_metadata_replay_cuda_graph(
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)

def get_cuda_graph_seq_len_fill_value(self):
return 1

Expand Down Expand Up @@ -444,6 +443,10 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
return o

def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
ispobock marked this conversation as resolved.
Show resolved Hide resolved

# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
Expand Down
30 changes: 11 additions & 19 deletions python/sglang/srt/layers/triton_attention/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,9 @@
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
import torch
import triton
import triton.language as tl

from sglang.srt.managers.schedule_batch import global_server_args_dict

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:
REDUCE_TRITON_TYPE = tl.float16
REDUCE_TORCH_TYPE = torch.float16


@triton.jit
def tanh(x):
Expand Down Expand Up @@ -67,6 +57,7 @@ def _fwd_kernel_stage1(
cur_head = tl.program_id(1)
start_n = tl.program_id(2)

reduce_dtype = Att_Out.dtype.element_ty
cur_kv_head = cur_head // kv_group_num

offs_d = tl.arange(0, BLOCK_DMODEL)
Expand All @@ -85,7 +76,7 @@ def _fwd_kernel_stage1(
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)

for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
Expand All @@ -101,7 +92,7 @@ def _fwd_kernel_stage1(
K_Buffer + offs_buf_k,
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
other=0.0,
).to(REDUCE_TRITON_TYPE)
).to(reduce_dtype)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale

Expand Down Expand Up @@ -198,7 +189,7 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 32
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
Lk = k_buffer.shape[-1]

batch, head_num = B_req_idx.shape[0], q.shape[1]

Expand Down Expand Up @@ -308,6 +299,7 @@ def _fwd_grouped_kernel_stage1(
cur_kv_head = tl.program_id(1)
start_n = tl.program_id(2)

reduce_dtype = Att_Out.dtype.element_ty
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
mask_h = mask_h & (cur_head < q_head_num)
Expand Down Expand Up @@ -336,7 +328,7 @@ def _fwd_grouped_kernel_stage1(
for start_mark in range(0, block_mask, 1):
q = tl.load(
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
).to(REDUCE_TRITON_TYPE)
).to(reduce_dtype)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
Expand All @@ -352,11 +344,11 @@ def _fwd_grouped_kernel_stage1(
K_Buffer + offs_buf_k,
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
other=0.0,
).to(REDUCE_TRITON_TYPE)
).to(reduce_dtype)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
REDUCE_TRITON_TYPE
reduce_dtype
)
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
Expand All @@ -367,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[None, :] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
).to(reduce_dtype)
qk += tl.dot(qpe, kpe)
qk *= sm_scale

Expand Down Expand Up @@ -477,8 +469,8 @@ def _decode_grouped_att_m_fwd(
sm_scale,
logit_cap,
):
BLOCK = 32
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
BLOCK = 64
Lk = k_buffer.shape[-1]

if Lk == 576:
BLOCK_DMODEL = 512
Expand Down
40 changes: 39 additions & 1 deletion python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8"

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":

def is_in_ci():
"""Return whether it is in CI runner."""
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"


if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
else:
Expand Down Expand Up @@ -547,3 +553,35 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):

assert res["completed"] == num_prompts
return res


def run_bench_latency(model, other_args):
command = [
"python3",
"-m",
"sglang.bench_latency",
"--model-path",
model,
"--batch-size",
"1",
"--input",
"128",
"--output",
"8",
*other_args,
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

try:
stdout, stderr = process.communicate()
output = stdout.decode()
error = stderr.decode()
print(f"Output: {output}", flush=True)
print(f"Error: {error}", flush=True)

lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2])
finally:
kill_child_process(process.pid)

return output_throughput
71 changes: 9 additions & 62 deletions test/srt/test_bench_latency.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,29 @@
import os
import subprocess
import unittest

from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
is_in_ci,
run_bench_latency,
)


class TestBenchLatency(unittest.TestCase):
def test_default(self):
command = [
"python3",
"-m",
"sglang.bench_latency",
"--model-path",
DEFAULT_MODEL_NAME_FOR_TEST,
"--batch-size",
"1",
"--input",
"128",
"--output",
"8",
]
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)

try:
stdout, stderr = process.communicate()
output = stdout.decode()
error = stderr.decode()
print(f"Output: {output}")
print(f"Error: {error}")

lastline = output.split("\n")[-3]
value = float(lastline.split(" ")[-2])
output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, [])

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert value > 130
finally:
kill_child_process(process.pid)
if is_in_ci():
assert output_throughput > 130, f"{output_throughput=}"

def test_moe_default(self):
command = [
"python3",
"-m",
"sglang.bench_latency",
"--model",
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
"--batch-size",
"1",
"--input",
"128",
"--output",
"8",
"--tp",
"2",
]
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
output_throughput = run_bench_latency(
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"]
)

try:
stdout, stderr = process.communicate()
output = stdout.decode()
error = stderr.decode()
print(f"Output: {output}")
print(f"Error: {error}")

lastline = output.split("\n")[-3]
value = float(lastline.split(" ")[-2])

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
assert value > 125
finally:
kill_child_process(process.pid)
if is_in_ci():
assert output_throughput > 125, f"{output_throughput=}"


if __name__ == "__main__":
Expand Down
Loading
Loading