From 67e8c9764a8f0282bd63bedbed9fa6a587cdf059 Mon Sep 17 00:00:00 2001 From: lei Date: Sun, 20 Oct 2024 18:54:34 -0700 Subject: [PATCH] Update llama.py --- python/sglang/srt/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 543703c230..7da8a7fdfe 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -44,6 +44,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sageattention import sageattn class LlamaMLP(nn.Module): def __init__( @@ -166,7 +167,10 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, forward_batch) + + # attn_output = self.attn(q, k, v, forward_batch) + attn_output = sageattn(q, k, v, is_causal=False, smooth_k=True) + output, _ = self.o_proj(attn_output) return output