Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Triton fp8 e5m2 kv cache #1286

Merged
merged 4 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/sglang/srt/layers/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _fwd_kernel(
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk += tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
Expand All @@ -140,7 +140,7 @@ def _fwd_kernel(
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale

if logit_cap > 0:
Expand Down Expand Up @@ -278,7 +278,12 @@ def extend_attention_fwd(
if CUDA_CAPABILITY[0] >= 9:
BLOCK_M, BLOCK_N = (128, 64)
elif CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)

Expand Down
8 changes: 1 addition & 7 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,7 @@ def init_memory_pool(
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
logger.warning(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = torch.float8_e5m2
self.kv_cache_dtype = torch.float8_e5m2
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
Expand Down
Loading