Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: speedup jit compilation of prefill attention kernels #632

Merged
merged 1 commit into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 123 additions & 6 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,102 @@
limitations under the License.
"""

import itertools

batch_prefill_suffix = [
"_plan.cu",
*[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
"_ragged_run.cu",
*[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
"_paged_run.cu",
"_pybind.cc",
]


def ragged_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>

namespace flashinfer {

{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
constexpr bool use_custom_mask = """
+ mask_mode
+ r""" == MaskMode::kCustom;
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;

template
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", RaggedAttentionVariant>(
typename RaggedAttentionVariant::ParamsT params,
typename RaggedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

template
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", RaggedAttentionVariant>(
typename RaggedAttentionVariant::ParamsT params,
typename RaggedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

template
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", RaggedAttentionVariant>(
typename RaggedAttentionVariant::ParamsT params,
typename RaggedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
}
"""
)


def paged_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>

namespace flashinfer {

{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
constexpr bool use_custom_mask = """
+ mask_mode
+ r""" == MaskMode::kCustom;
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;

template
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", PagedAttentionVariant>(
typename PagedAttentionVariant::ParamsT params,
typename PagedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

template
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", PagedAttentionVariant>(
typename PagedAttentionVariant::ParamsT params,
typename PagedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

template
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
+ mask_mode
+ r""", PagedAttentionVariant>(
typename PagedAttentionVariant::ParamsT params,
typename PagedAttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);
}
"""
)


batch_prefill_templ = [
r"""#include <flashinfer/attention/scheduler.cuh>
#include "pytorch_extension_utils.h"
Expand Down Expand Up @@ -60,10 +147,15 @@
return plan_info.ToVector();
}
""",
*[
ragged_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""
#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include "pytorch_extension_utils.h"
Expand All @@ -73,6 +165,16 @@
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;

namespace flashinfer {

template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

};

void BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code,
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down Expand Up @@ -153,7 +255,7 @@
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithRaggedKVCacheDispatched<
status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>(
params, tmp_v, tmp_s, stream);
});
Expand All @@ -162,9 +264,14 @@
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status));
}
""",
*[
paged_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include "pytorch_extension_utils.h"
Expand All @@ -174,6 +281,16 @@
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;

namespace flashinfer {

template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

};

void BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code,
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down Expand Up @@ -274,7 +391,7 @@
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
status = BatchPrefillWithPagedKVCacheDispatched<
status = flashinfer::BatchPrefillWithPagedKVCacheDispatched<
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>(
params, tmp_v, tmp_s, stream);
});
Expand Down
113 changes: 101 additions & 12 deletions python/flashinfer/jit/single_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,12 @@
"""

single_prefill_suffix = [
*[f"_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
".cu",
"_pybind.cc",
]

customizable_single_prefill_templ = [
r"""
#include <optional>
#include <flashinfer/attention/prefill.cuh>
#include "pytorch_extension_utils.h"

using namespace flashinfer;


customizable_struct_templ = r"""
struct SinglePrefillParams {
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
Expand Down Expand Up @@ -82,10 +75,63 @@
return kv_len;
}
};
"""


def customizable_single_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>

using namespace flashinfer;
"""
+ customizable_struct_templ
+ r"""{{ variant_decl }}
using ParamsT = SinglePrefillParams;
using AttentionVariant = {{ variant_name }}<ParamsT>;

namespace flashinfer {

template
cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, """
f"{mask_mode}"
r""", AttentionVariant>(
typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);

};
"""
)


customizable_single_prefill_templ = [
*[
customizable_single_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""
#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/mask.cuh>
#include "pytorch_extension_utils.h"

using namespace flashinfer;

"""
+ customizable_struct_templ
+ r"""
{{ variant_decl }}

namespace flashinfer {

template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);

}

at::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor tmp, at::Tensor o, unsigned int layout, int32_t window_left,
Expand Down Expand Up @@ -155,10 +201,43 @@
""",
]


def single_prefill_inst_templ(mask_mode: str) -> str:
return (
r"""#include <flashinfer/attention/prefill.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include <flashinfer/attention/variants.cuh>

namespace flashinfer {

{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
constexpr bool use_custom_mask = """
f"{mask_mode}"
r"""== MaskMode::kCustom;
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;

template
cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
f"{mask_mode}"
r""", AttentionVariant>(
typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);

}
"""
)


single_prefill_templ = [
r"""
#include <optional>
#include <flashinfer/attention/prefill.cuh>
*[
single_prefill_inst_templ(mask_mode)
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
],
r"""#include <optional>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/prefill_params.cuh>
#include "pytorch_extension_utils.h"
Expand All @@ -168,6 +247,16 @@
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;

namespace flashinfer {

template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp,
cudaStream_t stream);

}

void single_prefill_with_kv_cache(
unsigned int mask_mode_code,
at::Tensor q, at::Tensor k, at::Tensor v, std::optional<at::Tensor> maybe_packed_custom_mask,
Expand Down