diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/flashinfer-aot/csrc_aot/batch_decode.cu index 48e0e6bd..118652e0 100644 --- a/flashinfer-aot/csrc_aot/batch_decode.cu +++ b/flashinfer-aot/csrc_aot/batch_decode.cu @@ -85,13 +85,13 @@ std::vector BatchDecodeWithPagedKVCachePlan( return plan_info.ToVector(); } -std::vector BatchDecodeWithPagedKVCacheRun( +torch::Tensor BatchDecodeWithPagedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse) { DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(kv_layout_code); @@ -111,9 +111,11 @@ std::vector BatchDecodeWithPagedKVCacheRun( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); torch::Tensor o = torch::empty_like(q); - torch::Tensor lse; - if (return_lse) { - lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -160,7 +162,7 @@ std::vector BatchDecodeWithPagedKVCacheRun( static_cast(paged_kv_last_page_len.data_ptr())); ParamsT params(static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), /*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -194,9 +196,5 @@ std::vector BatchDecodeWithPagedKVCacheRun( }); }); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu index 6483d983..fe665a1d 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu @@ -29,13 +29,13 @@ std::vector BatchDecodeWithPagedKVCachePlan( torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); -std::vector BatchDecodeWithPagedKVCacheRun( +torch::Tensor BatchDecodeWithPagedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); + float rope_scale, float rope_theta, std::optional maybe_lse); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu index 9ef91d8b..2f353d02 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu @@ -15,11 +15,12 @@ */ #include -std::vector single_prefill_with_kv_cache( +torch::Tensor single_prefill_with_kv_cache( unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_packed_custom_mask, torch::Tensor tmp, std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse); std::vector BatchPrefillWithKVCachePlan( unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, @@ -27,16 +28,16 @@ std::vector BatchPrefillWithKVCachePlan( torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); -std::vector BatchPrefillWithRaggedKVCacheRun( +torch::Tensor BatchPrefillWithRaggedKVCacheRun( unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse); + std::optional maybe_lse); -std::vector BatchPrefillWithPagedKVCacheRun( +torch::Tensor BatchPrefillWithPagedKVCacheRun( unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, @@ -44,7 +45,7 @@ std::vector BatchPrefillWithPagedKVCacheRun( torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); + float rope_scale, float rope_theta, std::optional maybe_lse); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, diff --git a/flashinfer-aot/csrc_aot/single_prefill.cu b/flashinfer-aot/csrc_aot/single_prefill.cu index c406ce95..4c944af7 100644 --- a/flashinfer-aot/csrc_aot/single_prefill.cu +++ b/flashinfer-aot/csrc_aot/single_prefill.cu @@ -32,12 +32,12 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params } // namespace flashinfer -std::vector single_prefill_with_kv_cache( +torch::Tensor single_prefill_with_kv_cache( unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_packed_custom_mask, torch::Tensor tmp, std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + std::optional maybe_lse) { auto device = q.device(); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; @@ -58,9 +58,11 @@ std::vector single_prefill_with_kv_cache( } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; @@ -90,7 +92,7 @@ std::vector single_prefill_with_kv_cache( ? static_cast(maybe_packed_custom_mask->data_ptr()) : nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -109,9 +111,5 @@ std::vector single_prefill_with_kv_cache( }); }); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } diff --git a/python/csrc/flashinfer_sampling_ops.cu b/python/csrc/flashinfer_sampling_ops.cu index 0ab59fc9..37f9daa6 100644 --- a/python/csrc/flashinfer_sampling_ops.cu +++ b/python/csrc/flashinfer_sampling_ops.cu @@ -47,10 +47,10 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, unsigned int top_k_val); -std::vector chain_speculative_sampling( +torch::Tensor chain_speculative_sampling( torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, - torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, - std::optional maybe_output_emitted_token_num, bool deterministic); + torch::Tensor target_probs, torch::Tensor output_accepted_token_num, + torch::Tensor output_emitted_token_num, bool deterministic); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index db4c0a5c..c43e5f49 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -314,10 +314,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional chain_speculative_sampling( +torch::Tensor chain_speculative_sampling( torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, - torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, - std::optional maybe_output_emitted_token_num, bool deterministic) { + torch::Tensor target_probs, torch::Tensor output_accepted_token_num, + torch::Tensor output_emitted_token_num, bool deterministic) { CHECK_INPUT(draft_probs); CHECK_INPUT(draft_token_ids); CHECK_INPUT(uniform_samples); @@ -339,6 +339,8 @@ std::vector chain_speculative_sampling( CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1)); CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1)); CHECK_EQ(vocab_size, target_probs.size(2)); + CHECK_EQ(batch_size, output_accepted_token_num.size(0)); + CHECK_EQ(batch_size, output_emitted_token_num.size(0)); draft_probs = draft_probs.to(torch::kFloat32); draft_token_ids = draft_token_ids.to(torch::kInt32); @@ -349,18 +351,6 @@ std::vector chain_speculative_sampling( auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1}, torch::dtype(torch::kInt32).device(device)); - bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value(); - bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value(); - auto output_accepted_token_num = maybe_output_accepted_token_num.value_or( - torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device))); - auto output_emitted_token_num = maybe_output_emitted_token_num.value_or( - torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device))); - if (has_output_accepted_token_num) { - CHECK_EQ(has_output_emitted_token_num, true); - CHECK_EQ(batch_size, output_accepted_token_num.size(0)); - CHECK_EQ(batch_size, output_emitted_token_num.size(0)); - } - cudaError_t status = sampling::ChainSpeculativeSampling( static_cast(draft_probs.data_ptr()), static_cast(draft_token_ids.data_ptr()), static_cast(uniform_samples.data_ptr()), static_cast(target_probs.data_ptr()), @@ -372,5 +362,5 @@ std::vector chain_speculative_sampling( TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " + std::string(cudaGetErrorString(status))); - return {output_token_ids, output_accepted_token_num, output_emitted_token_num}; + return output_token_ids; } diff --git a/python/flashinfer/activation.py b/python/flashinfer/activation.py index 0b5c18e5..5dea2dd6 100644 --- a/python/flashinfer/activation.py +++ b/python/flashinfer/activation.py @@ -14,16 +14,17 @@ limitations under the License. """ -from typing import Optional +from types import SimpleNamespace + +import torch + from .jit import ( - load_cuda_ops, FLASHINFER_GEN_SRC_DIR, gen_act_and_mul_cu, has_prebuilt_ops, + load_cuda_ops, ) - -import torch - +from .utils import register_custom_op, register_fake_op silu_def_cu_str = r""" __device__ __forceinline__ float silu(const float& val) { @@ -73,15 +74,31 @@ def get_act_and_mul_module(act_func_name: str): if has_prebuilt_ops: from . import _kernels - _jit_modules[act_func_name] = _kernels + module = _kernels else: - _jit_modules[act_func_name] = compile_act_and_mul_module( + module = compile_act_and_mul_module( act_func_name, act_func_def_str[act_func_name] ) + + # torch library for act_and_mul + fname = f"{act_func_name}_and_mul" + fn = getattr(module, fname) + + @register_custom_op(f"flashinfer::{fname}", mutates_args=("out",)) + def _act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None: + fn(out, input) + + @register_fake_op(f"flashinfer::{fname}") + def _fake_act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None: + pass + + # Register the module + _jit_modules[act_func_name] = SimpleNamespace(**{fname: _act_and_mul}) + return _jit_modules[act_func_name] -def _check_shape(input: torch.Tensor, output: torch.Tensor): +def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" assert ( input.shape[:-1] == output.shape[:-1] diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 7c10e5cb..5b29bddc 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -14,11 +14,19 @@ limitations under the License. """ -import math -from typing import Optional, Tuple, List -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from typing import List, Optional, Tuple + import torch +from .decode import ( + BatchDecodeWithPagedKVCacheWrapper, +) +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .prefill import ( + BatchPrefillWithPagedKVCacheWrapper, + single_prefill_with_kv_cache, +) +from .utils import register_custom_op, register_fake_op _cascade_module = None @@ -41,15 +49,7 @@ def get_cascade_module(): return _cascade_module -from .decode import ( - BatchDecodeWithPagedKVCacheWrapper, -) -from .prefill import ( - single_prefill_with_kv_cache_return_lse, - BatchPrefillWithPagedKVCacheWrapper, -) - - +@register_custom_op("flashinfer::merge_state", mutates_args=()) def merge_state( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -76,10 +76,10 @@ def merge_state( ------- V : torch.Tensor The merged attention output (equivalent to attention with merged KV-segment - ``[A: B]``), shape: ``[batch_size, num_heads, head_dim]``. + ``[A: B]``), shape: ``[seq_len, num_heads, head_dim]``. S : torch.Tensor The logsumexp value from the merged KV-segment ``[A: B]``, shape: - ``[batch_size, num_heads]``. + ``[seq_len, num_heads]``. Example ------- @@ -101,6 +101,16 @@ def merge_state( return get_cascade_module().merge_state(v_a, s_a, v_b, s_b) +@register_fake_op("flashinfer::merge_state") +def _fake_merge_state( + v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + v = torch.empty_like(v_a) + s = torch.empty_like(s_a) + return v, s + + +@register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s")) def merge_state_in_place( v: torch.Tensor, s: torch.Tensor, @@ -147,6 +157,18 @@ def merge_state_in_place( get_cascade_module().merge_state_in_place(v, s, v_other, s_other, mask) +@register_fake_op("flashinfer::merge_state_in_place") +def _fake_merge_state_in_place( + v: torch.Tensor, + s: torch.Tensor, + v_other: torch.Tensor, + s_other: torch.Tensor, + mask: Optional[torch.Tensor] = None, +) -> None: + pass + + +@register_custom_op("flashinfer::merge_states", mutates_args=()) def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Merge multiple attention states (v, s). @@ -187,6 +209,15 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch. return get_cascade_module().merge_states(v, s) +@register_fake_op("flashinfer::merge_states") +def _fake_merge_states( + v: torch.Tensor, s: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + v = torch.empty_like(v) + s = torch.empty_like(s) + return v, s + + class MultiLevelCascadeAttentionWrapper: r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all levels KV-Cache are stored in a unified paged table. @@ -442,12 +473,13 @@ def run( :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and ``paged_kv_cache[:, 1]`` is the value-cache. """ - out, lse = self._batch_prefill_wrappers[-1].run_return_lse( + out, lse = self._batch_prefill_wrappers[-1].run( q, paged_kv_cache, + return_lse=True, ) for wrapper in self._batch_prefill_wrappers[:-1]: - out_i, lse_i = wrapper.run_return_lse(q, paged_kv_cache) + out_i, lse_i = wrapper.run(q, paged_kv_cache, return_lse=True) merge_state_in_place(out, lse, out_i, lse_i) return out @@ -670,7 +702,7 @@ def forward( V : torch.Tensor The attention output, shape: ``[batch_size, num_heads, head_dim]`` """ - V_shared, S_shared = single_prefill_with_kv_cache_return_lse( + V_shared, S_shared = single_prefill_with_kv_cache( q, k_shared, v_shared, @@ -680,6 +712,7 @@ def forward( sm_scale=self._batch_decode_wrapper._sm_scale, rope_scale=self._batch_decode_wrapper._rope_scale, rope_theta=self._batch_decode_wrapper._rope_theta, + return_lse=True, ) V_unique, S_unique = self._batch_decode_wrapper.forward_return_lse( q, @@ -943,7 +976,7 @@ def forward( -------- MultiLevelCascadeAttentionWrapper """ - V_shared, S_shared = single_prefill_with_kv_cache_return_lse( + V_shared, S_shared = single_prefill_with_kv_cache( q, k_shared, v_shared, @@ -954,6 +987,7 @@ def forward( sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, + return_lse=True, ) V_unique, S_unique = self._batch_prefill_wrapper.forward_return_lse( q, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index aea27471..643416dd 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -14,36 +14,37 @@ limitations under the License. """ +import functools import math -from typing import Optional, Union, Dict, Tuple, Any from types import SimpleNamespace +from typing import Any, List, Optional, Tuple, Union + import torch -import functools from .jit import ( - load_cuda_ops, - FLASHINFER_GEN_SRC_DIR, - gen_single_decode_cu, - get_single_decode_uri, gen_batch_decode_cu, + gen_single_decode_cu, get_batch_decode_uri, + get_single_decode_uri, has_prebuilt_ops, + load_cuda_ops, prebuilt_ops_uri, ) -from .prefill import get_single_prefill_module, get_batch_prefill_module - +from .prefill import get_batch_prefill_module, get_single_prefill_module from .utils import ( + MaskMode, PosEncodingMode, TensorLayout, - MaskMode, - canonicalize_torch_dtype, - _check_pos_encoding_mode, - _check_kv_layout, - _unpack_paged_kv_cache, _check_cached_qkv_data_type, - _get_cache_buf, + _check_kv_layout, + _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, + _get_cache_buf, _get_range_buf, + _unpack_paged_kv_cache, + canonicalize_torch_dtype, + register_custom_op, + register_fake_op, ) @@ -78,21 +79,70 @@ def compile_batch_decode_module( def get_single_decode_module(*args): global _single_decode_modules if args not in _single_decode_modules: - if has_prebuilt_ops and get_single_decode_uri(*args) in prebuilt_ops_uri: + uri = get_single_decode_uri(*args) + if has_prebuilt_ops and uri in prebuilt_ops_uri: from . import _decode_kernels - _single_decode_modules[args] = SimpleNamespace( - run=_decode_kernels.single_decode_with_kv_cache, - ) + run_func = _decode_kernels.single_decode_with_kv_cache else: - _single_decode_modules[args] = compile_single_decode_module(*args) + run_func = compile_single_decode_module(*args).run + + # torch library for single_decode_with_kv_cache + + @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp",)) + def run_single_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + kv_layout_code: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + ) -> torch.Tensor: + return run_func( + q, + k, + v, + tmp, + alibi_slopes, + kv_layout_code, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + ) + + @register_fake_op(f"flashinfer::{uri}_run") + def _fake_run_single_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + kv_layout_code: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module. + _single_decode_modules[args] = SimpleNamespace(run=run_single_decode) return _single_decode_modules[args] def get_batch_decode_module(*args): global _batch_decode_modules if args not in _batch_decode_modules: - if has_prebuilt_ops and get_batch_decode_uri(*args) in prebuilt_ops_uri: + uri = get_batch_decode_uri(*args) + if has_prebuilt_ops and uri in prebuilt_ops_uri: from . import _decode_kernels # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later @@ -100,20 +150,102 @@ def get_batch_decode_module(*args): dtype_kv = args[1] head_dim = args[4] use_logits_cap = args[7] - plan_func = lambda *plan_args: _decode_kernels.batch_decode_with_paged_kv_cache_plan( - use_logits_cap, - head_dim, - torch.empty(0, dtype=dtype_q), - torch.empty(0, dtype=dtype_kv), - *plan_args, + plan_func = ( + lambda *plan_args: _decode_kernels.batch_decode_with_paged_kv_cache_plan( + use_logits_cap, + head_dim, + torch.empty(0, dtype=dtype_q), + torch.empty(0, dtype=dtype_kv), + *plan_args, + ) ) run_func = _decode_kernels.batch_decode_with_paged_kv_cache_run - _batch_decode_modules[args] = SimpleNamespace( - plan=plan_func, - run=run_func, - ) else: - _batch_decode_modules[args] = compile_batch_decode_module(*args) + mod = compile_batch_decode_module(*args) + plan_func = mod.plan + run_func = mod.run + + # torch library for batch_decode_with_paged_kv_cache_run + + @register_custom_op( + f"flashinfer::{uri}_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "maybe_lse", + ), + ) + def run_batch_decode( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: Optional[torch.Tensor], + paged_v_cache: Optional[torch.Tensor], + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + kv_layout_code: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + alibi_slopes, + kv_layout_code, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + ) + + @register_fake_op(f"flashinfer::{uri}_run") + def _fake_run_batch_decode( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: Optional[torch.Tensor], + paged_v_cache: Optional[torch.Tensor], + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + kv_layout_code: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module. + # + # Note that plan is not part of model logic. It should not be included in + # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. + _batch_decode_modules[args] = SimpleNamespace( + plan=plan_func, + run=run_batch_decode, + ) return _batch_decode_modules[args] @@ -556,6 +688,8 @@ def plan( The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. + + The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ batch_size = len(last_page_len) if logits_soft_cap is None: @@ -752,6 +886,12 @@ def run( if rope_theta is None: rope_theta = 1e4 + lse = None + if return_lse: + lse = torch.empty( + (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + ) + if self.use_tensor_cores: out = self._cached_module.paged_run( self._float_workspace_buffer, @@ -773,7 +913,7 @@ def run( sm_scale, rope_scale, rope_theta, - return_lse, + lse, ) else: out = self._cached_module.run( @@ -793,12 +933,12 @@ def run( sm_scale, rope_scale, rope_theta, - return_lse, + lse, ) if v_scale is not None: - out[0] *= v_scale + out *= v_scale - return out if return_lse else out[0] + return (out, lse) if return_lse else out def forward_return_lse( self, @@ -821,8 +961,13 @@ def forward_return_lse( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta - return self.run_return_lse( - q, paged_kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale + return self.run( + q, + paged_kv_cache, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + return_lse=True, ) run_return_lse = functools.partialmethod(run, return_lse=True) diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 5b938765..4c88ac2e 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -14,14 +14,18 @@ limitations under the License. """ +from types import SimpleNamespace from typing import Optional import torch -from .utils import get_indptr, get_compute_capability -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops -from typing import Optional - +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .utils import ( + get_compute_capability, + get_indptr, + register_custom_op, + register_fake_op, +) _gemm_module = None _gemm_module_sm90 = None @@ -33,9 +37,9 @@ def get_gemm_module(): if has_prebuilt_ops: from . import _kernels - _gemm_module = _kernels + module = _kernels else: - _gemm_module = load_cuda_ops( + module = load_cuda_ops( "gemm", [ FLASHINFER_CSRC_DIR / "bmm_fp8.cu", @@ -43,6 +47,69 @@ def get_gemm_module(): FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", ], ) + + # torch library for bmm_fp8 + + @register_custom_op("flashinfer::bmm_fp8", mutates_args=("D",)) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + ) -> None: + module.bmm_fp8(A, B, D, A_scale, B_scale) + + @register_fake_op("flashinfer::bmm_fp8") + def _fake_bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + ) -> None: + pass + + # torch library for cutlass_segment_gemm + + @register_custom_op("flashinfer::cutlass_segment_gemm", mutates_args=()) + def cutlass_segment_gemm( + workspace_buffer: torch.Tensor, + seg_indptr: torch.Tensor, + weight_indices: torch.Tensor, + x: torch.Tensor, + weights: torch.Tensor, + batch_size: int, + weight_column_major: bool, + ) -> torch.Tensor: + return module.cutlass_segment_gemm( + workspace_buffer, + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) + + @register_fake_op("flashinfer::cutlass_segment_gemm") + def _fake_cutlass_segment_gemm( + workspace_buffer: torch.Tensor, + seg_indptr: torch.Tensor, + weight_indices: torch.Tensor, + x: torch.Tensor, + weights: torch.Tensor, + batch_size: int, + weight_column_major: bool, + ) -> torch.Tensor: + return torch.empty_like(x) + + # Register the module + _gemm_module = SimpleNamespace( + bmm_fp8=bmm_fp8, + cutlass_segment_gemm=cutlass_segment_gemm, + ) + return _gemm_module @@ -53,9 +120,9 @@ def get_gemm_sm90_module(): if has_prebuilt_ops: from . import _kernels_sm90 - _gemm_module_sm90 = _kernels_sm90 + module = _kernels_sm90 else: - _gemm_module_sm90 = load_cuda_ops( + module = load_cuda_ops( "gemm_sm90", [ FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", @@ -63,6 +130,49 @@ def get_gemm_sm90_module(): ], extra_cuda_cflags=["-gencode", "arch=compute_90a,code=sm_90a"], ) + + # torch library for cutlass_segment_gemm_sm90 + + @register_custom_op("flashinfer::cutlass_segment_gemm_sm90", mutates_args=()) + def cutlass_segment_gemm_sm90( + workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + seg_indptr: torch.Tensor, + weight_indices: torch.Tensor, + x: torch.Tensor, + weights: torch.Tensor, + batch_size: int, + weight_column_major: bool, + ) -> torch.Tensor: + return module.cutlass_segment_gemm_sm90( + workspace_buffer, + int_workspace_buffer, + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) + + @register_fake_op("flashinfer::cutlass_segment_gemm_sm90") + def _fake_cutlass_segment_gemm_sm90( + workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + seg_indptr: torch.Tensor, + weight_indices: torch.Tensor, + x: torch.Tensor, + weights: torch.Tensor, + batch_size: int, + weight_column_major: bool, + ) -> torch.Tensor: + return torch.empty_like(x) + + # Register the module + _gemm_module_sm90 = SimpleNamespace( + cutlass_segment_gemm_sm90=cutlass_segment_gemm_sm90, + ) + return _gemm_module_sm90 diff --git a/python/flashinfer/jit/batch_decode_templ.py b/python/flashinfer/jit/batch_decode_templ.py index 349b3e95..ecbadbba 100644 --- a/python/flashinfer/jit/batch_decode_templ.py +++ b/python/flashinfer/jit/batch_decode_templ.py @@ -62,7 +62,7 @@ return plan_info.ToVector(); } -std::vector BatchDecodeWithPagedKVCacheRun( +torch::Tensor BatchDecodeWithPagedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, @@ -73,7 +73,8 @@ torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + std::optional maybe_lse) { DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(kv_layout_code); @@ -91,9 +92,11 @@ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); torch::Tensor o = torch::empty_like(q); - torch::Tensor lse; - if (return_lse) { - lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); @@ -122,7 +125,7 @@ ParamsT params( static_cast<{{ dtype_q }}*>(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast<{{ dtype_o }}*>(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), {% if use_alibi == "true" %}static_cast(alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -147,11 +150,7 @@ TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index f0cdf89c..d32e7b1d 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -67,7 +67,7 @@ return plan_info.ToVector(); } -std::vector BatchPrefillWithRaggedKVCacheRun( +torch::Tensor BatchPrefillWithRaggedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, torch::Tensor k, torch::Tensor v, @@ -76,7 +76,7 @@ torch::Tensor qo_indptr, torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -96,10 +96,11 @@ auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - int64_t nnz_qo = q.size(0); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = float_workspace_buffer.data_ptr(); @@ -114,7 +115,7 @@ {% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %}, /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast<{{ dtype_o }}*>(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -148,14 +149,10 @@ TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status)); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } -std::vector BatchPrefillWithPagedKVCacheRun( +torch::Tensor BatchPrefillWithPagedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, @@ -169,7 +166,7 @@ torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); @@ -187,10 +184,11 @@ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - int64_t nnz_qo = q.size(0); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); @@ -222,7 +220,7 @@ {% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %}, /*q_offset=*/nullptr, static_cast<{{ dtype_o }}*>(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); @@ -255,11 +253,7 @@ TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/python/flashinfer/jit/single_prefill_templ.py b/python/flashinfer/jit/single_prefill_templ.py index 0715934b..12e39b5f 100644 --- a/python/flashinfer/jit/single_prefill_templ.py +++ b/python/flashinfer/jit/single_prefill_templ.py @@ -154,10 +154,10 @@ using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; using AttentionVariant = ComposedAttention; -std::vector single_prefill_with_kv_cache( +torch::Tensor single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_packed_custom_mask, torch::Tensor tmp, std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { + float rope_scale, float rope_theta, std::optional maybe_lse) { auto device = q.device(); unsigned int head_dim = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; @@ -178,9 +178,11 @@ } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); } ParamsT params( @@ -188,7 +190,7 @@ static_cast<{{ dtype_kv }}*>(v.data_ptr()), {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_packed_custom_mask->data_ptr()){% else %}nullptr{% endif %}, static_cast<{{ dtype_o }}*>(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*lse=*/(maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr), {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale, @@ -201,11 +203,7 @@ "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } + return o; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 0a649f8b..cc3cbb93 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -14,10 +14,18 @@ limitations under the License. """ +from typing import Optional + import torch -from .utils import TensorLayout, _check_kv_layout, _unpack_paged_kv_cache -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .utils import ( + TensorLayout, + _check_kv_layout, + _unpack_paged_kv_cache, + register_custom_op, + register_fake_op, +) _page_module = None @@ -40,6 +48,49 @@ def get_page_module(): return _page_module +@register_custom_op( + "flashinfer::append_paged_kv_cache", + mutates_args=("paged_k_cache", "paged_v_cache"), +) +def _append_paged_kv_cache_kernel( + append_key: torch.Tensor, + append_value: torch.Tensor, + append_indptr: torch.Tensor, + paged_k_cache: Optional[torch.Tensor], + paged_v_cache: Optional[torch.Tensor], + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + layout: int, +) -> None: + get_page_module().append_paged_kv_cache( + append_key, + append_value, + append_indptr, + paged_k_cache, + paged_v_cache, + kv_indices, + kv_indptr, + kv_last_page_len, + layout, + ) + + +@register_fake_op("flashinfer::append_paged_kv_cache") +def _fake_append_paged_kv_cache_kernel( + append_key: torch.Tensor, + append_value: torch.Tensor, + append_indptr: torch.Tensor, + paged_k_cache: Optional[torch.Tensor], + paged_v_cache: Optional[torch.Tensor], + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + kv_last_page_len: torch.Tensor, + layout: int, +) -> None: + pass + + def append_paged_kv_cache( append_key: torch.Tensor, append_value: torch.Tensor, @@ -135,7 +186,7 @@ def append_paged_kv_cache( incorporated appended k/v. """ _check_kv_layout(kv_layout) - get_page_module().append_paged_kv_cache( + _append_paged_kv_cache_kernel( append_key, append_value, append_indptr, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 90da5fdf..6c927665 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -14,38 +14,39 @@ limitations under the License. """ +import functools +import logging import math -from typing import Optional, Dict, Tuple, Union, Any from types import SimpleNamespace -import functools +from typing import Any, List, Optional, Tuple, Union + import torch -import logging from .jit import ( - load_cuda_ops, - FLASHINFER_GEN_SRC_DIR, - gen_single_prefill_cu, - get_single_prefill_uri, gen_batch_prefill_cu, + gen_single_prefill_cu, get_batch_prefill_uri, + get_single_prefill_uri, has_prebuilt_ops, + load_cuda_ops, prebuilt_ops_uri, ) +from .quantization import packbits, segment_packbits from .utils import ( - PosEncodingMode, MaskMode, + PosEncodingMode, TensorLayout, - _check_pos_encoding_mode, - _check_kv_layout, _check_cached_qkv_data_type, + _check_kv_layout, + _check_pos_encoding_mode, + _get_cache_alibi_slopes_buf, + _get_cache_buf, _unpack_paged_kv_cache, - is_float8, canonicalize_torch_dtype, - _get_cache_buf, - _get_cache_alibi_slopes_buf, + is_float8, + register_custom_op, + register_fake_op, ) -from .quantization import packbits, segment_packbits - if has_prebuilt_ops: from . import _prefill_kernels @@ -82,25 +83,80 @@ def compile_batch_prefill_module( def get_single_prefill_module(*args): global _single_prefill_modules if args not in _single_prefill_modules: - if has_prebuilt_ops and get_single_prefill_uri(*args) in prebuilt_ops_uri: + uri = get_single_prefill_uri(*args) + if has_prebuilt_ops and uri in prebuilt_ops_uri: # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later mask_mode = args[5] run_func = lambda *run_args: _prefill_kernels.single_prefill_with_kv_cache( mask_mode, *run_args, ) - _single_prefill_modules[args] = SimpleNamespace( - run=run_func, - ) else: - _single_prefill_modules[args] = compile_single_prefill_module(*args) + run_func = compile_single_prefill_module(*args).run + + # torch library for single_prefill_with_kv_cache + + @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "maybe_lse")) + def run_single_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_packed_custom_mask: Optional[torch.Tensor], + tmp: torch.Tensor, + maybe_alibi_slopes: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return run_func( + q, + k, + v, + maybe_packed_custom_mask, + tmp, + maybe_alibi_slopes, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + ) + + @register_fake_op(f"flashinfer::{uri}_run") + def _fake_run_single_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_packed_custom_mask: Optional[torch.Tensor], + tmp: torch.Tensor, + maybe_alibi_slopes: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module + _single_prefill_modules[args] = SimpleNamespace(run=run_single_prefill) + return _single_prefill_modules[args] def get_batch_prefill_module(*args): global _batch_prefill_modules if args not in _batch_prefill_modules: - if has_prebuilt_ops and get_batch_prefill_uri(*args) in prebuilt_ops_uri: + uri = get_batch_prefill_uri(*args) + if has_prebuilt_ops and uri in prebuilt_ops_uri: # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later head_dim = args[4] plan_func = ( @@ -110,21 +166,189 @@ def get_batch_prefill_module(*args): ) ) mask_mode = args[6] - ragged_run_func = lambda *run_args: _prefill_kernels.batch_prefill_with_ragged_kv_cache_run( - mask_mode, - *run_args, - ) - paged_run_func = lambda *run_args: _prefill_kernels.batch_prefill_with_paged_kv_cache_run( - mask_mode, - *run_args, + ragged_run_func = ( + lambda *run_args: _prefill_kernels.batch_prefill_with_ragged_kv_cache_run( + mask_mode, + *run_args, + ) ) - _batch_prefill_modules[args] = SimpleNamespace( - plan=plan_func, - ragged_run=ragged_run_func, - paged_run=paged_run_func, + paged_run_func = ( + lambda *run_args: _prefill_kernels.batch_prefill_with_paged_kv_cache_run( + mask_mode, + *run_args, + ) ) else: - _batch_prefill_modules[args] = compile_batch_prefill_module(*args) + module = compile_batch_prefill_module(*args) + plan_func = module.plan + ragged_run_func = module.ragged_run + paged_run_func = module.paged_run + + # torch library for ragged_run + + @register_custom_op( + f"flashinfer::{uri}_ragged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "maybe_lse", + ), + ) + def ragged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return ragged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + kv_indptr, + maybe_qk_indptr, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + ) + + @register_fake_op(f"flashinfer::{uri}_ragged_run") + def _fake_ragged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # torch library for paged_run + + @register_custom_op( + f"flashinfer::{get_batch_prefill_uri(*args)}_paged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "maybe_lse", + ), + ) + def paged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return paged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + maybe_qk_indptr, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + ) + + @register_fake_op(f"flashinfer::{uri}_paged_run") + def _fake_paged_run( + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module. + # + # Note that plan is not part of model logic. It should not be included in + # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. + _batch_prefill_modules[args] = SimpleNamespace( + plan=plan_func, + ragged_run=ragged_run, + paged_run=paged_run, + ) return _batch_prefill_modules[args] @@ -139,10 +363,13 @@ def single_prefill_with_kv_cache_with_jit_module( return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) + lse = None + if return_lse: + lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) out = jit_module.run( - q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args + q, k, v, tmp, TensorLayout[kv_layout].value, window_left, lse, *args ) - return out if return_lse else out[0] + return (out, lse) if return_lse else out def single_prefill_with_kv_cache( @@ -292,6 +519,10 @@ def single_prefill_with_kv_cache( else: mask_mode = MaskMode.NON_CAUSAL.value + lse = None + if return_lse: + lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) + out = get_single_prefill_module( q.dtype, k.dtype, @@ -315,10 +546,10 @@ def single_prefill_with_kv_cache( sm_scale, rope_scale, rope_theta, - return_lse, + lse, ) - return out if return_lse else out[0] + return (out, lse) if return_lse else out single_prefill_with_kv_cache_return_lse = functools.partial( @@ -691,6 +922,8 @@ def plan( The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. + + The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: @@ -908,6 +1141,11 @@ def run( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + lse = None + if return_lse: + lse = torch.empty( + (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + ) out = self._cached_module.paged_run( self._float_workspace_buffer, @@ -929,13 +1167,13 @@ def run( sm_scale, rope_scale, rope_theta, - return_lse, + lse, ) if v_scale is not None: - out[0] *= v_scale + out *= v_scale - return out if return_lse else out[0] + return (out, lse) if return_lse else out run_return_lse = functools.partialmethod(run, return_lse=True) @@ -1278,6 +1516,8 @@ def plan( The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. + + The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: @@ -1330,6 +1570,10 @@ def plan( self._custom_mask_buf = packed_custom_mask.to(self.device) self._qk_indptr_buf = qk_indptr.to(self.device) + # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors + qo_indptr_host = qo_indptr.to("cpu", non_blocking=True) + kv_indptr_host = kv_indptr.to("cpu", non_blocking=True) + if packed_custom_mask is not None: mask_mode = MaskMode.CUSTOM.value else: @@ -1356,8 +1600,8 @@ def plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - qo_indptr, - kv_indptr, + qo_indptr_host, + kv_indptr_host, batch_size, num_qo_heads, num_kv_heads, @@ -1447,6 +1691,11 @@ def run( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + lse = None + if return_lse: + lse = torch.empty( + (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + ) if is_float8(q): logging.warning( "Our current prefill kernel implementation needs f16 input, the f8 inputs " @@ -1474,10 +1723,10 @@ def run( sm_scale, rope_scale, rope_theta, - return_lse, + lse, ) - return out if return_lse else out[0] + return (out, lse) if return_lse else out run_return_lse = functools.partialmethod(run, return_lse=True) diff --git a/python/flashinfer/quantization.py b/python/flashinfer/quantization.py index 919d5177..38af10bc 100644 --- a/python/flashinfer/quantization.py +++ b/python/flashinfer/quantization.py @@ -14,10 +14,12 @@ limitations under the License. """ -import torch from typing import Tuple -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +import torch + +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .utils import register_custom_op, register_fake_op _quantization_module = None @@ -40,6 +42,7 @@ def get_quantization_module(): return _quantization_module +@register_custom_op("flashinfer::packbits", mutates_args=()) def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: r"""Pack the elements of a binary-valued array into bits in a uint8 array. @@ -74,6 +77,11 @@ def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: return get_quantization_module().packbits(x, bitorder) +@register_fake_op("flashinfer::packbits") +def _fake_packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: + return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device) + + def segment_packbits( x: torch.Tensor, indptr: torch.Tensor, bitorder: str = "big" ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -112,6 +120,10 @@ def segment_packbits( >>> new_indptr tensor([0, 1, 2, 3], device='cuda:0') + Note + ---- + ``torch.compile`` is not supported for this function because it's data dependent. + See Also -------- packbits diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 3fb71435..408c1f4a 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -14,9 +14,12 @@ limitations under the License. """ +from typing import Tuple + import torch -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .utils import register_custom_op, register_fake_op _rope_module = None @@ -39,6 +42,7 @@ def get_rope_module(): return _rope_module +@register_custom_op("flashinfer::apply_rope_inplace", mutates_args=("q", "k")) def apply_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -119,6 +123,20 @@ def apply_rope_inplace( ) +@register_fake_op("flashinfer::apply_rope_inplace") +def _fake_apply_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + pass + + +@register_custom_op("flashinfer::apply_llama31_rope_inplace", mutates_args=("q", "k")) def apply_llama31_rope_inplace( q: torch.Tensor, k: torch.Tensor, @@ -218,6 +236,23 @@ def apply_llama31_rope_inplace( ) +@register_fake_op("flashinfer::apply_llama31_rope_inplace") +def _fake_apply_llama31_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = True, + rope_scale: float = 8, + rope_theta: float = 5e5, + low_freq_factor: float = 1, + high_freq_factor: float = 4, + old_context_len: int = 8192, +) -> None: + pass + + +@register_custom_op("flashinfer::apply_rope", mutates_args=()) def apply_rope( q: torch.Tensor, k: torch.Tensor, @@ -226,7 +261,7 @@ def apply_rope( interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -309,6 +344,20 @@ def apply_rope( ) +@register_fake_op("flashinfer::apply_rope") +def _fake_apply_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) + + +@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=()) def apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, @@ -320,7 +369,7 @@ def apply_llama31_rope( low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor). @@ -417,3 +466,19 @@ def apply_llama31_rope( high_freq_factor, float(old_context_len), ) + + +@register_fake_op("flashinfer::apply_llama31_rope") +def _fake_apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = True, + rope_scale: float = 8, + rope_theta: float = 5e5, + low_freq_factor: float = 1, + high_freq_factor: float = 4, + old_context_len: int = 8192, +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 6ef0ddeb..b73049a9 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -14,10 +14,13 @@ limitations under the License. """ +from types import SimpleNamespace +from typing import Optional, Tuple, Union + import torch -from typing import Tuple, Union, Optional -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops +from .utils import register_custom_op, register_fake_op _sampling_module = None @@ -28,15 +31,248 @@ def get_sampling_module(): if has_prebuilt_ops: from . import _kernels - _sampling_module = _kernels + module = _kernels else: - _sampling_module = load_cuda_ops( + module = load_cuda_ops( "sampling", [ FLASHINFER_CSRC_DIR / "sampling.cu", FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu", ], ) + + # torch library for sampling_from_probs + + @register_custom_op("flashinfer::sampling_from_probs", mutates_args=()) + def sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + deterministic: bool, + ) -> torch.Tensor: + return module.sampling_from_probs(probs, uniform_samples, deterministic) + + @register_fake_op("flashinfer::sampling_from_probs") + def _fake_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + deterministic: bool, + ) -> torch.Tensor: + return torch.empty(probs.size(0), dtype=torch.int32, device=probs.device) + + # torch library for top_p_sampling_from_probs + + @register_custom_op("flashinfer::top_p_sampling_from_probs", mutates_args=()) + def top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + samples, success = module.top_p_sampling_from_probs( + probs, uniform_samples, maybe_top_p_arr, top_p_val, deterministic + ) + return samples, success + + @register_fake_op("flashinfer::top_p_sampling_from_probs") + def _fake_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sample = torch.empty(probs.size(0), dtype=torch.int32, device=probs.device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=probs.device) + return sample, success + + # torch library for top_k_sampling_from_probs + + @register_custom_op("flashinfer::top_k_sampling_from_probs", mutates_args=()) + def top_k_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + samples, success = module.top_k_sampling_from_probs( + probs, uniform_samples, maybe_top_k_arr, top_k_val, deterministic + ) + return samples, success + + @register_fake_op("flashinfer::top_k_sampling_from_probs") + def _fake_top_k_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sample = torch.empty(probs.size(0), dtype=torch.int32, device=probs.device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=probs.device) + return sample, success + + # torch library for min_p_sampling_from_probs + + @register_custom_op("flashinfer::min_p_sampling_from_probs", mutates_args=()) + def min_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_min_p_arr: Optional[torch.Tensor], + min_p_val: float, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + samples, success = module.min_p_sampling_from_probs( + probs, uniform_samples, maybe_min_p_arr, min_p_val, deterministic + ) + return samples, success + + # torch library for top_k_top_p_sampling_from_probs + + @register_custom_op( + "flashinfer::top_k_top_p_sampling_from_probs", mutates_args=() + ) + def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + samples, success = module.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + ) + return samples, success + + @register_fake_op("flashinfer::top_k_top_p_sampling_from_probs") + def _fake_top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sample = torch.empty(probs.size(0), dtype=torch.int32, device=probs.device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=probs.device) + return sample, success + + # torch library for top_p_renorm_probs + + @register_custom_op("flashinfer::top_p_renorm_probs", mutates_args=()) + def top_p_renorm_probs( + probs: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + ) -> torch.Tensor: + return module.top_p_renorm_probs(probs, maybe_top_p_arr, top_p_val) + + @register_fake_op("flashinfer::top_p_renorm_probs") + def _fake_top_p_renorm_probs( + probs: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + ) -> torch.Tensor: + return torch.empty_like(probs) + + # torch library for top_k_renorm_probs + + @register_custom_op("flashinfer::top_k_renorm_probs", mutates_args=()) + def top_k_renorm_probs( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + ) -> torch.Tensor: + return module.top_k_renorm_probs(probs, maybe_top_k_arr, top_k_val) + + @register_fake_op("flashinfer::top_k_renorm_probs") + def _fake_top_k_renorm_probs( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + ) -> torch.Tensor: + return torch.empty_like(probs) + + # torch library for top_k_mask_logits + + @register_custom_op("flashinfer::top_k_mask_logits", mutates_args=()) + def top_k_mask_logits( + logits: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + ) -> torch.Tensor: + return module.top_k_mask_logits(logits, maybe_top_k_arr, top_k_val) + + @register_fake_op("flashinfer::top_k_mask_logits") + def _fake_top_k_mask_logits( + logits: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + ) -> torch.Tensor: + return torch.empty_like(logits) + + # torch library for chain_speculative_sampling + + @register_custom_op( + "flashinfer::chain_speculative_sampling", + mutates_args=("output_accepted_token_num", "output_emitted_token_num"), + ) + def chain_speculative_sampling( + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + uniform_samples: torch.Tensor, + target_probs: torch.Tensor, + output_accepted_token_num: torch.Tensor, + output_emitted_token_num: torch.Tensor, + deterministic: bool, + ) -> torch.Tensor: + return module.chain_speculative_sampling( + draft_probs, + draft_token_ids, + uniform_samples, + target_probs, + output_accepted_token_num, + output_emitted_token_num, + deterministic, + ) + + @register_fake_op("flashinfer::chain_speculative_sampling") + def _fake_chain_speculative_sampling( + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + uniform_samples: torch.Tensor, + target_probs: torch.Tensor, + output_accepted_token_num: torch.Tensor, + output_emitted_token_num: torch.Tensor, + deterministic: bool, + ) -> torch.Tensor: + b, n = draft_token_ids.shape + device = draft_token_ids.device + return torch.empty((b, n + 1), dtype=torch.int32, device=device) + + # Register the module + _sampling_module = SimpleNamespace( + sampling_from_probs=sampling_from_probs, + top_p_sampling_from_probs=top_p_sampling_from_probs, + top_k_sampling_from_probs=top_k_sampling_from_probs, + min_p_sampling_from_probs=min_p_sampling_from_probs, + top_k_top_p_sampling_from_probs=top_k_top_p_sampling_from_probs, + top_p_renorm_probs=top_p_renorm_probs, + top_k_renorm_probs=top_k_renorm_probs, + top_k_mask_logits=top_k_mask_logits, + chain_speculative_sampling=chain_speculative_sampling, + ) + return _sampling_module @@ -847,12 +1083,23 @@ def chain_speculative_sampling( >>> output_emitted_token_num tensor([1], device='cuda:0') """ - return get_sampling_module().chain_speculative_sampling( + b = draft_probs.size(0) + dev = draft_probs.device + if maybe_output_accepted_token_num is None: + output_accepted_token_num = torch.zeros(b, dtype=torch.int32, device=dev) + else: + output_accepted_token_num = maybe_output_accepted_token_num + if maybe_output_emitted_token_num is None: + output_emitted_token_num = torch.zeros(b, dtype=torch.int32, device=dev) + else: + output_emitted_token_num = maybe_output_emitted_token_num + output_token_ids = get_sampling_module().chain_speculative_sampling( draft_probs, draft_token_ids, uniform_samples, target_probs, - maybe_output_accepted_token_num, - maybe_output_emitted_token_num, + output_accepted_token_num, + output_emitted_token_num, deterministic, ) + return output_token_ids, output_accepted_token_num, output_emitted_token_num diff --git a/tests/conftest.py b/tests/conftest.py index e8948378..95738ddb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,12 +8,51 @@ from torch.torch_version import __version__ as torch_version TORCH_COMPILE_FNS = [ + flashinfer.activation.silu_and_mul, + flashinfer.activation.gelu_and_mul, + flashinfer.activation.gelu_tanh_and_mul, + flashinfer.cascade.merge_state, + flashinfer.cascade.merge_state_in_place, + flashinfer.cascade.merge_states, + flashinfer.cascade.MultiLevelCascadeAttentionWrapper.run, + flashinfer.cascade.BatchDecodeWithSharedPrefixPagedKVCacheWrapper.forward, + flashinfer.cascade.BatchPrefillWithSharedPrefixPagedKVCacheWrapper.forward, + flashinfer.decode.single_decode_with_kv_cache, + flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run, + flashinfer.gemm.bmm_fp8, + flashinfer.gemm.SegmentGEMMWrapper.run, flashinfer.norm.rmsnorm, flashinfer.norm.fused_add_rmsnorm, flashinfer.norm.gemma_rmsnorm, flashinfer.norm.gemma_fused_add_rmsnorm, + flashinfer.page.append_paged_kv_cache, + flashinfer.prefill.single_prefill_with_kv_cache, + flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.run, + flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper.run, + flashinfer.quantization.packbits, + flashinfer.rope.apply_rope, + flashinfer.rope.apply_rope_inplace, + flashinfer.rope.apply_llama31_rope, + flashinfer.rope.apply_llama31_rope_inplace, + flashinfer.sampling.sampling_from_probs, + flashinfer.sampling.top_p_sampling_from_probs, + flashinfer.sampling.top_k_sampling_from_probs, + flashinfer.sampling.min_p_sampling_from_probs, + flashinfer.sampling.top_k_top_p_sampling_from_probs, + flashinfer.sampling.top_p_renorm_probs, + flashinfer.sampling.top_k_renorm_probs, + flashinfer.sampling.top_k_mask_logits, + flashinfer.sampling.chain_speculative_sampling, ] +_TORCH_COMPILE_CACHE = dict() + + +def _set_torch_compile_options(): + import torch._dynamo.config + + torch._dynamo.config.cache_size_limit = 128 + def _monkeypatch_add_torch_compile(func): """ @@ -29,28 +68,49 @@ def _monkeypatch_add_torch_compile(func): else: raise ValueError(f"Unsupported fn type {type(func)}") - components = fn.__module__.split(".") + fullname = fn.__module__ + "." + fn.__qualname__ + components = fullname.split(".") assert components[0] == "flashinfer" module = flashinfer - for component in components[1:]: + for component in components[1:-1]: module = getattr(module, component) + if not hasattr(module, components[-1]): + raise ValueError(f"Failed to monkeypatch: {fullname}") + + def wrapper(*args, **kwargs): + compiled = _TORCH_COMPILE_CACHE.get(fullname) + if compiled is None: + # Warmup -- JIT compile / import the kernels. + # + # From user side, users also need to warmup the model beforehand, + # as suggested by PyTorch Cuda Graph docs (not sure if it's also + # recommended for torch.compile as well.) + # + # For the convenience of FlashInfer testing, we do the warmup here, + # on the first run of the function. The caveat is that the first + # call will run twice: once to warmup, and another through the + # compiled version. + func(*args, **kwargs) + + # Compile + compiled = torch.compile( + func, + fullgraph=True, + backend="inductor", + mode="max-autotune-no-cudagraphs", + ) + _TORCH_COMPILE_CACHE[fn.__name__] = compiled + + return compiled(*args, **kwargs) - setattr( - module, - fn.__name__, - torch.compile( - func, - fullgraph=True, - backend="inductor", - mode="max-autotune-no-cudagraphs", - ), - ) - print("Applied torch.compile to", f"{fn.__module__}.{fn.__name__}") + setattr(module, fn.__name__, wrapper) + print("Applied torch.compile to", fullname) def pytest_configure(config): if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1": if torch_version < TorchVersion("2.4"): pytest.skip("torch.compile requires torch >= 2.4") + _set_torch_compile_options() for fn in TORCH_COMPILE_FNS: _monkeypatch_add_torch_compile(fn) diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index 1c9e78f0..2cd6e806 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -81,7 +81,9 @@ def test_batch_decode_with_paged_kv_cache( ).to(0) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) wrapper.plan( kv_indptr, kv_indices, @@ -96,7 +98,7 @@ def test_batch_decode_with_paged_kv_cache( q_data_type=q_dtype, ) if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) @@ -134,7 +136,7 @@ def test_batch_decode_with_paged_kv_cache( ], dim=0, ).to(kv_dtype) - o_ref_i = flashinfer.single_decode_with_kv_cache( + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, @@ -212,7 +214,9 @@ def test_batch_decode_with_tuple_paged_kv_cache( ).to(0) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) wrapper.plan( kv_indptr, kv_indices, @@ -227,7 +231,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( q_data_type=q_dtype, ) if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) @@ -267,7 +271,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( ], dim=0, ).to(kv_dtype) - o_ref_i = flashinfer.single_decode_with_kv_cache( + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, @@ -289,7 +293,6 @@ def test_batch_decode_with_tuple_paged_kv_cache( @pytest.mark.parametrize( "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) -@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_cuda_graph_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -340,7 +343,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( kv_last_page_device_buffer = torch.empty(batch_size).int().to(0) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + wrapper = flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_indptr_device_buffer, kv_indices_device_buffer, @@ -451,7 +454,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( ], dim=0, ).to(kv_dtype) - o_ref_i = flashinfer.single_decode_with_kv_cache( + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index c2ad5c15..09b0b2cb 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -87,7 +87,7 @@ def test_batch_prefill_with_paged_kv_cache( kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( @@ -104,7 +104,7 @@ def test_batch_prefill_with_paged_kv_cache( logits_soft_cap=logits_soft_cap, ) if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) else: @@ -112,7 +112,7 @@ def test_batch_prefill_with_paged_kv_cache( kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0) kv_indices_buffer = torch.empty(total_num_pages).int().to(0) kv_last_page_len_buffer = torch.empty(batch_size).int().to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_cuda_graph=True, @@ -148,7 +148,7 @@ def test_batch_prefill_with_paged_kv_cache( with torch.cuda.stream(s): for _ in range(3): if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) torch.cuda.current_stream().wait_stream(s) @@ -156,9 +156,9 @@ def test_batch_prefill_with_paged_kv_cache( g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: - o = wrapper.run(q, kv_data) + o = wrapper.run(q, kv_data) wrapper.plan( q_indptr_cpu, @@ -218,7 +218,7 @@ def test_batch_prefill_with_paged_kv_cache( ], dim=0, ).half() - o_ref_i = flashinfer.single_prefill_with_kv_cache( + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, vi, @@ -304,7 +304,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( @@ -321,7 +321,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( logits_soft_cap=logits_soft_cap, ) if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) else: @@ -329,7 +329,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0) kv_indices_buffer = torch.empty(total_num_pages).int().to(0) kv_last_page_len_buffer = torch.empty(batch_size).int().to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_cuda_graph=True, @@ -364,7 +364,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( with torch.cuda.stream(s): for _ in range(3): if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) torch.cuda.current_stream().wait_stream(s) @@ -372,7 +372,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data) + o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) @@ -427,7 +427,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( ], dim=0, ).half() - o_ref_i = flashinfer.single_prefill_with_kv_cache( + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, vi, @@ -495,7 +495,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( ).to(0) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) custom_mask = ( @@ -522,7 +522,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( logits_soft_cap=logits_soft_cap, ) if return_lse: - o_custom, _ = wrapper.run_return_lse(q, kv_data) + o_custom, _ = wrapper.run(q, kv_data, return_lse=True) else: o_custom = wrapper.run(q, kv_data) @@ -541,12 +541,10 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( logits_soft_cap=logits_soft_cap, ) if return_lse: - o_causal, _ = wrapper.run_return_lse(q, kv_data) + o_causal, _ = wrapper.run(q, kv_data, return_lse=True) else: o_causal = wrapper.run(q, kv_data) - torch.testing.assert_close( - o_custom, o_causal, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -580,7 +578,7 @@ def test_batch_prefill_with_ragged_kv_cache( kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( @@ -594,12 +592,12 @@ def test_batch_prefill_with_ragged_kv_cache( logits_soft_cap=logits_soft_cap, ) if return_lse: - o, _ = wrapper.run_return_lse(q, k, v) + o, _ = wrapper.run(q, k, v, return_lse=True) else: o = wrapper.run(q, k, v) for i in range(batch_size): - o_ref_i = flashinfer.single_prefill_with_kv_cache( + o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( q[q_indptr[i] : q_indptr[i + 1]], k[kv_indptr[i] : kv_indptr[i + 1]], v[kv_indptr[i] : kv_indptr[i + 1]], @@ -640,7 +638,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -665,7 +663,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( logits_soft_cap=logits_soft_cap, ) if return_lse: - o_custom, _ = wrapper.run_return_lse(q, k, v) + o_custom, _ = wrapper.run(q, k, v, return_lse=True) else: o_custom = wrapper.run(q, k, v) @@ -681,12 +679,10 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( logits_soft_cap=logits_soft_cap, ) if return_lse: - o_causal, _ = wrapper.run_return_lse(q, k, v) + o_causal, _ = wrapper.run(q, k, v, return_lse=True) else: o_causal = wrapper.run(q, k, v) - torch.testing.assert_close( - o_custom, o_causal, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index f9ca24de..f6d261f5 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -36,7 +36,7 @@ def bsr_attention_ref( shape=(M, N), ) dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device) - o = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask) + o = flashinfer.prefill.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask) return o @@ -70,7 +70,9 @@ def test_block_sparse_attention( o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask) workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=0) - sparse_attention_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer) + sparse_attention_wrapper = flashinfer.sparse.BlockSparseAttentionWrapper( + workspace_buffer + ) sparse_attention_wrapper.plan( indptr, diff --git a/tests/test_page.py b/tests/test_page.py new file mode 100644 index 00000000..8084c573 --- /dev/null +++ b/tests/test_page.py @@ -0,0 +1,41 @@ +import flashinfer +import torch + + +def test_append_paged_kv_cache(): + nnz_kv = 100 + num_kv_heads = 32 + head_dim = 128 + k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) + v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) + # 45 + 8 + 25 + 22 = nnz_kv + kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0") + kv_append_indptr = torch.cat( + [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)] + ).int() + max_num_pages = 1000 + page_size = 16 + paged_kv_cache = ( + torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0) + ) + num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0") + kv_page_indptr = torch.cat( + [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)] + ).int() + # use first 8 pages in the paged-kv + kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0") + # 45 = (3 - 1) * 16 + 13 + # 8 = (1 - 1) * 16 + 8 + # 25 = (2 - 1) * 16 + 9 + # 22 = (2 - 1) * 16 + 6 + kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") + + flashinfer.page.append_paged_kv_cache( + k_append, + v_append, + kv_append_indptr, + paged_kv_cache, + kv_page_indices, + kv_page_indptr, + kv_last_page_len, + ) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index aaa60b34..77d120c1 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -33,7 +33,7 @@ def test_packbits(num_elements, bitorder): x_cpu = torch.rand(num_elements) < 0.5 x_gpu = x_cpu.to(0) x_packed_ref = numpy_packbits_ref(x_cpu, bitorder) - x_packed = flashinfer.packbits(x_gpu, bitorder) + x_packed = flashinfer.quantization.packbits(x_gpu, bitorder) assert torch.equal(x_packed_ref.cpu(), x_packed.cpu()) @@ -47,7 +47,9 @@ def test_segment_packbits(batch_size, bitorder): x_cpu = torch.rand(num_elements) < 0.5 x_gpu = x_cpu.to(0) - y_gpu, new_indptr = flashinfer.segment_packbits(x_gpu, old_indptr, bitorder) + y_gpu, new_indptr = flashinfer.quantization.segment_packbits( + x_gpu, old_indptr, bitorder + ) for i in range(batch_size): x_segment_i = x_gpu[old_indptr[i] : old_indptr[i + 1]]