Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jun 2, 2024
1 parent 580201d commit 1a09125
Showing 1 changed file with 164 additions and 61 deletions.
225 changes: 164 additions & 61 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,18 @@ class BatchPrefillWithPagedKVCacheWrapper:
wrapper class manages the lifecycle of these data structures.
"""

def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
def __init__(
self,
workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
enable_cuda_graph: bool = False,
qo_indptr_buf: Optional[torch.Tensor] = None,
paged_kv_indptr_buf: Optional[torch.Tensor] = None,
paged_kv_indices_buf: Optional[torch.Tensor] = None,
paged_kv_last_page_len_buf: Optional[torch.Tensor] = None,
custom_mask_buf: Optional[torch.Tensor] = None,
qk_indptr_buf: Optional[torch.Tensor] = None,
):
r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`.
Parameters
Expand All @@ -482,22 +493,87 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
The user reserved workspace buffer used to store auxiliary data structures,
recommended size is 16MB, the device of the workspace buffer should be the
same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
enable_cuda_graph : bool
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
auxiliary data structures will be stored in provided buffers.
qo_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``qo_indptr`` array, should be large
enough to store the maximum possible size of the ``qo_indptr`` array during the
lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph``
is set to ``True``.
paged_kv_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``paged_kv_indptr`` array, should be large
enough to store the maximum possible size of the ``paged_kv_indptr`` array during
the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph``
is set to ``True``.
paged_kv_indices_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``paged_kv_indices`` array, should be large
enough to store the maximum possible size of the ``paged_kv_indices`` array during
the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph``
is set to ``True``.
paged_kv_last_page_len_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``paged_kv_last_page_len`` array, should be
large enough to store the maximum possible size of the ``paged_kv_last_page_len`` array
during the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph``
is set to ``True``.
custom_mask_buf : Optional[torch.Tensor]
The user reserved buffer to store the custom mask tensor, should be large enough to
store the maximum possible size of the custom mask tensor during the lifetime of the
wrapper. This argument is only effective when ``enable_cuda_graph`` is set to ``True``
and the custom mask will be used in attention computation.
qk_indptr_buf : Optional[torch.Tensor]
The user reserved buffer to store the ``qk_indptr`` array, should be large enough to
store the maximum possible size of the ``qk_indptr`` array during the lifetime of the
wrapper. This argument is only effective when ``enable_cuda_graph`` is set to ``True``
and the custom mask will be used in attention computation.
"""
check_kv_layout(kv_layout)
self._kv_layout = kv_layout
self._workspace_buffer = workspace_buffer
self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper(
TensorLayout[kv_layout].value,
workspace_buffer.numel() * workspace_buffer.element_size(),
enable_cuda_graph,
)
self._qo_indptr = None
self._paged_kv_indptr = None
self._paged_kv_indices = None
self._paged_kv_last_page_len = None
self._custom_mask = None
self._qk_indptr = None
if enable_cuda_graph:
if not torch.is_tensor(self._qo_indptr):
raise ValueError(
"qo_indptr_buf should be a torch.Tensor in CUDA graph mode"
)
if not torch.is_tensor(self._paged_kv_indptr):
raise ValueError(
"paged_kv_indptr_buf should be a torch.Tensor in CUDA graph mode"
)
if not torch.is_tensor(self._paged_kv_indices):
raise ValueError(
"paged_kv_indices_buf should be a torch.Tensor in CUDA graph mode"
)
if not torch.is_tensor(self._paged_kv_last_page_len):
raise ValueError(
"paged_kv_last_page_len_buf should be a torch.Tensor in CUDA graph mode"
)
# NOTE(Zihao): do not check custom_mask_buf and qk_indptr_buf here, as they are optional

self._qo_indptr_buf = qo_indptr_buf
self._paged_kv_indptr_buf = paged_kv_indptr_buf
self._paged_kv_indices_buf = paged_kv_indices_buf
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
self._custom_mask_buf = custom_mask_buf
self._qk_indptr_buf = qk_indptr_buf

@property
def is_cuda_graph_enabled(self):
return self._wrapper.is_cuda_graph_enabled

def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
r"""Reset the workspace buffer.
Expand Down Expand Up @@ -563,18 +639,44 @@ def begin_forward(
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
"""
batch_size = len(qo_indptr) - 1
self._qo_indptr = qo_indptr
self._paged_kv_indptr = paged_kv_indptr
self._paged_kv_indices = paged_kv_indices
self._paged_kv_last_page_len = paged_kv_last_page_len
if custom_mask is not None:
self._custom_mask = custom_mask
self._qk_indptr = _compute_page_qk_indptr(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
page_size,
if self.is_cuda_graph_enabled:
self._qo_indptr_buf[: len(qo_indptr)] = qo_indptr
self._paged_kv_indptr_buf[: len(paged_kv_indptr)] = paged_kv_indptr
self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices
self._paged_kv_last_page_len_buf[: len(paged_kv_last_page_len)] = (
paged_kv_last_page_len
)

if custom_mask is not None:
if not torch.is_tensor(self._custom_mask_buf):
raise ValueError(
"custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
if not torch.is_tensor(self._qk_indptr_buf):
raise ValueError(
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
self._custom_mask_buf[: len(custom_mask)] = custom_mask
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
self._qk_indptr_buf[: len(qo_indptr)] = _compute_page_qk_indptr(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
page_size,
)
else:
self._qo_indptr_buf = qo_indptr
self._paged_kv_indptr_buf = paged_kv_indptr
self._paged_kv_indices_buf = paged_kv_indices
self._paged_kv_last_page_len_buf = paged_kv_last_page_len
if custom_mask is not None:
self._custom_mask = custom_mask
self._qk_indptr = _compute_page_qk_indptr(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
page_size,
)
self._wrapper.begin_forward(
self._workspace_buffer,
qo_indptr,
Expand All @@ -586,12 +688,13 @@ def begin_forward(

def end_forward(self):
r"""Clear the auxiliary data structures created by :meth:`begin_forward`."""
self._qo_indptr = None
self._paged_kv_indptr = None
self._paged_kv_indices = None
self._paged_kv_last_page_len = None
self._custom_mask = None
self._qk_indptr = None
if not self.is_cuda_graph_enabled:
self._qo_indptr = None
self._paged_kv_indptr = None
self._paged_kv_indices = None
self._paged_kv_last_page_len = None
self._custom_mask = None
self._qk_indptr = None
self._wrapper.end_forward()

def forward(
Expand Down Expand Up @@ -660,11 +763,11 @@ def forward(
if self._custom_mask is None:
return self._wrapper.forward(
q,
self._qo_indptr,
self._qo_indptr_buf,
paged_kv_data,
self._paged_kv_indptr,
self._paged_kv_indices,
self._paged_kv_last_page_len,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len_buf,
causal,
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
Expand All @@ -676,13 +779,13 @@ def forward(
else:
return self._wrapper.forward_custom_mask(
q,
self._qo_indptr,
self._qo_indptr_buf,
paged_kv_data,
self._paged_kv_indptr,
self._paged_kv_indices,
self._paged_kv_last_page_len,
self._custom_mask,
self._qk_indptr,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len_buf,
self._custom_mask_buf,
self._qk_indptr_buf,
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
sm_scale,
Expand Down Expand Up @@ -758,11 +861,11 @@ def forward_return_lse(
if self._custom_mask is None:
return self._wrapper.forward(
q,
self._qo_indptr,
self._qo_indptr_buf,
paged_kv_data,
self._paged_kv_indptr,
self._paged_kv_indices,
self._paged_kv_last_page_len,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len_buf,
causal,
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
Expand All @@ -774,13 +877,13 @@ def forward_return_lse(
else:
return self._wrapper.forward(
q,
self._qo_indptr,
self._qo_indptr_buf,
paged_kv_data,
self._paged_kv_indptr,
self._paged_kv_indices,
self._paged_kv_last_page_len,
self._custom_mask,
self._qk_indptr,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len_buf,
self._custom_mask_buf,
self._qk_indptr_buf,
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
sm_scale,
Expand Down Expand Up @@ -938,13 +1041,13 @@ def __init__(
The user reserved GPU buffer to store the custom mask tensor, should be large
enough to store the maximum possible size of the custom mask tensor during the
lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph``
is ``True`` and custom mask will be used.
is ``True`` and custom mask will be used in attention computation.
qk_indptr_buf : Optional[torch.Tensor]
The user reserved GPU buffer to store the ``qk_indptr`` array, should be large
enough to store the maximum possible size of the ``qk_indptr`` array during the
lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph``
is ``True``.
is ``True`` and custom mask will be used in attention computation.
"""
check_kv_layout(kv_layout)
Expand All @@ -957,20 +1060,20 @@ def __init__(
)
if enable_cuda_graph:
if not torch.is_tensor(qo_indptr_buf):
raise ValueError("qo_indptr_buf should be a torch.Tensor")
raise ValueError(
"qo_indptr_buf should be a torch.Tensor in cuda graph mode"
)
if not torch.is_tensor(kv_indptr_buf):
raise ValueError("kv_indptr_buf should be a torch.Tensor")
raise ValueError(
"kv_indptr_buf should be a torch.Tensor in cuda graph mode"
)
# NOTE(Zihao): do not check custom_mask_buf and qk_indptr_buf here,
# as they may not be used.
self._qo_indptr_buf = qo_indptr_buf
self._kv_indptr_buf = kv_indptr_buf
self._custom_mask_buf = custom_mask_buf
self._qk_indptr_buf = qk_indptr_buf
else:
self._qo_indptr_buf = None
self._kv_indptr_buf = None
self._custom_mask_buf = None
self._qk_indptr_buf = None

self._qo_indptr_buf = qo_indptr_buf
self._kv_indptr_buf = kv_indptr_buf
self._custom_mask_buf = custom_mask_buf
self._qk_indptr_buf = qk_indptr_buf

@property
def is_cuda_graph_enabled(self):
Expand Down Expand Up @@ -1036,11 +1139,11 @@ def begin_forward(
if custom_mask is not None:
if not torch.is_tensor(self._custom_mask_buf):
raise ValueError(
"custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention."
"custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
if not torch.is_tensor(self._qk_indptr_buf):
raise ValueError(
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention."
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention computation."
)
self._custom_mask_buf[: len(custom_mask)] = custom_mask
self._qk_indptr_buf[: len(qo_indptr)] = _compute_qk_indptr(
Expand Down Expand Up @@ -1248,8 +1351,8 @@ def forward_return_lse(
k,
v,
self._kv_indptr_buf,
self._custom_mask,
self._qk_indptr,
self._custom_mask_buf,
self._qk_indptr_buf,
PosEncodingMode[pos_encoding_mode].value,
allow_fp16_qk_reduction,
sm_scale,
Expand Down

0 comments on commit 1a09125

Please sign in to comment.