From 4f40420e24d65cabd8be731e12f96a5ef0795a4b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 29 Oct 2024 14:51:19 -0700 Subject: [PATCH] feat: support huggingface transformer style rope interface (#568) Previously our rope apis assume the position indices of each request is contiguous, which is not appropriate for applications such as speculative decoding, this PR fixes the issue by supporting the huggingface transformer-style API which use `pos_ids` argument to specify positions. This PR implements parts of the feature of #530 , other requests are coming in later PRs. cc @dreaming-panda @abcdabcd987 @byronhsu --- flashinfer-aot/csrc_aot/flashinfer_ops.cu | 39 ++-- include/flashinfer/pos_enc.cuh | 243 +++++++++++++--------- python/csrc/flashinfer_rope_ops.cu | 30 ++- python/csrc/rope.cu | 132 ++++++------ python/flashinfer/__init__.py | 2 + python/flashinfer/rope.py | 153 +++++++++++++- tests/test_rope.py | 133 ++++++++++-- 7 files changed, 521 insertions(+), 211 deletions(-) diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index 05b259f5..9ab9a86c 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -61,10 +61,11 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, unsigned int top_k_val); -torch::Tensor chain_speculative_sampling( - torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, - torch::Tensor target_probs, torch::Tensor output_accepted_token_num, - torch::Tensor output_emitted_token_num, bool deterministic); +torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, + torch::Tensor uniform_samples, torch::Tensor target_probs, + torch::Tensor output_accepted_token_num, + torch::Tensor output_emitted_token_num, + bool deterministic); void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); @@ -82,24 +83,30 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); - -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); +std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); + +std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, @@ -141,11 +148,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); m.def("apply_rope", &apply_rope, "Apply RoPE"); m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index d6f96e4c..ed0b732a 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -17,6 +17,7 @@ #define FLASHINFER_POS_ENC_CUH_ #include +#include #include #include "layout.cuh" @@ -94,6 +95,25 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope( return vec; } +template +__device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin( + const T* x, const vec_t& cos, const vec_t& sin) { + constexpr uint32_t head_dim = vec_size * bdx; + vec_t permuted_vec, vec; + vec.cast_load(x + threadIdx.x * vec_size); + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < head_dim / 2) + ? threadIdx.x * vec_size + head_dim / 2 + : threadIdx.x * vec_size - head_dim / 2)); + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = + vec[i] * cos[i] + + ((threadIdx.x * vec_size < head_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin[i]; + } + return vec; +} + /*! * \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave, * return thread-local vector. @@ -122,13 +142,28 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope_interleav return vec; } +template +__device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin_interleave( + const T* x, const vec_t& cos, const vec_t& sin) { + vec_t vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + vec_before = vec; + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = vec[i] * cos[i] + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i]; + } + return vec; +} + template -__global__ void BatchQKApplyRotaryInPlaceKernel( - DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, - IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a, - float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { +__global__ void BatchQKApplyRotaryPosIdsKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, + size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, float smooth_b, + float rope_rcp_scale, float rope_rcp_theta) { uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t freq; @@ -146,61 +181,56 @@ __global__ void BatchQKApplyRotaryInPlaceKernel( freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; } - if (bx < batch_size * num_qo_heads) { - // apply rotary to q - const uint32_t batch_idx = bx / num_qo_heads; - const uint32_t qo_head_idx = bx % num_qo_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; -#pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t cos, sin; + + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(pos) * freq[i]; + __sincosf(embed, &sin[i], &cos[i]); + } + +#pragma unroll 1 + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); vec_t q_vec; - if (i * bdy + ty < seq_len) { - DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - q_stride_n, q_stride_h); - if constexpr (interleave) { - q_vec = - vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty); - } else { - q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); - } - q_vec.cast_store(q_ptr + tx * vec_size); + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin); } + q_vec.cast_store(q_rope_ptr + tx * vec_size); } - } else { - // apply rotary to k - uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; - uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; -#pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + +#pragma unroll 1 + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); vec_t k_vec; - if (i * bdy + ty < seq_len) { - DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - k_stride_n, k_stride_h); - if constexpr (interleave) { - k_vec = - vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty); - } else { - k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); - } - k_vec.cast_store(k_ptr + tx * vec_size); + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin); } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } } } template -__global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restrict__ k, - DType* __restrict__ q_rope, DType* __restrict__ k_rope, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - float smooth_a, float smooth_b, float rope_rcp_scale, - float rope_rcp_theta) { +__global__ void BatchQKApplyRotaryKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + float smooth_a, float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t freq; @@ -232,8 +262,7 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric q_stride_n, q_stride_h); DType* q_rope_ptr = q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - /*q_stride_n=*/num_qo_heads * head_dim, - /*q_stride_h=*/head_dim); + q_rope_stride_n, q_rope_stride_h); if constexpr (interleave) { q_vec = vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty); @@ -257,8 +286,7 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric k_stride_n, k_stride_h); DType* k_rope_ptr = k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - /*kv_stride_n=*/num_kv_heads * head_dim, - /*kv_stride_h=*/head_dim); + k_rope_stride_n, k_rope_stride_h); if constexpr (interleave) { k_vec = vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty); @@ -281,13 +309,14 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric } template -cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, - cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope, + IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = 0.f; @@ -299,21 +328,26 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ constexpr uint32_t bdx = HEAD_DIM / vec_size; uint32_t num_threads = std::max(128U, bdx); uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nblks((nnz + bdy - 1) / bdy); dim3 nthrs(bdx, bdy); auto kernel = - BatchQKApplyRotaryInPlaceKernel; + BatchQKApplyRotaryPosIdsKernel; void* args[] = {(void*)&q, (void*)&k, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, + (void*)&q_rope, + (void*)&k_rope, + (void*)&pos_ids, + (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, @@ -326,16 +360,18 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ } template -cudaError_t BatchQKApplyLlama31RotaryInPlace( - DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, - IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); - float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); + float smooth_a = 0.f; + float smooth_b = 0.f; DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -345,10 +381,11 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( uint32_t bdy = num_threads / bdx; dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); dim3 nthrs(bdx, bdy); - auto kernel = - BatchQKApplyRotaryInPlaceKernel; + auto kernel = BatchQKApplyRotaryKernel; void* args[] = {(void*)&q, (void*)&k, + (void*)&q_rope, + (void*)&k_rope, (void*)&indptr, (void*)&offsets, (void*)&batch_size, @@ -358,6 +395,10 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, @@ -370,17 +411,17 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( } template -cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, - DType* __restrict__ q_rope, DType* __restrict__ k_rope, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, bool interleave, - float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyLlama31Rotary( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = 0.f; - float smooth_b = 0.f; + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -404,6 +445,10 @@ cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, @@ -416,15 +461,13 @@ cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, } template -cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ k, - DType* __restrict__ q_rope, DType* __restrict__ k_rope, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyLlama31RotaryPosIds( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave, + float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); @@ -436,22 +479,26 @@ cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ constexpr uint32_t bdx = HEAD_DIM / vec_size; uint32_t num_threads = std::max(128U, bdx); uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nblks((nnz + bdy - 1) / bdy); dim3 nthrs(bdx, bdy); - auto kernel = BatchQKApplyRotaryKernel; + auto kernel = + BatchQKApplyRotaryPosIdsKernel; void* args[] = {(void*)&q, (void*)&k, (void*)&q_rope, (void*)&k_rope, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, + (void*)&pos_ids, + (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, (void*)&smooth_a, (void*)&smooth_b, (void*)&rope_rcp_scale, diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index 4075930b..ef046ead 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -15,28 +15,36 @@ */ #include -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); +#include -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); +std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta); + +std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); m.def("apply_rope", &apply_rope, "Apply RoPE"); m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); + m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, + "Apply Llama 3.1 style RoPE with positional ids"); } diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index bb8d5a19..d2ca9155 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -19,9 +19,10 @@ using namespace flashinfer; -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -44,68 +45,80 @@ void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); indptr = indptr.to(torch::kInt32); offsets = offsets.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyRotaryInPlace( + cudaError_t status = BatchQKApplyRotary( static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + + k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + std::string(cudaGetErrorString(status))); return true; }); + + return {q_rope, k_rope}; } -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length) { +std::vector apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous - CHECK_INPUT(indptr); - CHECK_INPUT(offsets); + CHECK_INPUT(pos_ids); auto device = q.device(); CHECK_EQ(k.device(), device); - CHECK_DIM(3, q); // q: (nnz, H_Q, D) - CHECK_DIM(3, k); // k: (nnz, H_K, D) - CHECK_DIM(1, indptr); // indptr: (B + 1) - CHECK_DIM(1, offsets); // offsets: (B) + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_EQ(q.size(0), k.size(0)); CHECK_EQ(q.size(2), k.size(2)); unsigned int num_qo_heads = q.size(1); unsigned int num_kv_heads = k.size(1); unsigned int head_dim = q.size(2); - unsigned int batch_size = offsets.size(0); - CHECK_EQ(indptr.size(0), batch_size + 1); + unsigned int nnz = q.size(0); size_t q_stride_n = q.stride(0); size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); - indptr = indptr.to(torch::kInt32); - offsets = offsets.to(torch::kInt32); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + pos_ids = pos_ids.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyLlama31RotaryInPlace( + cudaError_t status = BatchQKApplyRotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, - old_context_length, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); return true; }); + + return {q_rope, k_rope}; } -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { +std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -128,21 +141,24 @@ std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::T size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); indptr = indptr.to(torch::kInt32); offsets = offsets.to(torch::kInt32); - // NOTE(Zihao): empty_like do not copy strides so it's okay to use it here. - auto q_rope = torch::empty_like(q); - auto k_rope = torch::empty_like(k); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyRotary( + cudaError_t status = BatchQKApplyLlama31Rotary( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + + k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, + rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + std::string(cudaGetErrorString(status))); return true; }); @@ -150,50 +166,46 @@ std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::T return {q_rope, k_rope}; } -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length) { +std::vector apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, + torch::Tensor q_rope, torch::Tensor k_rope, + torch::Tensor pos_ids, bool interleave, + float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous - CHECK_INPUT(indptr); - CHECK_INPUT(offsets); + CHECK_INPUT(pos_ids); auto device = q.device(); CHECK_EQ(k.device(), device); - CHECK_DIM(3, q); // q: (nnz, H_Q, D) - CHECK_DIM(3, k); // k: (nnz, H_K, D) - CHECK_DIM(1, indptr); // indptr: (B + 1) - CHECK_DIM(1, offsets); // offsets: (B) + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_EQ(q.size(0), k.size(0)); CHECK_EQ(q.size(2), k.size(2)); unsigned int num_qo_heads = q.size(1); unsigned int num_kv_heads = k.size(1); unsigned int head_dim = q.size(2); - unsigned int batch_size = offsets.size(0); - CHECK_EQ(indptr.size(0), batch_size + 1); + unsigned int nnz = q.size(0); size_t q_stride_n = q.stride(0); size_t q_stride_h = q.stride(1); size_t k_stride_n = k.stride(0); size_t k_stride_h = k.stride(1); - indptr = indptr.to(torch::kInt32); - offsets = offsets.to(torch::kInt32); - - // NOTE(Zihao): empty_like do not copy strides so it's okay to use it here. - auto q_rope = torch::empty_like(q); - auto k_rope = torch::empty_like(k); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + pos_ids = pos_ids.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - cudaError_t status = BatchQKApplyLlama31Rotary( + cudaError_t status = BatchQKApplyLlama31RotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), - static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, - old_context_length, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, + high_freq_factor, old_context_length, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); return true; }); diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index cb023a67..724fc3f0 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -60,6 +60,8 @@ apply_llama31_rope_inplace as apply_llama31_rope_inplace, apply_rope as apply_rope, apply_rope_inplace as apply_rope_inplace, + apply_rope_pos_ids as apply_rope_pos_ids, + apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace, ) from .sampling import ( chain_speculative_sampling as chain_speculative_sampling, diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 408c1f4a..29c2fcb7 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -118,8 +118,8 @@ def apply_rope_inplace( -------- apply_rope """ - return get_rope_module().apply_rope_inplace( - q, k, indptr, offsets, interleave, rope_scale, rope_theta + get_rope_module().apply_rope( + q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta ) @@ -136,6 +136,70 @@ def _fake_apply_rope_inplace( pass +@register_custom_op("flashinfer::apply_rope_pos_ids_inplace", mutates_args=("q", "k")) +def apply_rope_pos_ids_inplace( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + + See Also + -------- + apply_rope_pos_ids + """ + get_rope_module().apply_rope_pos_ids( + q, k, q, k, pos_ids, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids_inplace") +def _fake_apply_rope_pos_ids_inplace( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + pass + + @register_custom_op("flashinfer::apply_llama31_rope_inplace", mutates_args=("q", "k")) def apply_llama31_rope_inplace( q: torch.Tensor, @@ -222,7 +286,9 @@ def apply_llama31_rope_inplace( -------- apply_llama31_rope """ - return get_rope_module().apply_llama31_rope_inplace( + get_rope_module().apply_llama31_rope( + q, + k, q, k, indptr, @@ -339,8 +405,10 @@ def apply_rope( -------- apply_rope_inplace """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) return get_rope_module().apply_rope( - q, k, indptr, offsets, interleave, rope_scale, rope_theta + q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta ) @@ -357,6 +425,79 @@ def _fake_apply_rope( return torch.empty_like(q), torch.empty_like(k) +@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=()) +def apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + pos_ids : torch.Tensor + Position indices, shape: ``(batch_size + 1)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + + See Also + -------- + apply_rope_inplace + """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) + return get_rope_module().apply_rope_pos_ids( + q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids") +def _fake_apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) + + @register_custom_op("flashinfer::apply_llama31_rope", mutates_args=()) def apply_llama31_rope( q: torch.Tensor, @@ -454,9 +595,13 @@ def apply_llama31_rope( -------- apply_llama31_rope_inplace """ + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) return get_rope_module().apply_llama31_rope( q, k, + q_rope, + k_rope, indptr, offsets, interleave, diff --git a/tests/test_rope.py b/tests/test_rope.py index 1750ed34..f7ee84c9 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -69,12 +69,8 @@ def test_llama_rope_inplace( ) # compare - torch.testing.assert_close( - q_rope_ref, q, rtol=1e-3, atol=1e-3 - ) - torch.testing.assert_close( - k_rope_ref, k, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(q_rope_ref, q, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -125,12 +121,111 @@ def test_llama_rope( ) # compare - torch.testing.assert_close( - q_rope_ref, q_rope, rtol=1e-3, atol=1e-3 + torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama_rope_pos_ids( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", ) - torch.testing.assert_close( - k_rope_ref, k_rope, rtol=1e-3, atol=1e-3 + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + pos_ids = torch.cat( + [ + torch.arange(offset, qkv_len + offset, dtype=torch.int32) + for _ in range(batch_size) + ] + ).to("cuda:0") + + q_rope, k_rope = flashinfer.apply_rope( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( + q, k, pos_ids, interleave=True, rope_theta=1e4 + ) + + # compare + torch.testing.assert_close(q_rope_pos_ids, q_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_pos_ids, k_rope, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama_rope_pos_ids_inplace( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + pos_ids = torch.cat( + [ + torch.arange(offset, qkv_len + offset, dtype=torch.int32) + for _ in range(batch_size) + ] + ).to("cuda:0") + + q_clone = q.clone() + k_clone = k.clone() + + flashinfer.apply_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + flashinfer.apply_rope_pos_ids_inplace( + q_clone, k_clone, pos_ids, interleave=True, rope_theta=1e4 + ) + + # compare + torch.testing.assert_close(q_clone, q, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_clone, k, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -181,12 +276,8 @@ def test_llama31_rope_inplace( ) # compare - torch.testing.assert_close( - q_rope_ref, q, rtol=1e-3, atol=1e-3 - ) - torch.testing.assert_close( - k_rope_ref, k, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(q_rope_ref, q, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -237,12 +328,8 @@ def test_llama31_rope( ) # compare - torch.testing.assert_close( - q_rope_ref, q_rope, rtol=1e-3, atol=1e-3 - ) - torch.testing.assert_close( - k_rope_ref, k_rope, rtol=1e-3, atol=1e-3 - ) + torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) if __name__ == "__main__": @@ -250,3 +337,5 @@ def test_llama31_rope( test_llama31_rope_inplace(1, 1, 8, 8, 0, 128) test_llama_rope(2, 1, 8, 8, 1, 128) test_llama31_rope(1, 1, 8, 8, 0, 128) + test_llama_rope_pos_ids(2, 1, 8, 8, 1, 128) + test_llama_rope_pos_ids_inplace(2, 1, 8, 8, 1, 128)