diff --git a/csrc/generation/fused_get_rope.cu b/csrc/generation/fused_get_rope.cu index 011565262643..af34c0f075ba 100644 --- a/csrc/generation/fused_get_rope.cu +++ b/csrc/generation/fused_get_rope.cu @@ -44,6 +44,7 @@ __global__ __launch_bounds__(kBlockSize) void fused_get_rotary_embedding_neox(co const int32_t prompt_num, const float inv_head_dim, const int32_t elem_cnt, + const float theta, float* rope_embedding) { /* In Naive implementation, it will stacks [freqs, freqs] @@ -63,7 +64,7 @@ __global__ __launch_bounds__(kBlockSize) void fused_get_rotary_embedding_neox(co const int64_t position_offset = bsz_idx * max_position_seq_length + seq_idx + prompt_num; const int32_t half_head_idx = (idx % half_head_dim) * PackSize; const float exponent_factor = -static_cast(half_head_idx) * inv_head_dim; // * inv_head_dim equals to / head_dim. - const float inv_freq_val = powf(10000.0f, exponent_factor); + const float inv_freq_val = powf(theta, exponent_factor); const float freqs_val = static_cast(position_ids[position_offset]) * inv_freq_val; const float cos_embedding_val = cos(freqs_val); const float sin_embedding_val = sin(freqs_val); @@ -100,6 +101,7 @@ __global__ __launch_bounds__(kBlockSize) void fused_get_rotary_embedding(const i const int32_t prompt_num, const float inv_head_dim, const int32_t elem_cnt, + const float theta, float* rope_embedding) { /* In Naive implementation, it will stacks [freqs, freqs] @@ -119,7 +121,7 @@ __global__ __launch_bounds__(kBlockSize) void fused_get_rotary_embedding(const i const int64_t position_offset = bsz_idx * max_position_seq_length + seq_idx + prompt_num; const int32_t half_head_idx = (idx % half_head_dim) * PackSize; const float exponent_factor = -static_cast(half_head_idx) * inv_head_dim; // * inv_head_dim equals to / head_dim. - const float inv_freq_val = powf(10000.0f, exponent_factor); + const float inv_freq_val = powf(theta, exponent_factor); const float freqs_val = static_cast(position_ids[position_offset]) * inv_freq_val; const float cos_embedding_val = cos(freqs_val); const float sin_embedding_val = sin(freqs_val); @@ -145,6 +147,7 @@ std::vector GetRoPE(const paddle::Tensor& input_ids, const paddle::Tensor& position_ids, const paddle::Tensor& head_dim_shape_tensor, int prompt_num, + float theta, bool use_neox) { const int64_t batch_size = input_ids.shape()[0]; const int64_t max_seq_length = input_ids.shape()[1]; @@ -170,6 +173,7 @@ std::vector GetRoPE(const paddle::Tensor& input_ids, prompt_num, inv_head_dim, elem_cnt, + theta, reinterpret_cast(rotary_embedding.data())); } else { fused_get_rotary_embedding<<>> ( @@ -181,6 +185,7 @@ std::vector GetRoPE(const paddle::Tensor& input_ids, prompt_num, inv_head_dim, elem_cnt, + theta, reinterpret_cast(rotary_embedding.data())); } return {rotary_embedding}; @@ -209,7 +214,8 @@ PD_BUILD_OP(fused_get_rotary_embedding) .Inputs({"input_ids", "position_ids", "head_dim_shape_tensor"}) .Outputs({"rotary_embedding"}) .Attrs({"prompt_num: int", + "theta: float", "use_neox: bool"}) .SetKernelFn(PD_KERNEL(GetRoPE)) .SetInferShapeFn(PD_INFER_SHAPE(GetRoPEInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(GetRoPEInferDtype)); \ No newline at end of file + .SetInferDtypeFn(PD_INFER_DTYPE(GetRoPEInferDtype)); diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index f22eecb15d19..63e49c1edb56 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -432,10 +432,11 @@ def forward( seq_lens = seq_len_decoder if is_decoder else seq_len_encoder position_offset = 0 + theta = 10000.0 if not is_decoder and pre_caches is not None: position_offset = 128 new_rope = fused_get_rotary_embedding( - input_ids, position_ids, self.head_dim_shape_tensor, position_offset, True + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True ) with dy2st_nocheck_guard_context(): diff --git a/paddlenlp/experimental/transformers/qwen/modeling.py b/paddlenlp/experimental/transformers/qwen/modeling.py index fc6bb92a627d..c032a85e7ce2 100644 --- a/paddlenlp/experimental/transformers/qwen/modeling.py +++ b/paddlenlp/experimental/transformers/qwen/modeling.py @@ -320,11 +320,12 @@ def forward( seq_lens = seq_len_decoder if is_decoder else seq_len_encoder position_offset = 0 + theta = 10000.0 if not is_decoder and pre_caches is not None: position_offset = 128 new_rope = fused_get_rotary_embedding( - input_ids, position_ids, self.head_dim_shape_tensor, position_offset, True + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True ) with dy2st_nocheck_guard_context():