From ef2569060e9dae4e0d8506350790a6de9697a3ae Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 15 Jun 2024 20:15:12 +0000 Subject: [PATCH 1/7] upd --- include/flashinfer/utils.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 849dae19..9775283b 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -18,8 +18,6 @@ #include #include -#include -#include #include #define STR_HELPER(x) #x From b17ea15c0efd4fa83fa6216f9cfaa0000cc24db5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 15 Jun 2024 20:17:08 +0000 Subject: [PATCH 2/7] fix --- include/flashinfer/utils.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 9775283b..849dae19 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -18,6 +18,8 @@ #include #include +#include +#include #include #define STR_HELPER(x) #x From 134469bf45e8e1062ebb6d24344391a1c84aaab6 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 15 Jun 2024 23:53:35 +0000 Subject: [PATCH 3/7] upd --- include/flashinfer/attention/prefill.cuh | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 737d96c6..e98972c6 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1081,9 +1081,10 @@ __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size; + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + tx / 4 + j * 8 + fx * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; if (qo_idx < qo_len) { if constexpr (partition_kv) { float* tmp_lse = @@ -1303,9 +1304,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size; + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + tx / 4 + j * 8 + fx * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; if (qo_idx < qo_len) { lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); @@ -1515,9 +1517,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size; + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + tx / 4 + j * 8 + fx * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; if (qo_idx < qo_upper_bound) { lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); From 347568899e6232d0cd933dea689751bd390c7f77 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Jun 2024 04:39:57 +0000 Subject: [PATCH 4/7] upd --- cmake/config.cmake | 2 +- docs/api/python/quantization.rst | 13 +++ docs/index.rst | 1 + include/flashinfer/quantization.cuh | 119 ++++++++++++++++++++++++++++ python/csrc/flashinfer_ops.cu | 2 + python/csrc/flashinfer_ops.h | 5 ++ python/csrc/quantization.cu | 64 +++++++++++++++ python/flashinfer/__init__.py | 1 + python/flashinfer/quantization.py | 75 ++++++++++++++++++ python/setup.py | 1 + python/tests/test_quantization.py | 53 +++++++++++++ src/bench_single_prefill.cu | 37 ++++++--- src/flashinfer_ops.cuh | 25 ++++++ 13 files changed, 388 insertions(+), 10 deletions(-) create mode 100644 docs/api/python/quantization.rst create mode 100644 include/flashinfer/quantization.cuh create mode 100644 python/csrc/quantization.cu create mode 100644 python/flashinfer/quantization.py create mode 100644 python/tests/test_quantization.py diff --git a/cmake/config.cmake b/cmake/config.cmake index 4854c5af..75ea4fc0 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -27,7 +27,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256) set(FLASHINFER_GEN_KV_LAYOUTS 0 1) set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2) set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true") -set(FLASHINFER_GEN_MASK_MODES 0 1) +set(FLASHINFER_GEN_MASK_MODES 0 1 2) # Set target cuda architectures for tests/benchmarks, defaults to native. # "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU. diff --git a/docs/api/python/quantization.rst b/docs/api/python/quantization.rst new file mode 100644 index 00000000..28be0826 --- /dev/null +++ b/docs/api/python/quantization.rst @@ -0,0 +1,13 @@ +.. _apiquantization: + +flashinfer.quantization +======================= + +Quantization related kernels. + +.. currentmodule:: flashinfer.quantization + +.. autosummary:: + :toctree: _generate + + packbits diff --git a/docs/index.rst b/docs/index.rst index 8851b7ff..f4b0b7a0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,3 +34,4 @@ FlashInfer is a library for Language Languages Models that provides high-perform api/python/sampling api/python/group_gemm api/python/norm + api/python/quantization diff --git a/include/flashinfer/quantization.cuh b/include/flashinfer/quantization.cuh new file mode 100644 index 00000000..9f7a5e1f --- /dev/null +++ b/include/flashinfer/quantization.cuh @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_QUANTIZATION_CUH_ +#define FLASHINFER_QUANTIZATION_CUH_ +#include +#include + +#include "utils.cuh" + +namespace flashinfer { +namespace quantization { + +enum class BitOrder { kBig = 0U, kLittle = 1U }; + +#define DISPATCH_BITORDER(bitorder, BITORDER, ...) \ + if (bitorder == BitOrder::kBig) { \ + constexpr BitOrder BITORDER = BitOrder::kBig; \ + __VA_ARGS__ \ + } else { \ + constexpr BitOrder BITORDER = BitOrder::kLittle; \ + __VA_ARGS__ \ + } + +template +__global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_elements) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + uint8_t ret = 0; + uint8_t input_vec[8]; + for (uint32_t i = 0; i < 8; ++i) { + input_vec[i] = 0; + } + if ((idx + 1) * 8 <= num_elements) { + *(uint2*)input_vec = *(uint2*)(input + idx * 8); + } else { +#pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + input_vec[i] = (idx * 8 + i < num_elements) ? input[idx * 8 + i] : false; + } + } + + if constexpr (BITORDER == BitOrder::kBig) { + ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) | + (input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7]; + } else { + ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) | + (input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0]; + } + output[idx] = ret; +} + +// NOTE(Zihao): this implementation is not efficient, but this kernel is not a bottleneck +// at the moment. We can optimize it later if needed. +template +__global__ void SegmentPackBitsKernel(bool* input, uint8_t* output, IdType* input_indptr, + IdType* output_indptr) { + int64_t bx = blockIdx.x, tx = threadIdx.x; + for (uint32_t j = tx; j < output_indptr[bx + 1] - output_indptr[bx]; j += blockDim.x) { + int64_t num_elements = input_indptr[bx + 1] - input_indptr[bx]; + uint8_t ret = 0; + uint8_t input_vec[8]; +#pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + input_vec[i] = (j * 8 + i < num_elements) ? input[input_indptr[bx] + j * 8 + i] : false; + } + + if constexpr (BITORDER == BitOrder::kBig) { + ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) | + (input_vec[4] << 3) | (input_vec[5] << 2) | (input_vec[6] << 1) | input_vec[7]; + } else { + ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) | + (input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0]; + } + output[output_indptr[bx] + j] = ret; + } +} + +cudaError_t PackBits(bool* input, uint8_t* output, int64_t num_elements, BitOrder bitorder, + cudaStream_t stream) { + DISPATCH_BITORDER(bitorder, BITORDER, { + auto kernel = PackBitsKernel; + const dim3 nthrs(256); + const dim3 nblks(ceil_div(num_elements, nthrs.x * 8)); + void* args[] = {&input, &output, &num_elements}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t SegmentPackBits(bool* input, uint8_t* output, IdType* input_indptr, + IdType* output_indptr, uint32_t batch_size, BitOrder bitorder, + cudaStream_t stream) { + DISPATCH_BITORDER(bitorder, BITORDER, { + auto kernel = SegmentPackBitsKernel; + const dim3 nthrs(256); + const dim3 nblks(batch_size); + void* args[] = {&input, &output, &input_indptr, &output_indptr}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + return cudaSuccess; +} + +} // namespace quantization +} // namespace flashinfer + +#endif // FLASHINFER_QUANTIZATION_CUH_ \ No newline at end of file diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 91174d62..f3a5f62d 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); py::class_(m, "BatchDecodeWithPagedKVCachePyTorchWrapper") .def(py::init()) diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index bbfebd21..707f174b 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -74,6 +74,11 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); +torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); + +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, torch::Tensor output_indptr, + const std::string& bitorder); + class BatchDecodeWithPagedKVCachePyTorchWrapper { public: void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, diff --git a/python/csrc/quantization.cu b/python/csrc/quantization.cu new file mode 100644 index 00000000..90c0b1fa --- /dev/null +++ b/python/csrc/quantization.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) { + CHECK_INPUT(x); + TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'"); + x = x.to(torch::kBool); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + + int64_t num_elements = x.numel(); + int64_t num_output_elements = (num_elements + 7) / 8; + + auto y = torch::empty({num_output_elements}, x.options().dtype(torch::kUInt8)); + + cudaError_t status = quantization::PackBits( + static_cast(x.data_ptr()), static_cast(y.data_ptr()), num_elements, + bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, + torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, + "PackBits failed with error code " + std::string(cudaGetErrorString(status))); + return y; +} + +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, + torch::Tensor output_indptr, const std::string& bitorder) { + CHECK_INPUT(x); + CHECK_INPUT(input_indptr); + CHECK_INPUT(output_indptr); + TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'"); + unsigned int batch_size = input_indptr.size(0) - 1; + CHECK_EQ(output_indptr.size(0), batch_size + 1); + input_indptr = input_indptr.to(torch::kInt32); + output_indptr = output_indptr.to(torch::kInt32); + int64_t output_nnz = output_indptr[batch_size].item(); + auto y = torch::empty({output_nnz}, x.options().dtype(torch::kUInt8)); + + cudaError_t status = quantization::SegmentPackBits( + static_cast(x.data_ptr()), static_cast(y.data_ptr()), + static_cast(input_indptr.data_ptr()), + static_cast(output_indptr.data_ptr()), batch_size, + bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, + c10::cuda::getCurrentCUDAStream()); + return y; +} \ No newline at end of file diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 9dbf9dbe..2a3b3c9c 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -46,6 +46,7 @@ ) from .norm import rmsnorm from .group_gemm import SegmentGEMMWrapper +from .quantization import packbits, segment_packbits try: from ._build_meta import __version__ as __version__ diff --git a/python/flashinfer/quantization.py b/python/flashinfer/quantization.py new file mode 100644 index 00000000..ca48133a --- /dev/null +++ b/python/flashinfer/quantization.py @@ -0,0 +1,75 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + +def packbits(x: torch.Tensor, bitorder: str = "big"): + r"""Pack the elements of a binary-valued array into bits in a uint8 array. + + See `numpy.packbits `_ for more details. + + Parameters + ---------- + x: torch.Tensor + The 1D binary-valued array to pack. + bitorder: str + The bit-order ("bit"/"little") of the output. Default is "big". + + Returns + ------- + y: torch.Tensor + An uint8 packed array, shape ``((x.size(0) + 7) / 8),)``. + """ + return _kernels.packbits(x, bitorder) + +def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str="big"): + r"""Pack a batch elements of a binary-valued array into bits in a uint8 array. + + Parameters + ---------- + x: torch.Tensor + The 1D binary-valued array to pack. + indptr: torch.Tensor + The index pointer of the first element of each segment in :attr:`x`. + The i-th segment in :attr:`x` is ``x[indptr[i]:indptr[i+1]]``. + bitorder: str + The bit-order ("bit"/"little") of the output. Default is "big". + + Returns + ------- + y: torch.Tensor + An uint8 packed array, shape: ``(new_indptr[-1],)``. + The ``y[new_indptr[i]:new_indptr[i+1]]`` contains the packed bits ``x[indptr[i]:indptr[i+1]]``. + new_indptr: torch.Tensor + The new index pointer of the first element of each packed segment in :attr:`y`. + It's guaranteed that ``new_indptr[i+1] - new_indptr[i] == (indptr[i+1] - indptr[i] + 7) // 8``. + """ + seglen = indptr[1:] - indptr[:-1] + packed_len = (seglen + 7) // 8 + indptr_new = torch.empty(len(indptr) + 1, dtype=indptr.dtype, device=indptr.device) + indptr_new[1:] = torch.cumsum(packed_len, 0) diff --git a/python/setup.py b/python/setup.py index 33c1dd31..52c575b5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -345,6 +345,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/sampling.cu", "csrc/norm.cu", "csrc/group_gemm.cu", + "csrc/quantization.cu", ] + get_instantiation_cu(), include_dirs=[ diff --git a/python/tests/test_quantization.py b/python/tests/test_quantization.py new file mode 100644 index 00000000..5918bc91 --- /dev/null +++ b/python/tests/test_quantization.py @@ -0,0 +1,53 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import numpy +import torch +import pytest +import flashinfer + +def numpy_packbits_ref(x_cpu: torch.Tensor, bitorder: str): + x_np = x_cpu.numpy() + x_packed = numpy.packbits(x_np, bitorder=bitorder) + return torch.tensor(x_packed) + + +@pytest.mark.parametrize("num_elements", [1, 10, 99, 128, 999, 5000, 131072, 999999]) +@pytest.mark.parametrize("bitorder", ["big", "little"]) +def test_packbits(num_elements, bitorder): + x_cpu = torch.rand(num_elements) < 0.5 + x_gpu = x_cpu.to(0) + x_packed_ref = numpy_packbits_ref(x_cpu, bitorder) + x_packed = flashinfer.packbits(x_gpu, bitorder) + + assert torch.equal(x_packed_ref.cpu(), x_packed.cpu()) + +@pytest.mark.parametrize("batch_size", [1, 10, 99, 128, 777, 999]) +@pytest.mark.parametrize("bitorder", ["big", "little"]) +def test_segment_packbits(batch_size, bitorder): + old_indptr = torch.cumsum(torch.arange(batch_size + 1), 0) + num_elements = old_indptr[-1].item() + x_cpu = torch.rand(num_elements) < 0.5 + x_gpu = x_cpu.to(0) + + y_gpu, new_indptr = flashinfer.segment_packbits(x_gpu, old_indptr, bitorder) + + for i in range(batch_size): + x_segment_i = x_gpu[old_indptr[i]:old_indptr[i+1]] + y_segment_i_ref = flashinfer.packbits(x_segment_i, bitorder) + assert torch.equal(y_gpu[new_indptr[i]:new_indptr[i+1]], y_segment_i_ref) + + \ No newline at end of file diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 6028e19e..56af6af0 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include @@ -39,11 +40,13 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { size_t kv_layout = state.get_int64("kv_layout"); bool causal = state.get_int64("causal"); bool cooperative = state.get_int64("cooperative"); + bool custom_mask = state.get_int64("custom_mask"); bool allow_fp16_qk_reduction = state.get_int64("allow_fp16_qk_reduction"); // Allocate input data: thrust::device_vector Q(qo_len * num_qo_heads * head_dim); thrust::device_vector K(kv_len * num_kv_heads * head_dim); thrust::device_vector V(kv_len * num_kv_heads * head_dim); + thrust::device_vector mask(qo_len * kv_len); thrust::device_vector O(qo_len * num_qo_heads * head_dim); thrust::device_vector tmp(8 * 1024 * 1024); @@ -54,15 +57,29 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); - cudaError_t status = flashinfer::SinglePrefillWithKVCache( - thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), - thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), - /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, - /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, - QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, - /*maybe_sm_scale=*/std::nullopt, - /*rope_scale=*/1.f, - /*rope_theta=*/1e4, launch.get_stream()); + cudaError_t status; + if (custom_mask) { + status = flashinfer::SinglePrefillWithKVCacheCustomMask( + thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), + thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(mask.data()), + thrust::raw_pointer_cast(O.data()), + /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, + /*maybe_sm_scale=*/std::nullopt, + /*rope_scale=*/1.f, + /*rope_theta=*/1e4, launch.get_stream()); + } else { + status = flashinfer::SinglePrefillWithKVCache( + thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), + thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), + /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, + /*maybe_sm_scale=*/std::nullopt, + /*rope_scale=*/1.f, + /*rope_theta=*/1e4, launch.get_stream()); + } if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } @@ -99,6 +116,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("kv_layout", {0, 1}) \ .add_int64_axis("pos_encoding_mode", {0, 1}) \ .add_int64_axis("allow_fp16_qk_reduction", {0, 1}) \ + .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {1}) #define BENCH_FLASHINFER_APPEND_PREFILL(dtype_in, dtype_out) \ @@ -115,6 +133,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("kv_layout", {0, 1}) \ .add_int64_axis("pos_encoding_mode", {0, 1}) \ .add_int64_axis("allow_fp16_qk_reduction", {0, 1}) \ + .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {0, 1}) BENCH_FLASHINFER_PREFILL(half, half); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index c7cb4759..394d114f 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -18,10 +18,35 @@ #include #include "flashinfer/attention/logits_post_hook.cuh" +#include "flashinfer/attention/mask.cuh" #include "utils.h" namespace flashinfer { +template +cudaError_t SinglePrefillWithKVCacheCustomMask( + DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + return SinglePrefillWithKVCacheDispatched< + HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MaskMode::kCustom>( + q, k, v, custom_mask, o, tmp, lse, num_qo_heads, num_kv_heads, qo_len, kv_len, + sm_scale, rope_scale, rope_theta, stream); + })})})}); + return cudaSuccess; +} + /*! * \brief FlashAttention prefill CUDA function for a single request. * \tparam DTypeIn The data type of input From d6a8cfb637deac40f3c3bcfc05238ecf212fc3e6 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Jun 2024 06:15:07 +0000 Subject: [PATCH 5/7] wip --- include/flashinfer/quantization.cuh | 43 +++++++++++++---------------- python/csrc/flashinfer_ops.h | 4 +-- python/flashinfer/quantization.py | 11 +++++--- python/tests/test_quantization.py | 19 +++++++++---- 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/include/flashinfer/quantization.cuh b/include/flashinfer/quantization.cuh index 9f7a5e1f..780e2b99 100644 --- a/include/flashinfer/quantization.cuh +++ b/include/flashinfer/quantization.cuh @@ -18,6 +18,8 @@ #include #include +#include + #include "utils.cuh" namespace flashinfer { @@ -36,20 +38,13 @@ enum class BitOrder { kBig = 0U, kLittle = 1U }; template __global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_elements) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t start_offset = blockIdx.x * blockDim.x * 8, tx = threadIdx.x; uint8_t ret = 0; - uint8_t input_vec[8]; - for (uint32_t i = 0; i < 8; ++i) { - input_vec[i] = 0; - } - if ((idx + 1) * 8 <= num_elements) { - *(uint2*)input_vec = *(uint2*)(input + idx * 8); - } else { -#pragma unroll - for (uint32_t i = 0; i < 8; ++i) { - input_vec[i] = (idx * 8 + i < num_elements) ? input[idx * 8 + i] : false; - } - } + bool input_vec[8]; + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage temp_storage; + BlockLoad(temp_storage) + .Load(input + start_offset, input_vec, num_elements - start_offset, /*default=*/0); if constexpr (BITORDER == BitOrder::kBig) { ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) | @@ -58,23 +53,22 @@ __global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_element ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) | (input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0]; } - output[idx] = ret; + if (start_offset + tx * 8 < num_elements) output[start_offset / 8 + tx] = ret; } -// NOTE(Zihao): this implementation is not efficient, but this kernel is not a bottleneck -// at the moment. We can optimize it later if needed. template __global__ void SegmentPackBitsKernel(bool* input, uint8_t* output, IdType* input_indptr, IdType* output_indptr) { int64_t bx = blockIdx.x, tx = threadIdx.x; - for (uint32_t j = tx; j < output_indptr[bx + 1] - output_indptr[bx]; j += blockDim.x) { - int64_t num_elements = input_indptr[bx + 1] - input_indptr[bx]; + bool input_vec[8]; + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage temp_storage; + int64_t num_elements = input_indptr[bx + 1] - input_indptr[bx]; + for (uint32_t start_offset = 0; start_offset < num_elements; start_offset += 8 * blockDim.x) { uint8_t ret = 0; - uint8_t input_vec[8]; -#pragma unroll - for (uint32_t i = 0; i < 8; ++i) { - input_vec[i] = (j * 8 + i < num_elements) ? input[input_indptr[bx] + j * 8 + i] : false; - } + BlockLoad(temp_storage) + .Load(input + input_indptr[bx] + start_offset, input_vec, num_elements - start_offset, + /*default=*/0); if constexpr (BITORDER == BitOrder::kBig) { ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) | @@ -83,7 +77,8 @@ __global__ void SegmentPackBitsKernel(bool* input, uint8_t* output, IdType* inpu ret = (input_vec[7] << 7) | (input_vec[6] << 6) | (input_vec[5] << 5) | (input_vec[4] << 4) | (input_vec[3] << 3) | (input_vec[2] << 2) | (input_vec[1] << 1) | input_vec[0]; } - output[output_indptr[bx] + j] = ret; + if (start_offset + tx * 8 < num_elements) + output[output_indptr[bx] + start_offset / 8 + tx] = ret; } } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 707f174b..2f838615 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -76,8 +76,8 @@ torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, torch::Tensor output_indptr, - const std::string& bitorder); +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, + torch::Tensor output_indptr, const std::string& bitorder); class BatchDecodeWithPagedKVCachePyTorchWrapper { public: diff --git a/python/flashinfer/quantization.py b/python/flashinfer/quantization.py index ca48133a..602c515d 100644 --- a/python/flashinfer/quantization.py +++ b/python/flashinfer/quantization.py @@ -28,9 +28,10 @@ else: raise e + def packbits(x: torch.Tensor, bitorder: str = "big"): r"""Pack the elements of a binary-valued array into bits in a uint8 array. - + See `numpy.packbits `_ for more details. Parameters @@ -47,7 +48,8 @@ def packbits(x: torch.Tensor, bitorder: str = "big"): """ return _kernels.packbits(x, bitorder) -def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str="big"): + +def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str = "big"): r"""Pack a batch elements of a binary-valued array into bits in a uint8 array. Parameters @@ -59,7 +61,7 @@ def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str="big") The i-th segment in :attr:`x` is ``x[indptr[i]:indptr[i+1]]``. bitorder: str The bit-order ("bit"/"little") of the output. Default is "big". - + Returns ------- y: torch.Tensor @@ -71,5 +73,6 @@ def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str="big") """ seglen = indptr[1:] - indptr[:-1] packed_len = (seglen + 7) // 8 - indptr_new = torch.empty(len(indptr) + 1, dtype=indptr.dtype, device=indptr.device) + indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device) indptr_new[1:] = torch.cumsum(packed_len, 0) + return _kernels.segment_packbits(x, indptr, indptr_new, bitorder), indptr_new diff --git a/python/tests/test_quantization.py b/python/tests/test_quantization.py index 5918bc91..c727ca43 100644 --- a/python/tests/test_quantization.py +++ b/python/tests/test_quantization.py @@ -19,6 +19,7 @@ import pytest import flashinfer + def numpy_packbits_ref(x_cpu: torch.Tensor, bitorder: str): x_np = x_cpu.numpy() x_packed = numpy.packbits(x_np, bitorder=bitorder) @@ -28,6 +29,7 @@ def numpy_packbits_ref(x_cpu: torch.Tensor, bitorder: str): @pytest.mark.parametrize("num_elements", [1, 10, 99, 128, 999, 5000, 131072, 999999]) @pytest.mark.parametrize("bitorder", ["big", "little"]) def test_packbits(num_elements, bitorder): + torch.manual_seed(42) x_cpu = torch.rand(num_elements) < 0.5 x_gpu = x_cpu.to(0) x_packed_ref = numpy_packbits_ref(x_cpu, bitorder) @@ -35,19 +37,24 @@ def test_packbits(num_elements, bitorder): assert torch.equal(x_packed_ref.cpu(), x_packed.cpu()) + @pytest.mark.parametrize("batch_size", [1, 10, 99, 128, 777, 999]) @pytest.mark.parametrize("bitorder", ["big", "little"]) def test_segment_packbits(batch_size, bitorder): - old_indptr = torch.cumsum(torch.arange(batch_size + 1), 0) - num_elements = old_indptr[-1].item() + torch.manual_seed(42) + old_indptr = torch.cumsum(torch.arange(batch_size + 1), 0).to(0) + num_elements = old_indptr[-1].item() x_cpu = torch.rand(num_elements) < 0.5 x_gpu = x_cpu.to(0) y_gpu, new_indptr = flashinfer.segment_packbits(x_gpu, old_indptr, bitorder) - + for i in range(batch_size): - x_segment_i = x_gpu[old_indptr[i]:old_indptr[i+1]] + x_segment_i = x_gpu[old_indptr[i] : old_indptr[i + 1]] y_segment_i_ref = flashinfer.packbits(x_segment_i, bitorder) - assert torch.equal(y_gpu[new_indptr[i]:new_indptr[i+1]], y_segment_i_ref) + assert torch.equal(y_gpu[new_indptr[i] : new_indptr[i + 1]], y_segment_i_ref) + - \ No newline at end of file +if __name__ == "__main__": + test_packbits(999999, "big") + test_segment_packbits(77, "little") From 1f05cd7665df80068b7ce1b9ee3e3734cdf81031 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Jun 2024 07:49:07 +0000 Subject: [PATCH 6/7] upd --- docs/api/python/quantization.rst | 1 + docs/tutorials/kv_layout.rst | 8 + include/flashinfer/attention/prefill.cuh | 28 +-- include/flashinfer/prefill_attention_decl.cuh | 10 +- python/csrc/batch_prefill.cu | 4 +- python/csrc/flashinfer_ops.h | 23 +- python/csrc/single_prefill.cu | 15 +- python/flashinfer/prefill.py | 199 ++++++++++++------ python/flashinfer/quantization.py | 20 +- python/generate_batch_paged_prefill_inst.py | 2 +- python/generate_batch_ragged_prefill_inst.py | 2 +- python/generate_single_prefill_inst.py | 2 +- python/tests/test_batch_prefill_kernels.py | 12 +- src/bench_single_prefill.cu | 4 +- src/flashinfer_ops.cuh | 2 +- 15 files changed, 207 insertions(+), 125 deletions(-) diff --git a/docs/api/python/quantization.rst b/docs/api/python/quantization.rst index 28be0826..d284d32c 100644 --- a/docs/api/python/quantization.rst +++ b/docs/api/python/quantization.rst @@ -11,3 +11,4 @@ Quantization related kernels. :toctree: _generate packbits + segment_packbits diff --git a/docs/tutorials/kv_layout.rst b/docs/tutorials/kv_layout.rst index 973aa7c8..c29edcff 100644 --- a/docs/tutorials/kv_layout.rst +++ b/docs/tutorials/kv_layout.rst @@ -75,6 +75,11 @@ to store the start offset of each request's mask in the flattened mask array: `` ``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened mask of request ``i``. +To save memory, we can further packes the boolean flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements +are packed together as a `uint8`) with "little" bit-order (see `numpy.packbits `_ +for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed +array internally. + FlashInfer APIs ~~~~~~~~~~~~~~~ @@ -82,6 +87,9 @@ FlashInfer APIs allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions, the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel. +:meth:`flashinfer.quantization.packbits` and :meth:`flashinfer.quantization.segment_packbits` are the utility functions +to pack boolean mask into bit-packed array. + .. _page-layout: Page Table diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index e98972c6..4b772831 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -547,7 +547,7 @@ template kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) : kv_idx >= chunk_end); s_frag[fx][fz][reg_id] = - out_of_boundary ? DTypeQKAccum(-5e4) - : s_frag[fx][fz][reg_id] + - DTypeQKAccum((mask_mode == MaskMode::kCustom && q_idx < qo_len) - ? custom_mask[q_idx * kv_len + kv_idx] - : 0.f); + (out_of_boundary || + ((mask_mode == MaskMode::kCustom && q_idx < qo_len && + !(custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8))))) + ? DTypeQKAccum(-5e4) + : s_frag[fx][fz][reg_id]; } } } @@ -891,7 +891,7 @@ template __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, - float* __restrict__ custom_mask, + uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, void* __restrict__ tmp, float* __restrict__ lse, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, @@ -1107,7 +1107,7 @@ template paged_kv, - IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, - IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, - float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale, + IdType* __restrict__ qo_indptr, uint8_t* __restrict__ custom_mask, + IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, + float* __restrict__ tmp, float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); @@ -1534,7 +1534,7 @@ template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, - float* custom_mask, DTypeOut* o, float* tmp, + uint8_t* custom_mask, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, @@ -1674,7 +1674,7 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, @@ -1758,7 +1758,7 @@ template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, float* custom_mask, + paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 54da43e0..47257864 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -32,7 +32,7 @@ template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, - float* custom_mask, DTypeOut* o, float* tmp, + uint8_t* custom_mask, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, @@ -43,7 +43,7 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size, uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr); @@ -54,7 +54,7 @@ template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, float* custom_mask, + paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -63,7 +63,7 @@ template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, float* custom_mask, + paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; @@ -98,7 +98,7 @@ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, + IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 035de3ef..9a25a128 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -232,7 +232,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), + static_cast(custom_mask.data_ptr()), static_cast(qk_indptr.data_ptr()), static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, @@ -434,7 +434,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), - static_cast(custom_mask.data_ptr()), + static_cast(custom_mask.data_ptr()), static_cast(qk_indptr.data_ptr()), /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 2f838615..0a64b324 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -33,8 +33,8 @@ std::vector single_prefill_with_kv_cache( bool return_lse); std::vector single_prefill_with_kv_cache_custom_mask( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, - unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, + torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); @@ -127,9 +127,10 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { std::vector ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); + torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask, + torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), handler_(std::make_shared(enable_cuda_graph)) {} @@ -152,13 +153,11 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - std::vector ForwardCustomMask(torch::Tensor q, torch::Tensor qo_indptr, - torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, torch::Tensor custom_mask, - torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool logits_cap, - bool allow_fp16_qk_reduction, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); + std::vector ForwardCustomMask( + torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, + torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, + unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), handler_(std::make_shared(enable_cuda_graph)) {} diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index d1ee0bad..bedeac3c 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -102,20 +102,23 @@ std::vector single_prefill_with_kv_cache( } std::vector single_prefill_with_kv_cache_custom_mask( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, - unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, + torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, + torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); - CHECK_INPUT(custom_mask); + CHECK_INPUT(packed_custom_mask); CHECK_DIM(3, q); CHECK_DIM(3, k); CHECK_DIM(3, v); - CHECK_DIM(2, custom_mask); + CHECK_DIM(1, packed_custom_mask); CHECK_SHAPE(k, v); CHECK_EQ(q.size(2), k.size(2)); + // packed_custom_mask must be uint8 + TORCH_CHECK(packed_custom_mask.scalar_type() == torch::kUInt8, + "packed_custom_mask must be uint8"); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; QKVLayout kv_layout = static_cast(layout); @@ -130,8 +133,6 @@ std::vector single_prefill_with_kv_cache_custom_mask( num_kv_heads = k.size(0); num_qo_heads = q.size(0); } - CHECK_EQ(custom_mask.size(0), qo_len); - CHECK_EQ(custom_mask.size(1), kv_len); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); auto o = torch::empty_like(q, q.options()); @@ -157,7 +158,7 @@ std::vector single_prefill_with_kv_cache_custom_mask( ALLOW_FP16_QK_REDUCTION, MASK_MODE>( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(custom_mask.data_ptr()), + static_cast(packed_custom_mask.data_ptr()), static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index cc501b50..db6b8afe 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -39,6 +39,7 @@ check_kv_layout, is_float8, ) +from .quantization import packbits, segment_packbits _cache_buf = {} @@ -58,6 +59,7 @@ def single_prefill_with_kv_cache( k: torch.Tensor, v: torch.Tensor, custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -83,9 +85,17 @@ def single_prefill_with_kv_cache( is ``NHD``, ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is ``HND``. custom_mask : Optional[torch.Tensor] - The custom mask tensor, shape: ``[qo_len, kv_len]``. - If provided, the custom mask will be added to the attention matrix before - softmax and after scaling, and the :attr:`causal` parameter will be ignored. + The custom boolean mask tensor, shape: ``[qo_len, kv_len]``. + The elements in the mask tensor should be either ``True`` or ``False``, + where ``False`` means the corresponding element in the attention matrix will be + masked out. + + When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the + function will pack the custom mask tensor into a 1D packed mask tensor, which introduces + additional overhead. + packed_custom_mask : Optional[torch.Tensor] + The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. + The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. causal : bool Whether to apply causal mask to the attention matrix. This is only effective when :attr:`custom_mask` is not provided. @@ -132,18 +142,18 @@ def single_prefill_with_kv_cache( allow_fp16_qk_reduction=True) >>> o.shape torch.Size([128, 32, 128]) - >>> mask = torch.triu( - >>> torch.full((qo_len, kv_len), -float("inf"), dtype=torch.float32, device="cuda:0"), - >>> diagonal=(kv_len - qo_len + 1), + >>> mask = torch.tril( + >>> torch.full((qo_len, kv_len), True, device="cuda:0"), + >>> diagonal=(kv_len - qo_len), >>> ) >>> mask - tensor([[0., 0., 0., ..., -inf, -inf, -inf], - [0., 0., 0., ..., -inf, -inf, -inf], - [0., 0., 0., ..., -inf, -inf, -inf], - ..., - [0., 0., 0., ..., 0., -inf, -inf], - [0., 0., 0., ..., 0., 0., -inf], - [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0') + tensor([[ True, True, True, ..., False, False, False], + [ True, True, True, ..., False, False, False], + [ True, True, True, ..., False, False, False], + ..., + [ True, True, True, ..., True, False, False], + [ True, True, True, ..., True, True, False], + [ True, True, True, ..., True, True, True]], device='cuda:0') >>> o_custom = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) >>> torch.allclose(o, o_custom, rtol=1e-3, atol=1e-3) True @@ -163,12 +173,17 @@ def single_prefill_with_kv_cache( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - if custom_mask is not None: + if custom_mask is not None and packed_custom_mask is None: + # create packed custom mask from custom mask + packed_custom_mask = packbits( + custom_mask.contiguous().view(-1), bitorder="little" + ) + if packed_custom_mask is not None: return _kernels.single_prefill_with_kv_cache_custom_mask( q, k, v, - custom_mask, + packed_custom_mask, tmp, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, @@ -227,9 +242,17 @@ def single_prefill_with_kv_cache_return_lse( is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is ``HND``. custom_mask : Optional[torch.Tensor] - The custom_mask tensor, shape: ``[qo_len, kv_len]``. - If provided, the custom mask will be added to the attention matrix before - softmax and after scaling, and the :attr:`causal` parameter will be ignored. + The custom bool mask tensor, shape: ``[qo_len, kv_len]``. + The elements in the mask tensor should be either ``True`` or ``False``, + where ``False`` means the corresponding element in the attention matrix will be + masked out. + + When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the + function will pack the custom mask tensor into a 1D packed mask tensor, which introduces + additional overhead. + packed_custom_mask : Optional[torch.Tensor] + The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. + The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. causal : bool Whether to apply causal mask to the attention matrix. This is only effective when :attr:`custom_mask` is not provided. @@ -279,18 +302,18 @@ def single_prefill_with_kv_cache_return_lse( torch.Size([128, 32, 128]) >>> S.shape torch.Size([128, 32]) - >>> mask = torch.triu( - >>> torch.full((qo_len, kv_len), -float("inf"), dtype=torch.float32, device="cuda:0"), - >>> diagonal=(kv_len - qo_len + 1), + >>> mask = torch.tril( + >>> torch.full((qo_len, kv_len), True, device="cuda:0"), + >>> diagonal=(kv_len - qo_len), >>> ) >>> mask - tensor([[0., 0., 0., ..., -inf, -inf, -inf], - [0., 0., 0., ..., -inf, -inf, -inf], - [0., 0., 0., ..., -inf, -inf, -inf], - ..., - [0., 0., 0., ..., 0., -inf, -inf], - [0., 0., 0., ..., 0., 0., -inf], - [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0') + tensor([[ True, True, True, ..., False, False, False], + [ True, True, True, ..., False, False, False], + [ True, True, True, ..., False, False, False], + ..., + [ True, True, True, ..., True, False, False], + [ True, True, True, ..., True, True, False], + [ True, True, True, ..., True, True, True]], device='cuda:0') >>> V_custom, S_custom = flashinfer.single_prefill_with_kv_cache_return_lse(q, k, v, custom_mask=mask) >>> torch.allclose(V, V_custom, rtol=1e-3, atol=1e-3) True @@ -325,12 +348,17 @@ def single_prefill_with_kv_cache_return_lse( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - if custom_mask is not None: + if custom_mask is not None and packed_custom_mask is None: + # convert custom mask to packed mask + packed_custom_mask = packbits( + custom_mask.contiguous().view(-1), bitorder="little" + ) + if packed_custom_mask is not None: return _kernels.single_prefill_with_kv_cache_custom_mask( q, k, v, - custom_mask, + packed_custom_mask, tmp, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, @@ -451,9 +479,9 @@ class BatchPrefillWithPagedKVCacheWrapper: >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (page_size * (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) + paged_kv_last_page_len).cpu().tolist() >>> for i in range(batch_size): - ... mask_i = torch.triu( - ... torch.full((qo_len[i], kv_len[i]), -float("inf"), dtype=torch.float32, device="cuda:0"), - ... diagonal=(kv_len[i] - qo_len[i] + 1), + ... mask_i = torch.tril( + ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), + ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i) ... @@ -543,8 +571,8 @@ def __init__( custom_mask_buf : Optional[torch.Tensor] The user reserved buffer to store the custom mask tensor, should be large enough to - store the maximum possible size of the custom mask tensor during the lifetime of the - wrapper. This argument is only effective when ``enable_cuda_graph`` is set to ``True`` + store the maximum possible size of the packed custom mask tensor during the lifetime of + the wrapper. This argument is only effective when ``enable_cuda_graph`` is set to ``True`` and the custom mask will be used in attention computation. qk_indptr_buf : Optional[torch.Tensor] @@ -623,6 +651,7 @@ def begin_forward( head_dim: int, page_size: int, custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -647,13 +676,21 @@ def begin_forward( page_size : int The size of each page in the paged kv-cache. custom_mask : Optional[torch.Tensor] - The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. - If provided, the custom mask will be added to the attention matrix before softmax - and after scaling. The mask tensor should be in the same device as the input tensors. + The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. + The elements in the mask tensor should be either ``True`` or ``False``, + where ``False`` means the corresponding element in the attention matrix will be + masked out. Please refer to the :ref:`mask layout ` for more details about flattened layout of mask tensor. + When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the + function will pack the custom mask tensor into a 1D packed mask tensor, which introduces + additional overhead. + packed_custom_mask : Optional[torch.Tensor] + The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. + The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + Notes ----- The :meth:`begin_forward` method should be called before any :meth:`forward` or @@ -666,6 +703,21 @@ def begin_forward( """ batch_size = len(qo_indptr) - 1 + if custom_mask is not None or packed_custom_mask is not None: + qk_indptr = _compute_page_qk_indptr( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + page_size, + ) + if packed_custom_mask is None and custom_mask is not None: + # create packed custom mask from custom mask + packed_custom_mask, qk_indptr = segment_packbits( + custom_mask.contiguous().view(-1), + qk_indptr, + bitorder="little", + ) + if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( @@ -685,7 +737,7 @@ def begin_forward( self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices self._paged_kv_last_page_len_buf.copy_(paged_kv_last_page_len) - if custom_mask is not None: + if packed_custom_mask is not None: if not torch.is_tensor(self._custom_mask_buf): raise ValueError( "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." @@ -694,29 +746,17 @@ def begin_forward( raise ValueError( "qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) - self._custom_mask_buf[: len(custom_mask)] = custom_mask + self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask # NOTE(Zihao): qk_indptr has the same length as qo_indptr - self._qk_indptr_buf.copy_( - _compute_page_qk_indptr( - qo_indptr, - paged_kv_indptr, - paged_kv_last_page_len, - page_size, - ) - ) + self._qk_indptr_buf.copy_(qk_indptr) else: self._qo_indptr_buf = qo_indptr self._paged_kv_indptr_buf = paged_kv_indptr self._paged_kv_indices_buf = paged_kv_indices self._paged_kv_last_page_len_buf = paged_kv_last_page_len - if custom_mask is not None: - self._custom_mask = custom_mask - self._qk_indptr = _compute_page_qk_indptr( - qo_indptr, - paged_kv_indptr, - paged_kv_last_page_len, - page_size, - ) + if packed_custom_mask is not None: + self._custom_mask = packed_custom_mask + self._qk_indptr = qk_indptr self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -1021,9 +1061,9 @@ class BatchPrefillWithRaggedKVCacheWrapper: >>> qo_len = (qo_indptr[1:] - qo_indptr[:-1]).cpu().tolist() >>> kv_len = (kv_indptr[1:] - kv_indptr[:-1]).cpu().tolist() >>> for i in range(batch_size): - ... mask_i = torch.triu( - ... torch.full((qo_len[i], kv_len[i]), -float("inf"), dtype=torch.float32, device="cuda:0"), - ... diagonal=(kv_len[i] - qo_len[i] + 1), + ... mask_i = torch.tril( + ... torch.full((qo_len[i], kv_len[i]), True, device="cuda:0"), + ... diagonal=(kv_len[i] - qo_len[i]), ... ) ... mask_arr.append(mask_i.flatten()) ... @@ -1095,7 +1135,7 @@ def __init__( custom_mask_buf : Optional[torch.Tensor] The user reserved GPU buffer to store the custom mask tensor, should be large - enough to store the maximum possible size of the custom mask tensor during the + enough to store the maximum possible size of the packed custom mask tensor during the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` is ``True`` and custom mask will be used in attention computation. @@ -1159,6 +1199,7 @@ def begin_forward( num_kv_heads: int, head_dim: int, custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -1176,13 +1217,24 @@ def begin_forward( head_dim : int The dimension of the heads. custom_mask : Optional[torch.Tensor] - The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. - If provided, the custom mask will be added to the attention matrix before softmax - and after scaling. The mask tensor should be in the same device as the input tensors. + The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. + The elements in the mask tensor should be either ``True`` or ``False``, + where ``False`` means the corresponding element in the attention matrix will be + masked out. Please refer to the :ref:`mask layout ` for more details about flattened layout of mask tensor. + When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the + function will pack the custom mask tensor into a 1D packed mask tensor, which introduces + additional overhead. + packed_custom_mask : Optional[torch.Tensor] + The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. + The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + + If provided, the custom mask will be added to the attention matrix before softmax + and after scaling. The mask tensor should be in the same device as the input tensors. + Notes ----- The :meth:`begin_forward` method should be called before any :meth:`forward` or @@ -1198,6 +1250,15 @@ def begin_forward( raise ValueError( "The kv_indptr length should be equal to qk_indptr length." ) + if custom_mask is not None or packed_custom_mask is not None: + qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) + if packed_custom_mask is None and custom_mask is not None: + # create packed custom mask from custom mask + packed_custom_mask, qk_indptr = segment_packbits( + custom_mask.contiguous().view(-1), + qk_indptr, + bitorder="little", + ) if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: @@ -1209,7 +1270,7 @@ def begin_forward( ) self._qo_indptr_buf.copy_(qo_indptr) self._kv_indptr_buf.copy_(kv_indptr) - if custom_mask is not None: + if packed_custom_mask is not None: if not torch.is_tensor(self._custom_mask_buf): raise ValueError( "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." @@ -1218,14 +1279,14 @@ def begin_forward( raise ValueError( "qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention computation." ) - self._custom_mask_buf[: len(custom_mask)] = custom_mask - self._qk_indptr_buf.copy_(_compute_qk_indptr(qo_indptr, kv_indptr)) + self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask + self._qk_indptr_buf.copy_(qk_indptr) else: self._qo_indptr_buf = qo_indptr self._kv_indptr_buf = kv_indptr - if custom_mask is not None: - self._custom_mask_buf = custom_mask - self._qk_indptr_buf = _compute_qk_indptr(qo_indptr, kv_indptr) + if packed_custom_mask is not None: + self._custom_mask_buf = packed_custom_mask + self._qk_indptr_buf = qk_indptr self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, diff --git a/python/flashinfer/quantization.py b/python/flashinfer/quantization.py index 602c515d..28ee827f 100644 --- a/python/flashinfer/quantization.py +++ b/python/flashinfer/quantization.py @@ -32,7 +32,7 @@ def packbits(x: torch.Tensor, bitorder: str = "big"): r"""Pack the elements of a binary-valued array into bits in a uint8 array. - See `numpy.packbits `_ for more details. + The semantics of this function is the same as `numpy.packbits `_. Parameters ---------- @@ -45,19 +45,25 @@ def packbits(x: torch.Tensor, bitorder: str = "big"): ------- y: torch.Tensor An uint8 packed array, shape ``((x.size(0) + 7) / 8),)``. + + See Also + -------- + segment_packbits """ return _kernels.packbits(x, bitorder) def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str = "big"): - r"""Pack a batch elements of a binary-valued array into bits in a uint8 array. + r"""Pack a batch of binary-valued segments into bits in a uint8 array. + + For each segment, the semantics of this function is the same as `numpy.packbits `_. Parameters ---------- x: torch.Tensor - The 1D binary-valued array to pack. + The 1D binary-valued array to pack, shape ``(indptr[-1],)``. indptr: torch.Tensor - The index pointer of the first element of each segment in :attr:`x`. + The index pointer of each segment in :attr:`x`, shape ``(batch_size + 1,)``. The i-th segment in :attr:`x` is ``x[indptr[i]:indptr[i+1]]``. bitorder: str The bit-order ("bit"/"little") of the output. Default is "big". @@ -68,8 +74,12 @@ def segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str = "big An uint8 packed array, shape: ``(new_indptr[-1],)``. The ``y[new_indptr[i]:new_indptr[i+1]]`` contains the packed bits ``x[indptr[i]:indptr[i+1]]``. new_indptr: torch.Tensor - The new index pointer of the first element of each packed segment in :attr:`y`. + The new index pointer of each packed segment in :attr:`y`, shape ``(batch_size + 1,)``. It's guaranteed that ``new_indptr[i+1] - new_indptr[i] == (indptr[i+1] - indptr[i] + 7) // 8``. + + See Also + -------- + packbits """ seglen = indptr[1:] - indptr[:-1] packed_len = (seglen + 7) // 8 diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index fceaf03a..988ba1b3 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -46,7 +46,7 @@ def get_cu_file_str( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {idtype}* q_offset, paged_kv_t paged_kv, - float* custom_mask, {idtype}* qk_indptr, + uint8_t* custom_mask, {idtype}* qk_indptr, {dtype_out}* o, float* tmp, float* lse, uint32_t num_qo_tiles, uint32_t num_qo_heads, float sm_scale, float rope_scale, diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index abc7f142..b83a39bd 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -44,7 +44,7 @@ def get_cu_file_str( """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, - float* custom_mask, {idtype}* qk_indptr, + uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, {dtype_out}* o, float* tmp, float* lse, uint32_t batch_size, uint32_t num_qo_tiles, diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index dac301ee..c518c4dd 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -42,7 +42,7 @@ def get_cu_file_str( namespace flashinfer {{ template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, float* custom_mask, {dtype_out}* o, + {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, uint8_t* custom_mask, {dtype_out}* o, float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index c15861fc..66e0dc01 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -233,9 +233,9 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( workspace_buffer, kv_layout ) custom_mask = ( - torch.triu( - torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32), - diagonal=(kv_len - qo_len + 1), + torch.tril( + torch.full((batch_size, qo_len, kv_len), True), + diagonal=(kv_len - qo_len), ) .reshape(-1) .to(0) @@ -357,9 +357,9 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( ) custom_mask = ( - torch.triu( - torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32), - diagonal=(kv_len - qo_len + 1), + torch.tril( + torch.full((batch_size, qo_len, kv_len), True), + diagonal=(kv_len - qo_len), ) .reshape(-1) .to(0) diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 56af6af0..c8f40b98 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -23,6 +23,8 @@ using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; +inline uint32_t ceil_div(uint32_t a, uint32_t b) { return (a + b - 1) / b; } + template void bench_flashinfer_single_prefill(nvbench::state& state) { size_t kv_len = state.get_int64("kv_len"); @@ -46,7 +48,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { thrust::device_vector Q(qo_len * num_qo_heads * head_dim); thrust::device_vector K(kv_len * num_kv_heads * head_dim); thrust::device_vector V(kv_len * num_kv_heads * head_dim); - thrust::device_vector mask(qo_len * kv_len); + thrust::device_vector mask(ceil_div(qo_len * kv_len, 8)); thrust::device_vector O(qo_len * num_qo_heads * head_dim); thrust::device_vector tmp(8 * 1024 * 1024); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 394d114f..144ef98e 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -25,7 +25,7 @@ namespace flashinfer { template cudaError_t SinglePrefillWithKVCacheCustomMask( - DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, float* lse, + DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, From 93f700f8085c008c69a0cc0ca3f9c67e2b1811ff Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 16 Jun 2024 07:50:37 +0000 Subject: [PATCH 7/7] trailing empty lines --- include/flashinfer/quantization.cuh | 2 +- python/csrc/quantization.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/quantization.cuh b/include/flashinfer/quantization.cuh index 780e2b99..ce61af6b 100644 --- a/include/flashinfer/quantization.cuh +++ b/include/flashinfer/quantization.cuh @@ -111,4 +111,4 @@ cudaError_t SegmentPackBits(bool* input, uint8_t* output, IdType* input_indptr, } // namespace quantization } // namespace flashinfer -#endif // FLASHINFER_QUANTIZATION_CUH_ \ No newline at end of file +#endif // FLASHINFER_QUANTIZATION_CUH_ diff --git a/python/csrc/quantization.cu b/python/csrc/quantization.cu index 90c0b1fa..2eb4813d 100644 --- a/python/csrc/quantization.cu +++ b/python/csrc/quantization.cu @@ -61,4 +61,4 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, c10::cuda::getCurrentCUDAStream()); return y; -} \ No newline at end of file +}