Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support cuda graph for batched multi-query(prefill/append) attention #275

Merged
merged 6 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "../utils.cuh"
#include "../vec_dtypes.cuh"
#include "cascade.cuh"
#include "handler.cuh"
#include "state.cuh"

namespace flashinfer {
Expand Down
209 changes: 107 additions & 102 deletions include/flashinfer/attention/handler.cuh

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
#include "../pos_enc.cuh"
#include "../utils.cuh"
#include "cascade.cuh"
#include "handler.cuh"
#include "mask.cuh"
#include "state.cuh"

namespace flashinfer {

Expand Down
1 change: 0 additions & 1 deletion include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);

bool rejected = false;
uint32_t pos = 0;
for (pos = 0; pos < num_speculative_tokens; ++pos) {
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos];
Expand Down
74 changes: 22 additions & 52 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,32 +141,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
if (handler_->IsCUDAGraphMode()) {
// NOTE(Zihao): use runtime dispatch because template function is not virtual
auto cuda_graph_handler_ =
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
cudaGetErrorString(status));
} else {
cudaError_t status = handler_->BeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type,
nv_half, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
Expand All @@ -180,32 +165,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
if (handler_->IsCUDAGraphMode()) {
// NOTE(Zihao): use runtime dispatch because template function is not virtual
auto cuda_graph_handler_ =
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
cudaGetErrorString(status));
} else {
cudaError_t status = handler_->BeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
Expand Down
18 changes: 7 additions & 11 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int>())
.def(py::init<unsigned int, unsigned int, bool>())
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper>(
m, "CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int>())
.def("begin_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("update_page_locked_buffer_size",
&CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int>())
.def(py::init<unsigned int, unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int>())
.def(py::init<unsigned int, unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)
Expand Down
32 changes: 15 additions & 17 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
unsigned int pos_encoding_mode, torch::Tensor empty_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor paged_kv_data,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
Expand All @@ -92,32 +93,24 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
BatchDecodeWithPagedKVCachePyTorchWrapper(
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
: handler_(handler_ptr), kv_layout_(kv_layout) {}
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes)
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_batch_size,
bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}
handler_(
std::make_shared<flashinfer::BatchDecodeHandler>(max_batch_size, enable_cuda_graph)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper
: public BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_batch_size)
: BatchDecodeWithPagedKVCachePyTorchWrapper(
std::make_shared<flashinfer::CUDAGraphBatchDecodeHandler>(max_batch_size),
flashinfer::QKVLayout(layout)) {}
};

class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
Expand All @@ -133,9 +126,11 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes)
unsigned int max_workspace_size_in_bytes,
bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
enable_cuda_graph)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
Expand All @@ -148,6 +143,7 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
Expand All @@ -162,9 +158,11 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
bool allow_fp16_qk_reduction, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes)
unsigned int max_workspace_size_in_bytes,
bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
enable_cuda_graph)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
Expand Down
Loading