Skip to content

Commit

Permalink
feat: add an option non_blocking to plan function (#622)
Browse files Browse the repository at this point in the history
Use non-blocking memcpy only in plan functions when this option is
turned on.
  • Loading branch information
yzh119 authored Nov 20, 2024
1 parent f236f70 commit 560af6f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
29 changes: 22 additions & 7 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def plan(
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
non_blocking: bool = False,
) -> None:
r"""Plan batch decode for given problem specification.
Expand Down Expand Up @@ -687,6 +688,10 @@ def plan(
data_type: Optional[Union[str, torch.dtype]]
The data type of both the query and key/value tensors. Defaults to torch.float16.
data_type is deprecated, please use q_data_type and kv_data_type instead.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
Note
----
Expand Down Expand Up @@ -717,16 +722,26 @@ def plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=True)
self._paged_kv_indices_buf[: len(indices)].copy_(indices, non_blocking=True)
self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=True)
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking)
self._paged_kv_indices_buf[: len(indices)].copy_(
indices, non_blocking=non_blocking
)
self._paged_kv_last_page_len_buf.copy_(
last_page_len, non_blocking=non_blocking
)
else:
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
self._paged_kv_indptr_buf = indptr.to(
self.device, non_blocking=non_blocking
)
self._paged_kv_indices_buf = indices.to(
self.device, non_blocking=non_blocking
)
self._paged_kv_last_page_len_buf = last_page_len.to(
self.device, non_blocking=True
self.device, non_blocking=non_blocking
)
self._qo_indptr_buf = qo_indptr_host.to(
self.device, non_blocking=non_blocking
)
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=True)

indptr_host = indptr.to("cpu")
if data_type is not None:
Expand Down
30 changes: 18 additions & 12 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ def plan(
rope_theta: Optional[float] = None,
q_data_type: Union[str, torch.dtype] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
non_blocking: bool = False,
) -> None:
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
Expand Down Expand Up @@ -952,6 +953,9 @@ def plan(
The data type of the query tensor, defaults torch.float16.
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
Note
----
Expand Down Expand Up @@ -1003,13 +1007,13 @@ def plan(
"The length of paged_kv_indices exceeds the allocated buffer size."
)

self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=True)
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=non_blocking)
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
paged_kv_indices, non_blocking=True
paged_kv_indices, non_blocking=non_blocking
)
self._paged_kv_last_page_len_buf.copy_(
paged_kv_last_page_len, non_blocking=True
paged_kv_last_page_len, non_blocking=non_blocking
)

if packed_custom_mask is not None:
Expand All @@ -1022,26 +1026,28 @@ def plan(
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
packed_custom_mask, non_blocking=True
packed_custom_mask, non_blocking=non_blocking
)
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=True)
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=non_blocking)
else:
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True)
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking)
self._paged_kv_indptr_buf = paged_kv_indptr.to(
self.device, non_blocking=True
self.device, non_blocking=non_blocking
)
self._paged_kv_indices_buf = paged_kv_indices.to(
self.device, non_blocking=True
self.device, non_blocking=non_blocking
)
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
self.device, non_blocking=True
self.device, non_blocking=non_blocking
)
if packed_custom_mask is not None:
self._custom_mask_buf = packed_custom_mask.to(
self.device, non_blocking=True
self.device, non_blocking=non_blocking
)
self._qk_indptr_buf = qk_indptr.to(
self.device, non_blocking=non_blocking
)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)

# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu")
Expand Down
23 changes: 16 additions & 7 deletions python/flashinfer/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def plan(
rope_theta: Optional[float] = None,
q_data_type: Union[str, torch.dtype] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
non_blocking: bool = False,
) -> None:
r"""Create auxiliary data structures for block sparse attention.
Expand Down Expand Up @@ -241,6 +242,10 @@ def plan(
The data type of the query tensor.
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
The :meth:`plan` method should be called before any :meth:`run` or
:meth:`run_return_lse` calls, auxiliary data structures will be created
Expand All @@ -261,7 +266,7 @@ def plan(
num_blocks_row = len(indptr) - 1
qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32)
qo_indptr_host[-1] = M
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=True)
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=non_blocking)
if indices.max().item() * C > N:
raise ValueError("indices out of bound")
last_block_len = torch.full(
Expand All @@ -283,13 +288,17 @@ def plan(
mask.contiguous().view(-1), qk_indptr, bitorder="little"
)

self._qo_indptr = qo_indptr.to(self.device, non_blocking=True)
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
self._paged_kv_last_page_len = last_block_len.to(self.device, non_blocking=True)
self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking)
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=non_blocking)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=non_blocking)
self._paged_kv_last_page_len = last_block_len.to(
self.device, non_blocking=non_blocking
)
if packed_mask is not None:
self._packed_mask_buf = packed_mask.to(self.device, non_blocking=True)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
self._packed_mask_buf = packed_mask.to(
self.device, non_blocking=non_blocking
)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=non_blocking)
mask_mode = MaskMode.CUSTOM.value
else:
self._packed_mask_buf = None
Expand Down
1 change: 1 addition & 0 deletions tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def test_batch_decode_with_tuple_paged_kv_cache(
@pytest.mark.parametrize(
"kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
)
@pytest.mark.parametrize("contiguous_kv", [True, False])
def test_cuda_graph_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
Expand Down

0 comments on commit 560af6f

Please sign in to comment.