From d8f10eeb212cd074d1d05b79a956011cc7aca985 Mon Sep 17 00:00:00 2001 From: tsu-bin Date: Tue, 29 Oct 2024 21:05:41 +0800 Subject: [PATCH] update according to recent changes from upstream and some clean up --- python/flashinfer/__init__.py | 2 +- python/flashinfer/decode.py | 5 ++-- python/flashinfer/jit/attention.py | 7 +++-- src/flashinfer_ops.cuh | 46 ++++++++++++++---------------- 4 files changed, 29 insertions(+), 31 deletions(-) diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 3cb36a89..8ef7f046 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -24,7 +24,6 @@ BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper, BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper, - BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper, merge_state as merge_state, merge_state_in_place as merge_state_in_place, merge_states as merge_states, @@ -32,6 +31,7 @@ from .decode import ( BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper, CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper, single_decode_with_kv_cache as single_decode_with_kv_cache, ) from .gemm import ( diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index a06ecbcd..2bb5d6b9 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -77,11 +77,10 @@ def compile_batch_decode_mla_module( *args, verbose: bool = False, ): - gen_batch_decode_mla_cu(*args) - uri = get_batch_decode_mla_uri(*args) + uri, path = gen_batch_decode_mla_cu(*args) return load_cuda_ops( uri, - [FLASHINFER_GEN_SRC_DIR / f"{uri}.cu"], + [path], verbose=verbose, ) diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py index 78eb1c57..92163266 100644 --- a/python/flashinfer/jit/attention.py +++ b/python/flashinfer/jit/attention.py @@ -194,11 +194,14 @@ def gen_batch_decode_mla_cu(*args) -> None: gen_directory = FLASHINFER_GEN_SRC_DIR if not os.path.exists(gen_directory): os.makedirs(gen_directory) - file_name = f"{get_batch_decode_mla_uri(*args)}.cu" + uri = get_batch_decode_mla_uri(*args) + file_name = f"{uri}.cu" + path = gen_directory / file_name write_if_different( - gen_directory / file_name, + path, get_batch_decode_mla_cu_str(*args), ) + return uri, path def get_single_prefill_cu_str( dtype_q: torch.dtype, diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index f14e5dd0..5d4e8a77 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -607,25 +607,23 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperMLA( DISPATCH_head_dim(paged_kv.head_dim_ckv, HEAD_DIM_CKV, { // fixme: head_dim_ckv(kv_lora_rank) is 8 times the size of head_dim_kpe(qk_rope_head_dim) for all MLA model (DeepSeek-V2-Lite, DeepSeek-V2.5, MiniCPM3) at the time Oct.2024 constexpr auto HEAD_DIM_KPE = HEAD_DIM_CKV/8; - // DISPATCH_head_dim(paged_kv.head_dim_kpe, HEAD_DIM_KPE, { - using ParamsT = BatchDecodeParamsMLA; - using AttentionVariant = - ComposedAttention; - ParamsT params(q_nope, q_pe, q_offset, paged_kv, o, lse, num_qo_heads, - /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, - rope_theta); - params.request_indices = handler->GetRequestIndices(); - params.kv_tile_indices = handler->GetKVTileIndices(); - params.o_indptr = handler->GetOIndptr(); - params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); - params.block_valid_mask = handler->GetBlockValidMask(); - params.padded_batch_size = handler->GetPlanInfo().padded_batch_size; - - return BatchDecodeWithPagedKVCacheDispatchedMLA( - params, handler->GetTmpV(), handler->GetTmpS(), stream); - // }); + using ParamsT = BatchDecodeParamsMLA; + using AttentionVariant = + ComposedAttention; + ParamsT params(q_nope, q_pe, q_offset, paged_kv, o, lse, num_qo_heads, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, + rope_theta); + params.request_indices = handler->GetRequestIndices(); + params.kv_tile_indices = handler->GetKVTileIndices(); + params.o_indptr = handler->GetOIndptr(); + params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); + params.block_valid_mask = handler->GetBlockValidMask(); + params.padded_batch_size = handler->GetPlanInfo().padded_batch_size; + + return BatchDecodeWithPagedKVCacheDispatchedMLA( + params, handler->GetTmpV(), handler->GetTmpS(), stream); }); return cudaSuccess; } @@ -640,12 +638,10 @@ cudaError_t BatchDecodeHandlerPlanMLA(BatchDecodeHandler* handler, void* float_b DISPATCH_head_dim(head_dim_ckv, HEAD_DIM_CKV, { // fixme: head_dim_ckv(kv_lora_rank) is 8 times the size of head_dim_kpe(qk_rope_head_dim) for all MLA model (DeepSeek-V2-Lite, DeepSeek-V2.5, MiniCPM3) at the time Oct.2024 constexpr auto HEAD_DIM_KPE = HEAD_DIM_CKV/8; - // DISPATCH_head_dim(head_dim_kpe, HEAD_DIM_KPE, { - return handler->PlanDispatchedMLA( - float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes, - indptr_h, last_page_len_h, batch_size, num_qo_heads, page_size); - // }); - }); + return handler->PlanDispatchedMLA( + float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes, + indptr_h, last_page_len_h, batch_size, num_qo_heads, page_size); +}); } } // namespace flashinfer