Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Break up _kernels into multiple modules #428

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#include <flashinfer/decode_attention_decl.cuh>

#include "flashinfer_ops.h"
#include "flashinfer_ops_decode.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
#include <flashinfer/prefill_attention_decl.cuh>

#include "flashinfer_ops.h"
#include "flashinfer_ops_prefill.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;
Expand Down
37 changes: 0 additions & 37 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@
#include "flashinfer_ops.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
"Single-request decode with KV-Cache operator");
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
"Single-request prefill with KV-Cache operator, return logsumexp");
m.def(
"single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask,
"Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp");
m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator");
m.def("merge_state", &merge_state, "Merge two self-attention states");
m.def("merge_state_in_place", &merge_state_in_place,
Expand All @@ -50,36 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool, unsigned int>())
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<CutlassSegmentGEMMPyTorchWrapper>(m, "CutlassSegmentGEMMPyTorchWrapper")
.def(py::init<torch::Tensor>())
.def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer)
Expand Down
113 changes: 0 additions & 113 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,10 @@
#pragma once
#include <torch/extension.h>

#include <flashinfer/attention/handler.cuh>
#include <flashinfer/group_gemm/handler.cuh>
#include <flashinfer/layout.cuh>
#include <memory>

torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, unsigned int pos_encoding_mode,
unsigned int layout, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal,
unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);

std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask,
torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode,
bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
Expand Down Expand Up @@ -106,100 +87,6 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
torch::Tensor output_indptr, const std::string& bitorder);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, float logits_soft_cap,
torch::Tensor empty_q_data, torch::Tensor empty_kv_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
unsigned int pos_encoding_mode, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse);
BatchDecodeWithPagedKVCachePyTorchWrapper(
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
: handler_(handler_ptr), kv_layout_(kv_layout) {}
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph,
unsigned int fixed_batch_size)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(enable_cuda_graph,
fixed_batch_size)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor page_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned page_size, torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, bool causal,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
std::vector<torch::Tensor> ForwardCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask,
torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(enable_cuda_graph)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class BatchPrefillWithRaggedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
std::vector<torch::Tensor> ForwardCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(enable_cuda_graph)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class CutlassSegmentGEMMPyTorchWrapper {
public:
void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer);
Expand Down
32 changes: 32 additions & 0 deletions python/csrc/flashinfer_ops_decode.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>

#include "flashinfer_ops_decode.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
"Single-request decode with KV-Cache operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool, unsigned int>())
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
}
59 changes: 59 additions & 0 deletions python/csrc/flashinfer_ops_decode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

#include <flashinfer/attention/handler.cuh>
#include <flashinfer/layout.cuh>
#include <memory>

torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, unsigned int pos_encoding_mode,
unsigned int layout, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, float logits_soft_cap,
torch::Tensor empty_q_data, torch::Tensor empty_kv_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
unsigned int pos_encoding_mode, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse);
BatchDecodeWithPagedKVCachePyTorchWrapper(
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
: handler_(handler_ptr), kv_layout_(kv_layout) {}
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph,
unsigned int fixed_batch_size)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(enable_cuda_graph,
fixed_batch_size)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};
47 changes: 47 additions & 0 deletions python/csrc/flashinfer_ops_prefill.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>

#include "flashinfer_ops_prefill.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
"Single-request prefill with KV-Cache operator, return logsumexp");
m.def(
"single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask,
"Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp");
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
.def(py::init<unsigned int, bool>())
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask);
}
Loading