Skip to content

Commit

Permalink
feat: more sampling operator options (#431)
Browse files Browse the repository at this point in the history
1. implement the first top-k then top-p sampling to align with vllm and
huggingface's behavior
vllm-project/vllm#7137 (comment)
2. add options of using a scalar/tensor for top-p/top-k thresholds for
all sampling operators.
  • Loading branch information
yzh119 authored Aug 9, 2024
1 parent daa5566 commit 68df9c4
Show file tree
Hide file tree
Showing 10 changed files with 691 additions and 196 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Kernels for LLM sampling.
top_p_sampling_from_probs
top_k_sampling_from_probs
min_p_sampling_from_probs
top_k_top_p_sampling_from_logits
top_k_top_p_sampling_from_probs
top_p_renorm_prob
top_k_renorm_prob
top_k_mask_logits
chain_speculative_sampling
311 changes: 217 additions & 94 deletions include/flashinfer/sampling.cuh

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Top-k and top-p sampling from probabilities");
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask");
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
Expand Down
28 changes: 18 additions & 10 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,33 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
bool deterministic);

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p,
bool deterministic);
torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_p_arr,
double top_p_val, bool deterministic);

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k, bool deterministic);
std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, bool deterministic);

std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor min_p, bool deterministic);
std::optional<torch::Tensor> maybe_min_p_arr,
double min_p_val, bool deterministic);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k, torch::Tensor top_p,
bool deterministic);
torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_p_arr,
double top_p_val, double eps);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);
torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, double eps);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, double eps);

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs,
Expand Down
166 changes: 128 additions & 38 deletions python/csrc/sampling.cu

Large diffs are not rendered by default.

52 changes: 36 additions & 16 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,46 @@
limitations under the License.
"""

from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
merge_state, merge_state_in_place, merge_states)
from .decode import (BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache)
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
merge_state,
merge_state_in_place,
merge_states,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
)
from .group_gemm import SegmentGEMMWrapper
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse)
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse,
)
from .quantization import packbits, segment_packbits
from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope,
apply_rope_inplace)
from .sampling import (chain_speculative_sampling, sampling_from_probs,
top_k_renorm_prob, top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs, top_p_renorm_prob,
top_p_sampling_from_probs)
from .rope import (
apply_llama31_rope,
apply_llama31_rope_inplace,
apply_rope,
apply_rope_inplace,
)
from .sampling import (
chain_speculative_sampling,
sampling_from_probs,
top_k_renorm_prob,
top_k_mask_logits,
top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs,
top_k_top_p_sampling_from_logits,
top_p_renorm_prob,
top_p_sampling_from_probs,
min_p_sampling_from_probs,
)
from .sparse import BlockSparseAttentionWrapper

try:
Expand Down
Loading

0 comments on commit 68df9c4

Please sign in to comment.