Skip to content

Commit

Permalink
Clean up the usage of flashinfer (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 12, 2024
1 parent 519e20c commit af4e791
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 75 deletions.
15 changes: 3 additions & 12 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,13 @@ def __init__(
self.layer_id = layer_id

if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap if logit_cap is not None else 0

def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise NotImplementedError()
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0

def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
Expand Down Expand Up @@ -86,15 +78,14 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)

return o

def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
Expand Down
8 changes: 1 addition & 7 deletions python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
stride_obs,
stride_oh,
stride_req_to_token_b,
other_kv_index, # To fix a NAN issue
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index,
other=0,
)

qk = tl.load(
Expand Down Expand Up @@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
Expand All @@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
)
return

Expand All @@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK,
Expand All @@ -315,7 +311,6 @@ def token_attention_fwd(
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_num_tokens,
sm_scale=None,
logit_cap=-1,
Expand Down Expand Up @@ -347,5 +342,4 @@ def token_attention_fwd(
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
)
84 changes: 36 additions & 48 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,6 @@ class InputMetadata:
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None

other_kv_index: torch.Tensor = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None

Expand All @@ -743,24 +742,19 @@ class InputMetadata:
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None

def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if (
self.forward_mode == ForwardMode.EXTEND
):
if self.forward_mode == ForwardMode.DECODE:
paged_kernel_lens = self.seq_lens
else:
paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0)
else:
paged_kernel_lens = self.seq_lens

self.kv_indptr = torch.zeros(
kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
self.kv_indices = torch.cat(
kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
Expand All @@ -769,18 +763,34 @@ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)

if self.forward_mode == ForwardMode.EXTEND:
if self.forward_mode == ForwardMode.DECODE:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
else:
# extend part
self.qo_indptr = torch.zeros(
qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)

self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward(
self.qo_indptr,
self.qo_indptr.clone(),
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
Expand All @@ -789,28 +799,15 @@ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
# cached part
self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward(
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)

def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
Expand All @@ -822,7 +819,6 @@ def init_extend_args(self):
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
Expand All @@ -833,9 +829,6 @@ def create(
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
flashinfer_prefill_wrapper_ragged=None,
flashinfer_prefill_wrapper_paged=None,
flashinfer_decode_wrapper=None,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
Expand All @@ -845,9 +838,6 @@ def create(

if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
Expand All @@ -865,7 +855,6 @@ def create(
),
device="cuda",
)
other_kv_index = None

ret = cls(
forward_mode=forward_mode,
Expand All @@ -882,21 +871,20 @@ def create(
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
)

if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()

if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
model_runner.model_config.head_dim,
)

Expand Down
8 changes: 0 additions & 8 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,13 @@ def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
Expand All @@ -242,7 +238,6 @@ def forward_decode(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
Expand All @@ -252,9 +247,6 @@ def forward_decode(self, batch: Batch):
out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ServerArgs:
disable_flashinfer: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False
Expand Down Expand Up @@ -294,6 +295,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable regex jump-forward",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-disk-cache",
action="store_true",
Expand Down

0 comments on commit af4e791

Please sign in to comment.