Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: torch.compile and custom_op support #554

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
return plan_info.ToVector();
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
Expand All @@ -111,9 +111,11 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
Expand Down Expand Up @@ -160,7 +162,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);

Expand Down Expand Up @@ -194,9 +196,5 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}
4 changes: 2 additions & 2 deletions flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
Expand Down
13 changes: 7 additions & 6 deletions flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,37 @@
*/
#include <torch/extension.h>

std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
std::optional<torch::Tensor> maybe_lse);

std::vector<int64_t> BatchPrefillWithKVCachePlan(
unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);
std::optional<torch::Tensor> maybe_lse);

std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
torch::Tensor BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
Expand Down
20 changes: 9 additions & 11 deletions flashinfer-aot/csrc_aot/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params

} // namespace flashinfer

std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor single_prefill_with_kv_cache(
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse) {
std::optional<torch::Tensor> maybe_lse) {
auto device = q.device();
unsigned int head_dim = q.size(2);
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
Expand All @@ -58,9 +58,11 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
}
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
Expand Down Expand Up @@ -90,7 +92,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
? static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr())
: nullptr,
static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len,
q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);
Expand All @@ -109,9 +111,5 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}
6 changes: 3 additions & 3 deletions python/csrc/flashinfer_sampling_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");
Expand Down
22 changes: 6 additions & 16 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
return mask_logits;
}

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic) {
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(uniform_samples);
Expand All @@ -339,6 +339,8 @@ std::vector<torch::Tensor> chain_speculative_sampling(
CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1));
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
CHECK_EQ(vocab_size, target_probs.size(2));
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_token_num.size(0));

draft_probs = draft_probs.to(torch::kFloat32);
draft_token_ids = draft_token_ids.to(torch::kInt32);
Expand All @@ -349,18 +351,6 @@ std::vector<torch::Tensor> chain_speculative_sampling(
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
torch::dtype(torch::kInt32).device(device));

bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value();
bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value();
auto output_accepted_token_num = maybe_output_accepted_token_num.value_or(
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
auto output_emitted_token_num = maybe_output_emitted_token_num.value_or(
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
if (has_output_accepted_token_num) {
CHECK_EQ(has_output_emitted_token_num, true);
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
}

cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
Expand All @@ -372,5 +362,5 @@ std::vector<torch::Tensor> chain_speculative_sampling(
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));

return {output_token_ids, output_accepted_token_num, output_emitted_token_num};
return output_token_ids;
}
33 changes: 25 additions & 8 deletions python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
limitations under the License.
"""

from typing import Optional
from types import SimpleNamespace

import torch

from .jit import (
load_cuda_ops,
FLASHINFER_GEN_SRC_DIR,
gen_act_and_mul_cu,
has_prebuilt_ops,
load_cuda_ops,
)

import torch

from .utils import register_custom_op, register_fake_op

silu_def_cu_str = r"""
__device__ __forceinline__ float silu(const float& val) {
Expand Down Expand Up @@ -73,15 +74,31 @@ def get_act_and_mul_module(act_func_name: str):
if has_prebuilt_ops:
from . import _kernels

_jit_modules[act_func_name] = _kernels
module = _kernels
else:
_jit_modules[act_func_name] = compile_act_and_mul_module(
module = compile_act_and_mul_module(
act_func_name, act_func_def_str[act_func_name]
)

# torch library for act_and_mul
fname = f"{act_func_name}_and_mul"
fn = getattr(module, fname)

@register_custom_op(f"flashinfer::{fname}", mutates_args=("out",))
def _act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
fn(out, input)

@register_fake_op(f"flashinfer::{fname}")
def _fake_act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
pass

# Register the module
_jit_modules[act_func_name] = SimpleNamespace(**{fname: _act_and_mul})

return _jit_modules[act_func_name]


def _check_shape(input: torch.Tensor, output: torch.Tensor):
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
Expand Down
Loading