Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: CUDAGraph compatibility of multi-level cascade inference APIs #586

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,22 @@ class MultiLevelCascadeAttentionWrapper:
...
>>> outputs[0].shape
torch.Size([7, 64, 128])

See Also
--------
BatchPrefillWithPagedKVCacheWrapper
"""

def __init__(
self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
self,
num_levels,
float_workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
use_cuda_graph: bool = False,
qo_indptr_buf_arr: Optional[list[torch.Tensor]] = None,
paged_kv_indptr_buf_arr: Optional[list[torch.Tensor]] = None,
paged_kv_indices_buf_arr: Optional[list[torch.Tensor]] = None,
paged_kv_last_page_len_buf_arr: Optional[list[torch.Tensor]] = None,
) -> None:
r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.

Expand All @@ -298,14 +310,59 @@ def __init__(
buffer should be the same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
use_cuda_graph : bool
Whether to use CUDA graph to capture the kernels, if enabled, the auxiliary data structures
will be stored in provided buffers.
qo_indptr_buf_arr : Optional[List[torch.Tensor]]
An array of qo indptr buffers for each level, the array length should be equal to
the number of levels.
The last element of each tensor should be the total number of queries/outputs.
paged_kv_indptr_buf_arr : Optional[List[torch.Tensor]]
An array of paged kv-cache indptr buffers for each level, the array length should be
equal to the number of levels.
paged_kv_indices_buf_arr : Optional[List[torch.Tensor]]
An array of paged kv-cache indices buffers for each level, the array length should be
equal to the number of levels.
paged_kv_last_page_len_buf_arr : Optional[List[torch.Tensor]]
An array of paged kv-cache last page length buffers for each level, the array length
should be equal to the number of levels.
"""
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
for _ in range(num_levels)
]
self._use_cuda_graph = use_cuda_graph
if use_cuda_graph:
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer,
kv_layout,
use_cuda_graph=True,
qo_indptr_buf=qo_indptr_buf,
paged_kv_indptr_buf=paged_kv_indptr_buf,
paged_kv_indices_buf=paged_kv_indices_buf,
paged_kv_last_page_len_buf=paged_kv_last_page_len_buf,
)
for (
qo_indptr_buf,
paged_kv_indptr_buf,
paged_kv_indices_buf,
paged_kv_last_page_len_buf,
) in zip(
qo_indptr_buf_arr,
paged_kv_indptr_buf_arr,
paged_kv_indices_buf_arr,
paged_kv_last_page_len_buf_arr,
)
]
else:
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
for _ in range(num_levels)
]
self._num_levels = num_levels
self._kv_layout = kv_layout

@property
def is_cuda_graph_enabled(self) -> bool:
return self._use_cuda_graph

def reset_workspace_buffer(
self,
float_workspace_buffer: torch.Tensor,
Expand Down Expand Up @@ -912,7 +969,7 @@ def forward(
k_shared: torch.Tensor,
v_shared: torch.Tensor,
unique_kv_cache: torch.Tensor,
causal: bool = True,
causal: bool = False,
allow_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def __init__(

use_cuda_graph : bool
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
auxiliary data structures will be stored as provided buffers. The ``batch_size``
auxiliary data structures will be stored in provided buffers. The ``batch_size``
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.

qo_indptr_buf : Optional[torch.Tensor]
Expand Down Expand Up @@ -1095,7 +1095,7 @@ def forward(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
k_scale: Optional[float] = None,
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def forward_return_lse(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
k_scale: Optional[float] = None,
Expand Down Expand Up @@ -1491,7 +1491,7 @@ def plan(
head_dim: int,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
window_left: int = -1,
Expand Down Expand Up @@ -1683,7 +1683,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
window_left: int = -1,
Expand Down Expand Up @@ -1812,7 +1812,7 @@ def forward_return_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
window_left: int = -1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ceil_div(a, b):
@pytest.mark.parametrize("unique_kv_len", [37, 17])
@pytest.mark.parametrize("shared_kv_len", [128, 512, 2048])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_attention_with_shared_prefix_paged_kv_cache(
Expand Down