Skip to content

Commit

Permalink
torchlib
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 committed Oct 24, 2024
1 parent f6e0010 commit a870e3e
Show file tree
Hide file tree
Showing 24 changed files with 1,283 additions and 279 deletions.
20 changes: 9 additions & 11 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
return plan_info.ToVector();
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
Expand All @@ -111,9 +111,11 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
Expand Down Expand Up @@ -156,7 +158,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap,
sm_scale, rope_scale, rope_theta);

Expand Down Expand Up @@ -190,9 +192,5 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}
4 changes: 2 additions & 2 deletions flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
Expand Down
13 changes: 7 additions & 6 deletions flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,37 @@
*/
#include <torch/extension.h>

std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
std::optional<torch::Tensor> maybe_lse);

std::vector<int64_t> BatchPrefillWithKVCachePlan(
unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);
std::optional<torch::Tensor> maybe_lse);

std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
torch::Tensor BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
Expand Down
20 changes: 9 additions & 11 deletions flashinfer-aot/csrc_aot/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params

} // namespace flashinfer

std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse) {
std::optional<torch::Tensor> maybe_lse) {
auto device = q.device();
unsigned int head_dim = q.size(2);
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
Expand All @@ -58,9 +58,11 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
Expand Down Expand Up @@ -90,7 +92,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
? static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr())
: nullptr,
static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len,
q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);
Expand All @@ -109,9 +111,5 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}
6 changes: 3 additions & 3 deletions python/csrc/flashinfer_sampling_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");
Expand Down
22 changes: 6 additions & 16 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
return mask_logits;
}

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic) {
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(uniform_samples);
Expand All @@ -339,6 +339,8 @@ std::vector<torch::Tensor> chain_speculative_sampling(
CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1));
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
CHECK_EQ(vocab_size, target_probs.size(2));
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_token_num.size(0));

draft_probs = draft_probs.to(torch::kFloat32);
draft_token_ids = draft_token_ids.to(torch::kInt32);
Expand All @@ -349,18 +351,6 @@ std::vector<torch::Tensor> chain_speculative_sampling(
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
torch::dtype(torch::kInt32).device(device));

bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value();
bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value();
auto output_accepted_token_num = maybe_output_accepted_token_num.value_or(
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
auto output_emitted_token_num = maybe_output_emitted_token_num.value_or(
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
if (has_output_accepted_token_num) {
CHECK_EQ(has_output_emitted_token_num, true);
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
}

cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
Expand All @@ -372,5 +362,5 @@ std::vector<torch::Tensor> chain_speculative_sampling(
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));

return {output_token_ids, output_accepted_token_num, output_emitted_token_num};
return output_token_ids;
}
33 changes: 25 additions & 8 deletions python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
limitations under the License.
"""

from typing import Optional
from types import SimpleNamespace

import torch

from .jit import (
load_cuda_ops,
FLASHINFER_GEN_SRC_DIR,
gen_act_and_mul_cu,
has_prebuilt_ops,
load_cuda_ops,
)

import torch

from .utils import register_custom_op, register_fake_op

silu_def_cu_str = r"""
__device__ __forceinline__ float silu(const float& val) {
Expand Down Expand Up @@ -73,15 +74,31 @@ def get_act_and_mul_module(act_func_name: str):
if has_prebuilt_ops:
from . import _kernels

_jit_modules[act_func_name] = _kernels
module = _kernels
else:
_jit_modules[act_func_name] = compile_act_and_mul_module(
module = compile_act_and_mul_module(
act_func_name, act_func_def_str[act_func_name]
)

# torch library for act_and_mul
fname = f"{act_func_name}_and_mul"
fn = getattr(module, fname)

@register_custom_op(f"flashinfer::{fname}", mutates_args=("out",))
def _act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
fn(out, input)

@register_fake_op(f"flashinfer::{fname}")
def _fake_act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
pass

# Register the module
_jit_modules[act_func_name] = SimpleNamespace(**{fname: _act_and_mul})

return _jit_modules[act_func_name]


def _check_shape(input: torch.Tensor, output: torch.Tensor):
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
Expand Down
Loading

0 comments on commit a870e3e

Please sign in to comment.