Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: use packed bit array for attention mask #308

Merged
merged 7 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
set(FLASHINFER_GEN_MASK_MODES 0 1)
set(FLASHINFER_GEN_MASK_MODES 0 1 2)

# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
Expand Down
14 changes: 14 additions & 0 deletions docs/api/python/quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _apiquantization:

flashinfer.quantization
=======================

Quantization related kernels.

.. currentmodule:: flashinfer.quantization

.. autosummary::
:toctree: _generate

packbits
segment_packbits
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ FlashInfer is a library for Language Languages Models that provides high-perform
api/python/sampling
api/python/group_gemm
api/python/norm
api/python/quantization
8 changes: 8 additions & 0 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,21 @@ to store the start offset of each request's mask in the flattened mask array: ``
``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened
mask of request ``i``.

To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements
are packed together as a `uint8`) with "little" bit-order (see `numpy.packbits <https://numpy.org/doc/stable/reference/generated/numpy.packbits.html>`_
for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed
array internally.

FlashInfer APIs
~~~~~~~~~~~~~~~

:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.

:meth:`flashinfer.quantization.packbits` and :meth:`flashinfer.quantization.segment_packbits` are the utility functions
to pack boolean mask into bit-packed array.

.. _page-layout:

Page Table
Expand Down
49 changes: 26 additions & 23 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ template <bool partition_kv, MaskMode mask_mode, uint32_t num_warps, uint32_t nu
__device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
const uint32_t kv_idx_base, const uint32_t qo_len,
const uint32_t kv_len, const uint32_t chunk_end,
const uint_fastdiv group_size, float* custom_mask,
const uint_fastdiv group_size, uint8_t* custom_mask,
DTypeQKAccum (*s_frag)[num_frags_z][8]) {
const uint32_t tx = threadIdx.x;
#pragma unroll
Expand All @@ -565,11 +565,11 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end))
: kv_idx >= chunk_end);
s_frag[fx][fz][reg_id] =
out_of_boundary ? DTypeQKAccum(-5e4)
: s_frag[fx][fz][reg_id] +
DTypeQKAccum((mask_mode == MaskMode::kCustom && q_idx < qo_len)
? custom_mask[q_idx * kv_len + kv_idx]
: 0.f);
(out_of_boundary ||
((mask_mode == MaskMode::kCustom && q_idx < qo_len &&
!(custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8)))))
? DTypeQKAccum(-5e4)
: s_frag[fx][fz][reg_id];
}
}
}
Expand Down Expand Up @@ -891,7 +891,7 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, MaskMode mask_mode
typename DTypeQKAccum, typename DTypeOut>
__global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k,
DTypeIn* __restrict__ v,
float* __restrict__ custom_mask,
uint8_t* __restrict__ custom_mask,
DTypeOut* __restrict__ o, void* __restrict__ tmp,
float* __restrict__ lse, const uint32_t qo_len,
const uint32_t kv_len, const uint_fastdiv group_size,
Expand Down Expand Up @@ -1081,9 +1081,10 @@ __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t qo_head_idx =
kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size;
const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size;
uint32_t q, r;
group_size.divmod(qo_packed_idx_base + tx / 4 + j * 8 + fx * 16, q, r);
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
const uint32_t qo_idx = q;
if (qo_idx < qo_len) {
if constexpr (partition_kv) {
float* tmp_lse =
Expand All @@ -1106,7 +1107,7 @@ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, QKVLayout kv_layo
__global__ void BatchPrefillWithRaggedKVCacheKernel(
DTypeIn* __restrict__ q, IdType* __restrict__ request_indices,
IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k,
DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, float* __restrict__ custom_mask,
DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask,
IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset,
IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp,
float* __restrict__ lse, uint32_t batch_size, const uint_fastdiv group_size, float sm_scale,
Expand Down Expand Up @@ -1303,9 +1304,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel(
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t qo_head_idx =
kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size;
const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size;
uint32_t q, r;
group_size.divmod(qo_packed_idx_base + tx / 4 + j * 8 + fx * 16, q, r);
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
const uint32_t qo_idx = q;
if (qo_idx < qo_len) {
lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] =
math::ptx_log2(d[fx][j]) + float(m[fx][j]);
Expand All @@ -1322,9 +1324,9 @@ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode p
__global__ void BatchPrefillWithPagedKVCacheKernel(
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr,
IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp,
float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale,
IdType* __restrict__ qo_indptr, uint8_t* __restrict__ custom_mask,
IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o,
float* __restrict__ tmp, float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale,
float log2_rope_rcp_scale, float log2_rope_rcp_theta) {
static_assert(sizeof(DTypeIn) == 2);
static_assert(sizeof(DTypeOut) == 2);
Expand Down Expand Up @@ -1515,9 +1517,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t qo_head_idx =
kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size;
const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size;
uint32_t q, r;
group_size.divmod(qo_packed_idx_base + tx / 4 + j * 8 + fx * 16, q, r);
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
const uint32_t qo_idx = q;
if (qo_idx < qo_upper_bound) {
lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] =
math::ptx_log2(d[fx][j]) + float(m[fx][j]);
Expand All @@ -1531,7 +1534,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
float* custom_mask, DTypeOut* o, float* tmp,
uint8_t* custom_mask, DTypeOut* o, float* tmp,
float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, float sm_scale, float rope_scale,
Expand Down Expand Up @@ -1671,7 +1674,7 @@ template <uint32_t num_frags_x, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HO
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size,
const uint32_t num_qo_heads, const uint32_t num_qo_tiles, const uint32_t num_kv_heads,
const float sm_scale, const float rope_scale, const float rope_theta,
Expand Down Expand Up @@ -1755,7 +1758,7 @@ template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, float* custom_mask,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads,
uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
float* custom_mask, DTypeOut* o, float* tmp,
uint8_t* custom_mask, DTypeOut* o, float* tmp,
float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, float sm_scale, float rope_scale,
Expand All @@ -43,7 +43,7 @@ template <uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HO
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size,
uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream = nullptr);
Expand All @@ -54,7 +54,7 @@ template <PageStorage PAGE_STORAGE, uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, float* custom_mask,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles,
uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);

Expand All @@ -63,7 +63,7 @@ template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POS
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, float* custom_mask,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
Expand Down Expand Up @@ -98,7 +98,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
Expand Down
114 changes: 114 additions & 0 deletions include/flashinfer/quantization.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_QUANTIZATION_CUH_
#define FLASHINFER_QUANTIZATION_CUH_
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>

#include <cub/cub.cuh>

#include "utils.cuh"

namespace flashinfer {
namespace quantization {

enum class BitOrder { kBig = 0U, kLittle = 1U };

#define DISPATCH_BITORDER(bitorder, BITORDER, ...) \
if (bitorder == BitOrder::kBig) { \
constexpr BitOrder BITORDER = BitOrder::kBig; \
__VA_ARGS__ \
} else { \
constexpr BitOrder BITORDER = BitOrder::kLittle; \
__VA_ARGS__ \
}

template <BitOrder BITORDER>
__global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_elements) {
int64_t start_offset = blockIdx.x * blockDim.x * 8, tx = threadIdx.x;
uint8_t ret = 0;
bool input_vec[8];
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage temp_storage;
BlockLoad(temp_storage)
.Load(input + start_offset, input_vec, num_elements - start_offset, /*default=*/0);

if constexpr (BITORDER == BitOrder::kBig) {
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |
(input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7];
} else {
ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) |
(input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0];
}
if (start_offset + tx * 8 < num_elements) output[start_offset / 8 + tx] = ret;
}

template <BitOrder BITORDER, typename IdType>
__global__ void SegmentPackBitsKernel(bool* input, uint8_t* output, IdType* input_indptr,
IdType* output_indptr) {
int64_t bx = blockIdx.x, tx = threadIdx.x;
bool input_vec[8];
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage temp_storage;
int64_t num_elements = input_indptr[bx + 1] - input_indptr[bx];
for (uint32_t start_offset = 0; start_offset < num_elements; start_offset += 8 * blockDim.x) {
uint8_t ret = 0;
BlockLoad(temp_storage)
.Load(input + input_indptr[bx] + start_offset, input_vec, num_elements - start_offset,
/*default=*/0);

if constexpr (BITORDER == BitOrder::kBig) {
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |
(input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7];
} else {
ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) |
(input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0];
}
if (start_offset + tx * 8 < num_elements)
output[output_indptr[bx] + start_offset / 8 + tx] = ret;
}
}

cudaError_t PackBits(bool* input, uint8_t* output, int64_t num_elements, BitOrder bitorder,
cudaStream_t stream) {
DISPATCH_BITORDER(bitorder, BITORDER, {
auto kernel = PackBitsKernel<BITORDER>;
const dim3 nthrs(256);
const dim3 nblks(ceil_div(num_elements, nthrs.x * 8));
void* args[] = {&input, &output, &num_elements};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
}

template <typename IdType>
cudaError_t SegmentPackBits(bool* input, uint8_t* output, IdType* input_indptr,
IdType* output_indptr, uint32_t batch_size, BitOrder bitorder,
cudaStream_t stream) {
DISPATCH_BITORDER(bitorder, BITORDER, {
auto kernel = SegmentPackBitsKernel<BITORDER, IdType>;
const dim3 nthrs(256);
const dim3 nblks(batch_size);
void* args[] = {&input, &output, &input_indptr, &output_indptr};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
}

} // namespace quantization
} // namespace flashinfer

#endif // FLASHINFER_QUANTIZATION_CUH_
Loading