From 54bc87e408f1aa27c56e75fc12965b7513ceaf56 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 12 Jul 2024 12:56:33 -0700 Subject: [PATCH] update the usage of flashinfer --- python/sglang/srt/layers/radix_attention.py | 15 +--- python/sglang/srt/layers/token_attention.py | 8 +- .../srt/managers/controller/infer_batch.py | 84 ++++++++----------- .../srt/managers/controller/model_runner.py | 8 -- python/sglang/srt/server_args.py | 6 ++ 5 files changed, 46 insertions(+), 75 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index eab16d536e..2c3e91af10 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -31,21 +31,13 @@ def __init__( self.layer_id = layer_id if not global_server_args_dict.get("disable_flashinfer", False): - self.prefill_forward = self.prefill_forward_flashinfer - self.extend_forward = self.prefill_forward_flashinfer + self.extend_forward = self.extend_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer - # flashinfer now accepts float logit_cap argument - self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 else: - self.prefill_forward = self.prefill_forward_triton self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton - self.logit_cap = logit_cap if logit_cap is not None else 0 - def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): - # In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend". - # See the extend_forward_xxx functions. - raise NotImplementedError() + self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): o = torch.empty_like(q) @@ -86,7 +78,6 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): input_metadata.start_loc, input_metadata.seq_lens, input_metadata.max_seq_len, - input_metadata.other_kv_index, input_metadata.total_num_tokens, sm_scale=self.scaling, logit_cap=self.logit_cap, @@ -94,7 +85,7 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o - def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): + def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index f9d58ae276..9d7bda145a 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -107,7 +107,6 @@ def _fwd_kernel_stage2( stride_obs, stride_oh, stride_req_to_token_b, - other_kv_index, # To fix a NAN issue kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -138,7 +137,7 @@ def _fwd_kernel_stage2( + cur_batch_req_idx * stride_req_to_token_b + (start_n + offs_n), mask=(start_n + offs_n) < cur_batch_seq_len, - other=other_kv_index, + other=0, ) qk = tl.load( @@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd( b_req_idx, b_start_loc, b_seq_len, - other_kv_index, ): BLOCK = 64 batch, head = b_seq_len.shape[0], logics.shape[0] @@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd( o.stride(0), o.stride(1), req_to_tokens.stride(0), - other_kv_index, ) return @@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd( o.stride(0), o.stride(1), req_to_tokens.stride(0), - other_kv_index, kv_group_num=kv_group_num, BLOCK_DMODEL=v_buffer.shape[-1], BLOCK_N=BLOCK, @@ -315,7 +311,6 @@ def token_attention_fwd( b_start_loc, b_seq_len, max_len_in_batch, - other_kv_index, total_num_tokens, sm_scale=None, logit_cap=-1, @@ -347,5 +342,4 @@ def token_attention_fwd( b_req_idx, b_start_loc, b_seq_len, - other_kv_index, ) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 793262b6f8..27d041d1dd 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -729,7 +729,6 @@ class InputMetadata: out_cache_cont_start: torch.Tensor = None out_cache_cont_end: torch.Tensor = None - other_kv_index: torch.Tensor = None return_logprob: bool = False top_logprobs_nums: List[int] = None @@ -743,24 +742,19 @@ class InputMetadata: flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): - if ( - self.forward_mode == ForwardMode.EXTEND - ): + if self.forward_mode == ForwardMode.DECODE: + paged_kernel_lens = self.seq_lens + else: paged_kernel_lens = self.prefix_lens self.no_prefix = torch.all(self.prefix_lens == 0) - else: - paged_kernel_lens = self.seq_lens - self.kv_indptr = torch.zeros( + kv_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) - self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - self.kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - self.kv_indices = torch.cat( + kv_indices = torch.cat( [ self.req_to_token_pool.req_to_token[ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] @@ -769,18 +763,34 @@ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): ], dim=0, ).contiguous() + kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) - if self.forward_mode == ForwardMode.EXTEND: + if self.forward_mode == ForwardMode.DECODE: + self.flashinfer_decode_wrapper.end_forward() + self.flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + pos_encoding_mode="NONE", + data_type=self.token_to_kv_pool.kv_data[0].dtype, + ) + else: # extend part - self.qo_indptr = torch.zeros( + qo_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) - self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) + qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) self.flashinfer_prefill_wrapper_ragged.end_forward() self.flashinfer_prefill_wrapper_ragged.begin_forward( - self.qo_indptr, - self.qo_indptr.clone(), + qo_indptr, + qo_indptr, num_qo_heads, num_kv_heads, head_dim, @@ -789,28 +799,15 @@ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): # cached part self.flashinfer_prefill_wrapper_paged.end_forward() self.flashinfer_prefill_wrapper_paged.begin_forward( - self.qo_indptr, - self.kv_indptr, - self.kv_indices, - self.kv_last_page_len, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, 1, ) - else: - self.flashinfer_decode_wrapper.end_forward() - self.flashinfer_decode_wrapper.begin_forward( - self.kv_indptr, - self.kv_indices, - self.kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - pos_encoding_mode="NONE", - data_type=self.token_to_kv_pool.kv_data[0].dtype, - ) def init_extend_args(self): self.extend_seq_lens = self.seq_lens - self.prefix_lens @@ -822,7 +819,6 @@ def init_extend_args(self): def create( cls, model_runner, - tp_size, forward_mode, req_pool_indices, seq_lens, @@ -833,9 +829,6 @@ def create( out_cache_cont_end=None, top_logprobs_nums=None, return_logprob=False, - flashinfer_prefill_wrapper_ragged=None, - flashinfer_prefill_wrapper_paged=None, - flashinfer_decode_wrapper=None, ): batch_size = len(req_pool_indices) start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") @@ -845,9 +838,6 @@ def create( if forward_mode == ForwardMode.DECODE: positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) - other_kv_index = model_runner.req_to_token_pool.req_to_token[ - req_pool_indices[0], seq_lens[0] - 1 - ].item() else: seq_lens_cpu = seq_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy() @@ -865,7 +855,6 @@ def create( ), device="cuda", ) - other_kv_index = None ret = cls( forward_mode=forward_mode, @@ -882,12 +871,11 @@ def create( out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, - other_kv_index=other_kv_index, return_logprob=return_logprob, top_logprobs_nums=top_logprobs_nums, - flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=flashinfer_decode_wrapper, + flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, + flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, ) if forward_mode == ForwardMode.EXTEND: @@ -895,8 +883,8 @@ def create( if not global_server_args_dict.get("disable_flashinfer", False): ret.init_flashinfer_args( - model_runner.model_config.num_attention_heads // tp_size, - model_runner.model_config.get_num_kv_heads(tp_size), + model_runner.model_config.num_attention_heads // model_runner.tp_size, + model_runner.model_config.get_num_kv_heads(model_runner.tp_size), model_runner.model_config.head_dim, ) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index a439756cf1..30a8001e71 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -221,7 +221,6 @@ def forward_extend(self, batch: Batch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.EXTEND, - tp_size=self.tp_size, req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, prefix_lens=batch.prefix_lens, @@ -229,9 +228,6 @@ def forward_extend(self, batch: Batch): out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -242,7 +238,6 @@ def forward_decode(self, batch: Batch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.DECODE, - tp_size=self.tp_size, req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, prefix_lens=batch.prefix_lens, @@ -252,9 +247,6 @@ def forward_decode(self, batch: Batch): out_cache_cont_end=batch.out_cache_cont_end, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 698a7bcc0d..ef8b6d252a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -53,6 +53,7 @@ class ServerArgs: disable_flashinfer: bool = False disable_radix_cache: bool = False disable_regex_jump_forward: bool = False + disable_cuda_graph: bool = False disable_disk_cache: bool = False attention_reduce_in_fp32: bool = False enable_p2p_check: bool = False @@ -294,6 +295,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable regex jump-forward", ) + parser.add_argument( + "--disable-cuda-graph", + action="store_true", + help="Disable cuda graph.", + ) parser.add_argument( "--disable-disk-cache", action="store_true",