Skip to content

Commit

Permalink
Support Triton fp8 e5m2 kv cache (#1286)
Browse files Browse the repository at this point in the history
Co-authored-by: Yineng Zhang <me@zhyncs.com>
  • Loading branch information
ispobock and zhyncs authored Sep 1, 2024
1 parent 761b2ce commit 6cb32ef
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
16 changes: 12 additions & 4 deletions python/sglang/srt/layers/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _fwd_kernel(
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk += tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
Expand All @@ -140,7 +140,7 @@ def _fwd_kernel(
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale

if logit_cap > 0:
Expand Down Expand Up @@ -276,9 +276,17 @@ def extend_attention_fwd(
BLOCK_DV = Lv

if CUDA_CAPABILITY[0] >= 9:
BLOCK_M, BLOCK_N = (128, 64)
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)

Expand Down
8 changes: 1 addition & 7 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,7 @@ def init_memory_pool(
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
logger.warning(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = torch.float8_e5m2
self.kv_cache_dtype = torch.float8_e5m2
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
Expand Down

0 comments on commit 6cb32ef

Please sign in to comment.