Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up the usage of flashinfer #610

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,13 @@ def __init__(
self.layer_id = layer_id

if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap if logit_cap is not None else 0

def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise NotImplementedError()
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0

def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
Expand Down Expand Up @@ -86,15 +78,14 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
sm_scale=self.scaling,
logit_cap=self.logit_cap,
)

return o

def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
Expand Down
8 changes: 1 addition & 7 deletions python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
stride_obs,
stride_oh,
stride_req_to_token_b,
other_kv_index, # To fix a NAN issue
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index,
other=0,
)

qk = tl.load(
Expand Down Expand Up @@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
Expand All @@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
)
return

Expand All @@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK,
Expand All @@ -315,7 +311,6 @@ def token_attention_fwd(
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_num_tokens,
sm_scale=None,
logit_cap=-1,
Expand Down Expand Up @@ -347,5 +342,4 @@ def token_attention_fwd(
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
)
84 changes: 36 additions & 48 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,6 @@ class InputMetadata:
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None

other_kv_index: torch.Tensor = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None

Expand All @@ -743,24 +742,19 @@ class InputMetadata:
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None

def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if (
self.forward_mode == ForwardMode.EXTEND
):
if self.forward_mode == ForwardMode.DECODE:
paged_kernel_lens = self.seq_lens
else:
paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0)
else:
paged_kernel_lens = self.seq_lens

self.kv_indptr = torch.zeros(
kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
self.kv_indices = torch.cat(
kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
Expand All @@ -769,18 +763,34 @@ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)

if self.forward_mode == ForwardMode.EXTEND:
if self.forward_mode == ForwardMode.DECODE:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
else:
# extend part
self.qo_indptr = torch.zeros(
qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)

self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward(
self.qo_indptr,
self.qo_indptr.clone(),
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
Expand All @@ -789,28 +799,15 @@ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
# cached part
self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward(
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)

def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
Expand All @@ -822,7 +819,6 @@ def init_extend_args(self):
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
Expand All @@ -833,9 +829,6 @@ def create(
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
flashinfer_prefill_wrapper_ragged=None,
flashinfer_prefill_wrapper_paged=None,
flashinfer_decode_wrapper=None,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
Expand All @@ -845,9 +838,6 @@ def create(

if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
Expand All @@ -865,7 +855,6 @@ def create(
),
device="cuda",
)
other_kv_index = None

ret = cls(
forward_mode=forward_mode,
Expand All @@ -882,21 +871,20 @@ def create(
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
)

if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()

if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
model_runner.model_config.head_dim,
)

Expand Down
8 changes: 0 additions & 8 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,13 @@ def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
Expand All @@ -242,7 +238,6 @@ def forward_decode(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
Expand All @@ -252,9 +247,6 @@ def forward_decode(self, batch: Batch):
out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ServerArgs:
disable_flashinfer: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False
Expand Down Expand Up @@ -294,6 +295,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable regex jump-forward",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument(
"--disable-disk-cache",
action="store_true",
Expand Down