Skip to content

Commit

Permalink
perf: use cuda-core implemention for io-bound block-sparse attention (#…
Browse files Browse the repository at this point in the history
…560)

When operational intensity is low, select cuda-core implementations for
block-sparse attention.
  • Loading branch information
yzh119 authored Oct 26, 2024
1 parent ea86f81 commit 3fbf028
Showing 1 changed file with 109 additions and 47 deletions.
156 changes: 109 additions & 47 deletions python/flashinfer/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, Union, Tuple
import logging
import torch
from .decode import get_batch_decode_module
from .prefill import _compute_page_qk_indptr, get_batch_prefill_module
from .quantization import segment_packbits
from .utils import (
Expand Down Expand Up @@ -299,31 +300,65 @@ def plan(

kv_indptr_host = indptr.to("cpu", non_blocking=True)

self._cached_module = get_batch_prefill_module(
q_data_type,
kv_data_type,
q_data_type,
indptr.dtype,
head_dim,
PosEncodingMode[pos_encoding_mode].value,
mask_mode,
False, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
allow_fp16_qk_reduction,
)
# NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should
# be easy to add support for it if needed, leave it as a future work.
# at this moment, when mask is provided, we use the tensor-core implementation
if (
R * (num_qo_heads // num_kv_heads) < 4
and mask_mode == MaskMode.NON_CAUSAL.value
):
# If the operation is not compute-bound, we use the cuda-core implementation
self._use_tensor_cores = False
self._cached_module = get_batch_decode_module(
q_data_type,
kv_data_type,
q_data_type,
indptr.dtype,
head_dim,
PosEncodingMode[pos_encoding_mode].value,
False, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
)

self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
kv_indptr_host,
num_blocks_row,
num_qo_heads,
num_kv_heads,
C,
False, # is_cuda_graph_enabled
)
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
kv_indptr_host,
num_blocks_row,
num_qo_heads,
num_kv_heads,
C,
False, # is_cuda_graph_enabled
)
else:
# if the operation is compute-bound, we use the tensor-core implementation
self._use_tensor_cores = True
self._cached_module = get_batch_prefill_module(
q_data_type,
kv_data_type,
q_data_type,
indptr.dtype,
head_dim,
PosEncodingMode[pos_encoding_mode].value,
mask_mode,
False, # use_sliding_window
logits_soft_cap > 0, # use_logits_soft_cap
allow_fp16_qk_reduction,
)

self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
kv_indptr_host,
num_blocks_row,
num_qo_heads,
num_kv_heads,
C,
False, # is_cuda_graph_enabled
)

self._pos_encoding_mode = pos_encoding_mode
self._allow_fp16_qk_reduction = allow_fp16_qk_reduction
Expand Down Expand Up @@ -404,30 +439,57 @@ def run(
k = k.reshape(-1, self.C, *k.shape[-2:]).contiguous()
v = v.reshape(-1, self.C, *v.shape[-2:]).contiguous()

out = self._cached_module.paged_run(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q,
k,
v,
self._packed_mask_buf,
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
self._qo_indptr,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len,
self._qk_indptr_buf,
TensorLayout[self._kv_layout].value,
-1, # window_left
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
return_lse,
)
lse = None
if return_lse:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)

if self._use_tensor_cores:
out = self._cached_module.paged_run(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q,
k,
v,
self._packed_mask_buf,
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
self._qo_indptr,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len,
self._qk_indptr_buf,
TensorLayout[self._kv_layout].value,
-1, # window_left
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
lse,
)
else:
out = self._cached_module.run(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._plan_info,
q,
k,
v,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len,
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
TensorLayout[self._kv_layout].value,
-1, # window_left
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
lse,
)

return out if return_lse else out[0]
return (out, lse) if return_lse else out

def end_forward(self) -> None:
r"""Warning: This method is deprecated and has no effect."""
Expand Down

0 comments on commit 3fbf028

Please sign in to comment.