Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: improve plan performance by using non-blocking memcpy #547

Merged
merged 5 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

DecodePlanInfo plan_info;

Expand Down Expand Up @@ -150,8 +150,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<DTypeKV*>(paged_v_cache.data_ptr()), kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
4 changes: 2 additions & 2 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(

auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
qo_indptr = qo_indptr.to(torch::kCPU);
kv_indptr = kv_indptr.to(torch::kCPU);
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");

PrefillPlanInfo plan_info;

Expand Down
46 changes: 24 additions & 22 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def compile_single_decode_module(
):
uri, path = gen_single_decode_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand All @@ -64,7 +65,8 @@ def compile_batch_decode_module(
):
uri, path = gen_batch_decode_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand Down Expand Up @@ -114,6 +116,7 @@ def get_batch_decode_module(*args):
_batch_decode_modules[args] = compile_batch_decode_module(*args)
return _batch_decode_modules[args]


def single_decode_with_kv_cache_with_jit_module(
jit_module: Any,
q: torch.Tensor,
Expand All @@ -123,8 +126,10 @@ def single_decode_with_kv_cache_with_jit_module(
kv_layout: str = "NHD",
window_left: int = -1,
):
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
return jit_module.run(q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args)
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
return jit_module.run(
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args
)


def single_decode_with_kv_cache(
Expand Down Expand Up @@ -444,6 +449,7 @@ def __init__(

if use_tensor_cores:
if use_cuda_graph:
# NOTE(Zihao): if once created, no need to update it in plan/run
self._qo_indptr_buf = torch.arange(
self._fixed_batch_size + 1,
dtype=torch.int32,
Expand Down Expand Up @@ -555,8 +561,7 @@ def plan(
if logits_soft_cap is None:
logits_soft_cap = 0.0

qo_indptr = _get_range_buf(batch_size + 1, indptr.device)

qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
Expand All @@ -569,21 +574,18 @@ def plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
self._paged_kv_indptr_buf.copy_(indptr)
self._paged_kv_indices_buf[: len(indices)] = indices
self._paged_kv_last_page_len_buf.copy_(last_page_len)
if self.use_tensor_cores:
self._qo_indptr_buf.copy_(qo_indptr)
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=True)
self._paged_kv_indices_buf[: len(indices)].copy_(indices, non_blocking=True)
self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=True)
else:
self._paged_kv_indptr_buf = indptr.to(self.device)
self._paged_kv_indices_buf = indices.to(self.device)
self._paged_kv_last_page_len_buf = last_page_len.to(self.device)
if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr.to(self.device)

qo_indptr = qo_indptr.to("cpu", non_blocking=True)
indptr = indptr.to("cpu", non_blocking=True)
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
self._paged_kv_last_page_len_buf = last_page_len.to(
self.device, non_blocking=True
)
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=True)

indptr_host = indptr.to("cpu", non_blocking=True)
if data_type is not None:
q_data_type = data_type
kv_data_type = data_type
Expand Down Expand Up @@ -612,8 +614,8 @@ def plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr,
indptr,
qo_indptr_host,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
Expand All @@ -635,7 +637,7 @@ def plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/jit/batch_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

DecodePlanInfo plan_info;

Expand Down
4 changes: 2 additions & 2 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@

auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
qo_indptr = qo_indptr.to(torch::kCPU);
kv_indptr = kv_indptr.to(torch::kCPU);
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");

PrefillPlanInfo plan_info;

Expand Down
57 changes: 38 additions & 19 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def compile_single_prefill_module(
):
uri, path = gen_single_prefill_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand All @@ -68,7 +69,8 @@ def compile_batch_prefill_module(
):
uri, path = gen_batch_prefill_cu(*args)
return load_cuda_ops(
uri, [path],
uri,
[path],
verbose=verbose,
)

Expand Down Expand Up @@ -125,6 +127,7 @@ def get_batch_prefill_module(*args):
_batch_prefill_modules[args] = compile_batch_prefill_module(*args)
return _batch_prefill_modules[args]


def single_prefill_with_kv_cache_with_jit_module(
jit_module: Any,
q: torch.Tensor,
Expand All @@ -137,7 +140,8 @@ def single_prefill_with_kv_cache_with_jit_module(
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
out = jit_module.run(
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args)
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args
)
return out if return_lse else out[0]


Expand Down Expand Up @@ -726,10 +730,14 @@ def plan(
"The length of paged_kv_indices exceeds the allocated buffer size."
)

self._qo_indptr_buf.copy_(qo_indptr)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr)
self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices
self._paged_kv_last_page_len_buf.copy_(paged_kv_last_page_len)
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True)
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=True)
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
paged_kv_indices, non_blocking=True
)
self._paged_kv_last_page_len_buf.copy_(
paged_kv_last_page_len, non_blocking=True
)

if packed_custom_mask is not None:
if not torch.is_tensor(self._custom_mask_buf):
Expand All @@ -740,20 +748,31 @@ def plan(
raise ValueError(
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
)
self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
packed_custom_mask, non_blocking=True
)
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
self._qk_indptr_buf.copy_(qk_indptr)
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=True)
else:
self._qo_indptr_buf = qo_indptr.to(self.device)
self._paged_kv_indptr_buf = paged_kv_indptr.to(self.device)
self._paged_kv_indices_buf = paged_kv_indices.to(self.device)
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(self.device)
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True)
self._paged_kv_indptr_buf = paged_kv_indptr.to(
self.device, non_blocking=True
)
self._paged_kv_indices_buf = paged_kv_indices.to(
self.device, non_blocking=True
)
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
self.device, non_blocking=True
)
if packed_custom_mask is not None:
self._custom_mask_buf = packed_custom_mask.to(self.device)
self._qk_indptr_buf = qk_indptr.to(self.device)
self._custom_mask_buf = packed_custom_mask.to(
self.device, non_blocking=True
)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)

qo_indptr = qo_indptr.to("cpu", non_blocking=True)
paged_kv_indptr = paged_kv_indptr.to("cpu", non_blocking=True)
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu", non_blocking=True)
paged_kv_indptr_host = paged_kv_indptr.to("cpu", non_blocking=True)

if packed_custom_mask is not None:
mask_mode = MaskMode.CUSTOM.value
Expand Down Expand Up @@ -781,8 +800,8 @@ def plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr,
paged_kv_indptr,
qo_indptr_host,
paged_kv_indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
14 changes: 7 additions & 7 deletions python/flashinfer/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def plan(
num_blocks_row = len(indptr) - 1
qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32)
qo_indptr_host[-1] = M
qo_indptr = qo_indptr_host.to(indptr.device)
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=True)
if indices.max().item() * C > N:
raise ValueError("indices out of bound")
last_block_len = torch.full(
Expand All @@ -279,13 +279,13 @@ def plan(
mask.contiguous().view(-1), qk_indptr, bitorder="little"
)

self._qo_indptr = qo_indptr.to(self.device)
self._paged_kv_indptr_buf = indptr.to(self.device)
self._paged_kv_indices_buf = indices.to(self.device)
self._paged_kv_last_page_len = last_block_len.to(self.device)
self._qo_indptr = qo_indptr.to(self.device, non_blocking=True)
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
self._paged_kv_last_page_len = last_block_len.to(self.device, non_blocking=True)
if packed_mask is not None:
self._packed_mask_buf = packed_mask.to(self.device)
self._qk_indptr_buf = qk_indptr.to(self.device)
self._packed_mask_buf = packed_mask.to(self.device, non_blocking=True)
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
mask_mode = MaskMode.CUSTOM.value
else:
self._packed_mask_buf = None
Expand Down