Skip to content

Commit

Permalink
Simplify flashinfer dispatch (#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Oct 1, 2024
1 parent 619bb6d commit 100f5b8
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 76 deletions.
26 changes: 21 additions & 5 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_utils import update_flashinfer_indices
from sglang.srt.layers.attention.flashinfer_utils import (
WrapperDispatch,
update_flashinfer_indices,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip

Expand Down Expand Up @@ -53,10 +56,19 @@ def __init__(self, model_runner: ModelRunner):
device="cuda",
)

assert not (
model_runner.sliding_window_size is not None
and model_runner.has_cross_attention
), "Sliding window and cross attention are not supported together"

self.num_wrappers = 1
self.dispatch_reason = None
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
else:
self.num_wrappers = 1
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.has_cross_attention:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION

# NOTE: we do not use ragged attention when there are multiple wrappers
self.prefill_wrapper_ragged = (
Expand Down Expand Up @@ -88,8 +100,12 @@ def _get_wrapper_idx(self, layer: nn.Module):
if self.num_wrappers == 1:
return 0

# TODO: make sure the idx is related to sliding window size
return layer.sliding_window_size == -1
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
return layer.sliding_window_size == -1
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
return layer.is_cross_attention

raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")

def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
Expand Down
136 changes: 67 additions & 69 deletions python/sglang/srt/layers/attention/flashinfer_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from enum import Enum, auto

import torch
import triton
import triton.language as tl


class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
CROSS_ATTENTION = auto()


@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
Expand Down Expand Up @@ -80,67 +87,6 @@ def __init__(
(self.batch_size,), dtype=torch.int32, device="cuda"
)

def _init_indices_no_sliding_window(self):
if self.use_ragged:
paged_kernel_lens = self.prefix_lens
else:
paged_kernel_lens = self.seq_lens

self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_indices = torch.empty(
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
)

create_flashinfer_kv_indices_triton[(self.batch_size,)](
self.model_runner.req_to_token_pool.req_to_token,
self.req_pool_indices,
paged_kernel_lens,
self.kv_indptr,
None,
self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
)

def _init_indices_sliding_window(self, wrapper_id):
if wrapper_id == 0:
# window attention use paged only
if self.forward_mode.is_decode():
paged_kernel_lens = torch.minimum(
self.seq_lens,
torch.tensor(self.model_runner.sliding_window_size + 1),
)
else:
paged_kernel_lens = torch.minimum(
self.seq_lens,
torch.tensor(self.model_runner.sliding_window_size)
+ self.seq_lens
- self.prefix_lens,
)
else:
# full attention
paged_kernel_lens = self.seq_lens

kv_start_idx = self.seq_lens - paged_kernel_lens
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_indices = torch.empty(
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(self.batch_size,)](
self.model_runner.req_to_token_pool.req_to_token,
self.req_pool_indices,
paged_kernel_lens,
self.kv_indptr,
kv_start_idx,
self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
)

def _update_decode_indices(self, decode_wrapper):
assert not isinstance(decode_wrapper, list)
decode_wrapper.end_forward()
Expand Down Expand Up @@ -189,8 +135,53 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
1,
)

def update_indices_no_sliding_window(self):
self._init_indices_no_sliding_window()
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
if dispatch_reason is None:
if self.use_ragged:
paged_kernel_lens = self.prefix_lens
else:
paged_kernel_lens = self.seq_lens
self.kv_start_idx = None
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
if wrapper_id == 0:
# window attention use paged only
if self.forward_mode.is_decode():
paged_kernel_lens = torch.minimum(
self.seq_lens,
torch.tensor(self.model_runner.sliding_window_size + 1),
)
else:
paged_kernel_lens = torch.minimum(
self.seq_lens,
torch.tensor(self.model_runner.sliding_window_size)
+ self.seq_lens
- self.prefix_lens,
)
else:
# full attention
paged_kernel_lens = self.seq_lens
self.kv_start_idx = self.seq_lens - paged_kernel_lens

self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_indices = torch.empty(
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
)

create_flashinfer_kv_indices_triton[(self.batch_size,)](
self.model_runner.req_to_token_pool.req_to_token,
self.req_pool_indices,
paged_kernel_lens,
self.kv_indptr,
self.kv_start_idx,
self.kv_indices,
self.model_runner.req_to_token_pool.req_to_token.size(1),
)

def _update_indicess_single_wrapper(self):
self._get_indices()

if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrappers[0])
Expand All @@ -200,11 +191,13 @@ def update_indices_no_sliding_window(self):
self.prefill_wrappers_paged[0],
)

def update_indices_sliding_window(self):
assert self.use_ragged is False
def _update_indices_cross_attention(self):
pass

def _update_indices_sliding_window(self):
assert self.use_ragged is False
for wrapper_id in range(2):
self._init_indices_sliding_window(wrapper_id)
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrappers[wrapper_id])
else:
Expand Down Expand Up @@ -233,7 +226,12 @@ def update_flashinfer_indices(
use_ragged,
)

if model_runner.sliding_window_size is None:
updater.update_indices_no_sliding_window()
dispatch_reason = model_runner.attn_backend.dispatch_reason

if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
updater._update_indices_sliding_window()
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
updater._update_indices_cross_attention()
else:
updater.update_indices_sliding_window()
assert model_runner.attn_backend.num_wrappers == 1
updater._update_indicess_single_wrapper()
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def __init__(
scaling: float,
num_kv_heads: int,
layer_id: int,
sliding_window_size: int = -1,
logit_cap: float = 0.0,
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
):
super().__init__()
self.tp_q_head_num = num_heads
Expand All @@ -47,6 +48,7 @@ def __init__(
self.layer_id = layer_id
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention

def forward(self, q, k, v, forward_batch: ForwardBatch):
if k is not None:
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def load_model(self):
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
Expand Down Expand Up @@ -453,6 +454,10 @@ def init_attention_backend(self):
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.has_cross_attention, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self.attn_backend = TritonAttnBackend(self)
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ def __init__(
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
logit_cap=self.config.attn_logit_softcapping,
sliding_window_size=(
get_attention_sliding_window_size(config)
if use_sliding_window
else None
),
logit_cap=self.config.attn_logit_softcapping,
)

def forward(
Expand Down

0 comments on commit 100f5b8

Please sign in to comment.