diff --git a/README.md b/README.md index d82567202f..aeaa43bb54 100644 --- a/README.md +++ b/README.md @@ -205,28 +205,29 @@ print(response) It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). ### Additional Server Arguments -- Add `--tp 2` to enable multi-GPU tensor parallelism. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. +- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2 ``` -- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. +- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 ``` - If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --mem-fraction-static 0.7 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7 ``` - See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance. - If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096 ``` -- To enable torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes. -- To enable fp8 weight quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. -- To enable fp8 kv cache quanzation, you can add `--kv-cache-dtype fp8_e5m2`. -- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). -- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. +- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. +- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. +- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`. +- To enable DeepSeek MLA acceleration, add `--enable-mla`. +- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). +- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. ``` # Node 0 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index b1ac43e9dc..93fcc01158 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -479,7 +479,8 @@ def main(server_args, bench_args): if __name__ == "__main__": - # TODO(kevin85421): Make the parser setup unit testable. + multiprocessing.set_start_method("spawn", force=True) + parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index af3986bc2c..835664fb63 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -22,7 +22,7 @@ from sglang.global_config import global_config from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata if TYPE_CHECKING: @@ -332,7 +332,6 @@ class TritonAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.triton_attention.decode_attention import ( - REDUCE_TORCH_TYPE, decode_attention_fwd, ) from sglang.srt.layers.triton_attention.extend_attention import ( @@ -343,9 +342,13 @@ def __init__(self, model_runner: ModelRunner): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd - self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE self.num_head = model_runner.model_config.num_attention_heads + if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): + self.reduce_dtype = torch.float32 + else: + self.reduce_dtype = torch.float16 + self.forward_metadata = None self.cuda_graph_max_seq_len = model_runner.model_config.context_len @@ -362,7 +365,7 @@ def init_forward_metadata( total_num_tokens = torch.sum(input_metadata.seq_lens).item() attn_logits = torch.empty( (self.num_head, total_num_tokens), - dtype=self.REDUCE_TORCH_TYPE, + dtype=self.reduce_dtype, device="cuda", ) @@ -382,8 +385,11 @@ def init_cuda_graph_state(self, max_bs: int): (max_bs,), dtype=torch.int32, device="cuda" ) self.cuda_graph_attn_logits = torch.empty( - (self.num_head, self.cuda_graph_max_total_num_tokens), - dtype=self.REDUCE_TORCH_TYPE, + ( + self.num_head, + self.cuda_graph_max_total_num_tokens, + ), + dtype=self.reduce_dtype, device="cuda", ) @@ -403,13 +409,6 @@ def init_forward_metadata_replay_cuda_graph( self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - self.forward_metadata = ( - self.cuda_graph_start_loc, - self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, - None, - ) - def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -444,6 +443,10 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat return o def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + # During torch.compile, there is a bug in rotary_emb that causes the + # output value to have a 3D tensor shape. This reshapes the output correctly. + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py index 5d8eb9ae4a..9e06b068cb 100644 --- a/python/sglang/srt/layers/triton_attention/decode_attention.py +++ b/python/sglang/srt/layers/triton_attention/decode_attention.py @@ -21,19 +21,9 @@ # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py -import torch import triton import triton.language as tl -from sglang.srt.managers.schedule_batch import global_server_args_dict - -if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): - REDUCE_TRITON_TYPE = tl.float32 - REDUCE_TORCH_TYPE = torch.float32 -else: - REDUCE_TRITON_TYPE = tl.float16 - REDUCE_TORCH_TYPE = torch.float16 - @triton.jit def tanh(x): @@ -67,6 +57,7 @@ def _fwd_kernel_stage1( cur_head = tl.program_id(1) start_n = tl.program_id(2) + reduce_dtype = Att_Out.dtype.element_ty cur_kv_head = cur_head // kv_group_num offs_d = tl.arange(0, BLOCK_DMODEL) @@ -85,7 +76,7 @@ def _fwd_kernel_stage1( block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE) + q = tl.load(Q + off_q + start_mark).to(reduce_dtype) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -101,7 +92,7 @@ def _fwd_kernel_stage1( K_Buffer + offs_buf_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk), other=0.0, - ).to(REDUCE_TRITON_TYPE) + ).to(reduce_dtype) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale @@ -198,7 +189,7 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 32 - Lq, Lk = q.shape[-1], k_buffer.shape[-1] + Lk = k_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -308,6 +299,7 @@ def _fwd_grouped_kernel_stage1( cur_kv_head = tl.program_id(1) start_n = tl.program_id(2) + reduce_dtype = Att_Out.dtype.element_ty cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_kv_head + 1) * kv_group_num mask_h = mask_h & (cur_head < q_head_num) @@ -336,7 +328,7 @@ def _fwd_grouped_kernel_stage1( for start_mark in range(0, block_mask, 1): q = tl.load( Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk) - ).to(REDUCE_TRITON_TYPE) + ).to(reduce_dtype) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -352,11 +344,11 @@ def _fwd_grouped_kernel_stage1( K_Buffer + offs_buf_k, mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk), other=0.0, - ).to(REDUCE_TRITON_TYPE) + ).to(reduce_dtype) qk = tl.dot(q, k) if BLOCK_DPE > 0: qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to( - REDUCE_TRITON_TYPE + reduce_dtype ) offs_buf_kpe = ( k_loc[None, :] * stride_buf_kbs @@ -367,7 +359,7 @@ def _fwd_grouped_kernel_stage1( K_Buffer + offs_buf_kpe, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0, - ).to(REDUCE_TRITON_TYPE) + ).to(reduce_dtype) qk += tl.dot(qpe, kpe) qk *= sm_scale @@ -477,8 +469,8 @@ def _decode_grouped_att_m_fwd( sm_scale, logit_cap, ): - BLOCK = 32 - Lq, Lk = q.shape[-1], k_buffer.shape[-1] + BLOCK = 64 + Lk = k_buffer.shape[-1] if Lk == 576: BLOCK_DMODEL = 512 diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index a816bb7fa1..51eae5613b 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -30,7 +30,13 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8" -if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + +def is_in_ci(): + """Return whether it is in CI runner.""" + return os.getenv("SGLANG_IS_IN_CI", "false") == "true" + + +if is_in_ci(): DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157 DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157" else: @@ -547,3 +553,35 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args): assert res["completed"] == num_prompts return res + + +def run_bench_latency(model, other_args): + command = [ + "python3", + "-m", + "sglang.bench_latency", + "--model-path", + model, + "--batch-size", + "1", + "--input", + "128", + "--output", + "8", + *other_args, + ] + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + try: + stdout, stderr = process.communicate() + output = stdout.decode() + error = stderr.decode() + print(f"Output: {output}", flush=True) + print(f"Error: {error}", flush=True) + + lastline = output.split("\n")[-3] + output_throughput = float(lastline.split(" ")[-2]) + finally: + kill_child_process(process.pid) + + return output_throughput diff --git a/test/srt/test_bench_latency.py b/test/srt/test_bench_latency.py index 2c893ee66f..4d2042ccf8 100644 --- a/test/srt/test_bench_latency.py +++ b/test/srt/test_bench_latency.py @@ -1,4 +1,3 @@ -import os import subprocess import unittest @@ -6,77 +5,25 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, + is_in_ci, + run_bench_latency, ) class TestBenchLatency(unittest.TestCase): def test_default(self): - command = [ - "python3", - "-m", - "sglang.bench_latency", - "--model-path", - DEFAULT_MODEL_NAME_FOR_TEST, - "--batch-size", - "1", - "--input", - "128", - "--output", - "8", - ] - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - - try: - stdout, stderr = process.communicate() - output = stdout.decode() - error = stderr.decode() - print(f"Output: {output}") - print(f"Error: {error}") - - lastline = output.split("\n")[-3] - value = float(lastline.split(" ")[-2]) + output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, []) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - assert value > 130 - finally: - kill_child_process(process.pid) + if is_in_ci(): + assert output_throughput > 130, f"{output_throughput=}" def test_moe_default(self): - command = [ - "python3", - "-m", - "sglang.bench_latency", - "--model", - DEFAULT_MOE_MODEL_NAME_FOR_TEST, - "--batch-size", - "1", - "--input", - "128", - "--output", - "8", - "--tp", - "2", - ] - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + output_throughput = run_bench_latency( + DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"] ) - try: - stdout, stderr = process.communicate() - output = stdout.decode() - error = stderr.decode() - print(f"Output: {output}") - print(f"Error: {error}") - - lastline = output.split("\n")[-3] - value = float(lastline.split(" ")[-2]) - - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - assert value > 125 - finally: - kill_child_process(process.pid) + if is_in_ci(): + assert output_throughput > 125, f"{output_throughput=}" if __name__ == "__main__": diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index a196b76761..eee6d7701b 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -1,9 +1,9 @@ -import os import unittest from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, + is_in_ci, run_bench_serving, ) @@ -18,7 +18,7 @@ def test_offline_throughput_default(self): other_server_args=[], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["output_throughput"] > 2600 def test_offline_throughput_without_radix_cache(self): @@ -29,7 +29,7 @@ def test_offline_throughput_without_radix_cache(self): other_server_args=["--disable-radix-cache"], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["output_throughput"] > 2800 def test_offline_throughput_without_chunked_prefill(self): @@ -40,7 +40,7 @@ def test_offline_throughput_without_chunked_prefill(self): other_server_args=["--chunked-prefill-size", "-1"], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["output_throughput"] > 2600 def test_offline_throughput_with_triton_attention_backend(self): @@ -56,7 +56,7 @@ def test_offline_throughput_with_triton_attention_backend(self): ], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["output_throughput"] > 2600 def test_online_latency_default(self): @@ -67,7 +67,7 @@ def test_online_latency_default(self): other_server_args=[], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["median_e2e_latency_ms"] < 12000 assert res["median_ttft_ms"] < 80 assert res["median_itl_ms"] < 12 @@ -80,7 +80,7 @@ def test_moe_offline_throughput_default(self): other_server_args=["--tp", "2"], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["output_throughput"] > 1850 def test_moe_offline_throughput_without_radix_cache(self): @@ -91,7 +91,7 @@ def test_moe_offline_throughput_without_radix_cache(self): other_server_args=["--tp", "2", "--disable-radix-cache"], ) - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + if is_in_ci(): assert res["output_throughput"] > 1950 diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index b6027b61cb..b15308dcec 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.625, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -54,7 +54,7 @@ def test_human_eval(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.425, f"{metrics}" + assert metrics["score"] >= 0.42, f"{metrics}" def test_mgsm_en(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.625, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index 9c6519d911..f0200a916c 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -1,3 +1,4 @@ +import subprocess import unittest from types import SimpleNamespace @@ -7,37 +8,49 @@ DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + is_in_ci, popen_launch_server, + run_bench_latency, ) class TestTritonAttnBackend(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--attention-backend", "triton"], + def test_latency(self): + output_throughput = run_bench_latency( + DEFAULT_MODEL_NAME_FOR_TEST, + [ + "--attention-backend", + "triton", + "--enable-torch-compile", + ], ) - @classmethod - def tearDownClass(cls): - kill_child_process(cls.process.pid) + if is_in_ci(): + assert output_throughput > 155, f"{output_throughput=}" def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--attention-backend", "triton"], ) - metrics = run_eval(args) - assert metrics["score"] >= 0.65 + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + finally: + kill_child_process(process.pid) if __name__ == "__main__":