Skip to content

Commit

Permalink
Enable MLA by default (#1447)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Sep 17, 2024
1 parent 90a26be commit c6b6d2e
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 18 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- To enable DeepSeek MLA acceleration, add `--enable-mla`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
```
Expand Down
1 change: 0 additions & 1 deletion docs/en/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- To enable DeepSeek MLA acceleration, add `--enable-mla`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
```
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"attention_backend": ServerArgs.attention_backend,
"sampling_backend": ServerArgs.sampling_backend,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"enable_mla": ServerArgs.enable_mla,
"disable_mla": ServerArgs.disable_mla,
"torchao_config": ServerArgs.torchao_config,
}

Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
}
)
Expand Down Expand Up @@ -329,7 +329,7 @@ def profile_max_num_token(self, total_gpu_memory: int):
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
and not self.server_args.disable_mla
):
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
Expand Down Expand Up @@ -397,7 +397,7 @@ def init_memory_pool(
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
and not self.server_args.disable_mla
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def __init__(
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if global_server_args_dict["enable_mla"]:
if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
config=config,
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -732,7 +732,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

if global_server_args_dict["enable_mla"]:
if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def __init__(
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if global_server_args_dict["enable_mla"]:
if not global_server_args_dict["disable_mla"]:
self.self_attn = MiniCPM3AttentionMLA(
config=config,
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -653,7 +653,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
)
weight_loader(param, loaded_weight)

if global_server_args_dict["enable_mla"]:
if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
Expand Down
14 changes: 7 additions & 7 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ class ServerArgs:
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
max_torch_compile_bs: int = 32
torchao_config: str = ""
enable_p2p_check: bool = False
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False

# LoRA
Expand Down Expand Up @@ -173,7 +173,7 @@ def __post_init__(self):
self.sampling_backend = "pytorch"

# Default kernel backends
if self.enable_mla:
if not self.disable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")
self.attention_backend = "triton"

Expand Down Expand Up @@ -514,6 +514,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--disable-mla",
action="store_true",
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
Expand Down Expand Up @@ -541,11 +546,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--enable-mla",
action="store_true",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--triton-attention-reduce-in-fp32",
action="store_true",
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_nightly_gsm8k_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def launch_server(self, model, is_fp8, is_tp2):
if is_tp2:
other_args.extend(["--tp", "2"])
if "DeepSeek" in model:
other_args.extend(["--enable-mla", "--mem-frac", "0.85"])
other_args.extend(["--mem-frac", "0.85"])

self.process = popen_launch_server(
model,
Expand Down

0 comments on commit c6b6d2e

Please sign in to comment.