Skip to content

Commit

Permalink
perf: improve plan performance by using non-blocking memcpy (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Oct 22, 2024
1 parent 021b585 commit 41ebe6d
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 56 deletions.
5 changes: 2 additions & 3 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

DecodePlanInfo plan_info;

Expand Down Expand Up @@ -150,8 +150,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<DTypeKV*>(paged_v_cache.data_ptr()), kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
4 changes: 2 additions & 2 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(

auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
qo_indptr = qo_indptr.to(torch::kCPU);
kv_indptr = kv_indptr.to(torch::kCPU);
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");

PrefillPlanInfo plan_info;

Expand Down
46 changes: 24 additions & 22 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def compile_single_decode_module(
):
uri, path = gen_single_decode_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand All @@ -64,7 +65,8 @@ def compile_batch_decode_module(
):
uri, path = gen_batch_decode_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand Down Expand Up @@ -114,6 +116,7 @@ def get_batch_decode_module(*args):
_batch_decode_modules[args] = compile_batch_decode_module(*args)
return _batch_decode_modules[args]


def single_decode_with_kv_cache_with_jit_module(
jit_module: Any,
q: torch.Tensor,
Expand All @@ -123,8 +126,10 @@ def single_decode_with_kv_cache_with_jit_module(
kv_layout: str = "NHD",
window_left: int = -1,
):
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
return jit_module.run(q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args)
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
return jit_module.run(
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args
)


def single_decode_with_kv_cache(
Expand Down Expand Up @@ -444,6 +449,7 @@ def __init__(

if use_tensor_cores:
if use_cuda_graph:
# NOTE(Zihao): if once created, no need to update it in plan/run
self._qo_indptr_buf = torch.arange(
self._fixed_batch_size + 1,
dtype=torch.int32,
Expand Down Expand Up @@ -555,8 +561,7 @@ def plan(
if logits_soft_cap is None:
logits_soft_cap = 0.0

qo_indptr = _get_range_buf(batch_size + 1, indptr.device)

qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
Expand All @@ -569,21 +574,18 @@ def plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
self._paged_kv_indptr_buf.copy_(indptr)
self._paged_kv_indices_buf[: len(indices)] = indices
self._paged_kv_last_page_len_buf.copy_(last_page_len)
if self.use_tensor_cores:
self._qo_indptr_buf.copy_(qo_indptr)
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=True)
self._paged_kv_indices_buf[: len(indices)].copy_(indices, non_blocking=True)
self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=True)
else:
self._paged_kv_indptr_buf = indptr.to(self.device)
self._paged_kv_indices_buf = indices.to(self.device)
self._paged_kv_last_page_len_buf = last_page_len.to(self.device)
if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr.to(self.device)

qo_indptr = qo_indptr.to("cpu", non_blocking=True)
indptr = indptr.to("cpu", non_blocking=True)
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
self._paged_kv_last_page_len_buf = last_page_len.to(
self.device, non_blocking=True
)
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=True)

indptr_host = indptr.to("cpu", non_blocking=True)
if data_type is not None:
q_data_type = data_type
kv_data_type = data_type
Expand Down Expand Up @@ -612,8 +614,8 @@ def plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr,
indptr,
qo_indptr_host,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
Expand All @@ -635,7 +637,7 @@ def plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/jit/batch_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
DecodePlanInfo plan_info;
Expand Down
4 changes: 2 additions & 2 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
qo_indptr = qo_indptr.to(torch::kCPU);
kv_indptr = kv_indptr.to(torch::kCPU);
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
PrefillPlanInfo plan_info;
Expand Down
57 changes: 38 additions & 19 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def compile_single_prefill_module(
):
uri, path = gen_single_prefill_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand All @@ -68,7 +69,8 @@ def compile_batch_prefill_module(
):
uri, path = gen_batch_prefill_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand Down Expand Up @@ -125,6 +127,7 @@ def get_batch_prefill_module(*args):
_batch_prefill_modules[args] = compile_batch_prefill_module(*args)
return _batch_prefill_modules[args]


def single_prefill_with_kv_cache_with_jit_module(
jit_module: Any,
q: torch.Tensor,
Expand All @@ -137,7 +140,8 @@ def single_prefill_with_kv_cache_with_jit_module(
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
out = jit_module.run(
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args)
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args
)
return out if return_lse else out[0]


Expand Down Expand Up @@ -726,10 +730,14 @@ def plan(
"The length of paged_kv_indices exceeds the allocated buffer size."
)

self._qo_indptr_buf.copy_(qo_indptr)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr)
self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices
self._paged_kv_last_page_len_buf.copy_(paged_kv_last_page_len)
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=True)
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
paged_kv_indices, non_blocking=True
)
self._paged_kv_last_page_len_buf.copy_(
paged_kv_last_page_len, non_blocking=True
)

if packed_custom_mask is not None:
if not torch.is_tensor(self._custom_mask_buf):
Expand All @@ -740,20 +748,31 @@ def plan(
raise ValueError(
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
packed_custom_mask, non_blocking=True
)
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
self._qk_indptr_buf.copy_(qk_indptr)
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=True)
else:
self._qo_indptr_buf = qo_indptr.to(self.device)
self._paged_kv_indptr_buf = paged_kv_indptr.to(self.device)
self._paged_kv_indices_buf = paged_kv_indices.to(self.device)
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(self.device)
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True)
self._paged_kv_indptr_buf = paged_kv_indptr.to(
self.device, non_blocking=True
)
self._paged_kv_indices_buf = paged_kv_indices.to(
self.device, non_blocking=True
)
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
self.device, non_blocking=True
)
if packed_custom_mask is not None:
self._custom_mask_buf = packed_custom_mask.to(self.device)
self._qk_indptr_buf = qk_indptr.to(self.device)
self._custom_mask_buf = packed_custom_mask.to(
self.device, non_blocking=True
)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)

qo_indptr = qo_indptr.to("cpu", non_blocking=True)
paged_kv_indptr = paged_kv_indptr.to("cpu", non_blocking=True)
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu", non_blocking=True)
paged_kv_indptr_host = paged_kv_indptr.to("cpu", non_blocking=True)

if packed_custom_mask is not None:
mask_mode = MaskMode.CUSTOM.value
Expand Down Expand Up @@ -781,8 +800,8 @@ def plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr,
paged_kv_indptr,
qo_indptr_host,
paged_kv_indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
14 changes: 7 additions & 7 deletions python/flashinfer/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def plan(
num_blocks_row = len(indptr) - 1
qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32)
qo_indptr_host[-1] = M
qo_indptr = qo_indptr_host.to(indptr.device)
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=True)
if indices.max().item() * C > N:
raise ValueError("indices out of bound")
last_block_len = torch.full(
Expand All @@ -279,13 +279,13 @@ def plan(
mask.contiguous().view(-1), qk_indptr, bitorder="little"
)

self._qo_indptr = qo_indptr.to(self.device)
self._paged_kv_indptr_buf = indptr.to(self.device)
self._paged_kv_indices_buf = indices.to(self.device)
self._paged_kv_last_page_len = last_block_len.to(self.device)
self._qo_indptr = qo_indptr.to(self.device, non_blocking=True)
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
self._paged_kv_last_page_len = last_block_len.to(self.device, non_blocking=True)
if packed_mask is not None:
self._packed_mask_buf = packed_mask.to(self.device)
self._qk_indptr_buf = qk_indptr.to(self.device)
self._packed_mask_buf = packed_mask.to(self.device, non_blocking=True)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
mask_mode = MaskMode.CUSTOM.value
else:
self._packed_mask_buf = None
Expand Down

0 comments on commit 41ebe6d

Please sign in to comment.