diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7e9fd0f726..fbaa63a8f2 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -45,6 +45,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import make_layers +from sageattention import sageattn class LlamaMLP(nn.Module): def __init__( @@ -167,7 +168,10 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, forward_batch) + + # attn_output = self.attn(q, k, v, forward_batch) + attn_output = sageattn(q, k, v, is_causal=False, smooth_k=True) + output, _ = self.o_proj(attn_output) return output