Skip to content

Commit

Permalink
some edit
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
  • Loading branch information
liqunfu committed Feb 15, 2025
1 parent 43bdb44 commit 46353d8
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 23 deletions.
21 changes: 8 additions & 13 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,16 @@ class GQAAttentionBase {
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;

const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
{
ComputeAttentionProbs<T>(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
}
ComputeAttentionProbs<T>(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
{
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
}
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
return Status::OK();
}

Expand All @@ -126,7 +122,6 @@ class GQAAttentionBase {
const bool is_prompt, // whether it is prompt
ThreadPool* tp, // thread pool
AllocatorPtr allocator) const { // allocator for temporary buffer

const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,9 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {

ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
// Compute the attention score and apply the score to V
auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
return ret;
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
} // namespace contrib
} // namespace onnxruntime
5 changes: 1 addition & 4 deletions onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ Module Name:
//
const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() {
MLAS_ROPE_DISPATCH d;

#if defined(MLAS_TARGET_AMD64)
d.SRope = rope_avx2::RopeKernel_Avx2;
#endif
d.SRope = rope_avx2::RopeKernel_Avx2;
return d;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ RopeKernel_Avx2_Impl<true>(
for (; i + 15 < dim; i += 16) {
float32x8_t x0 = _mm256_loadu_ps(input + i);
float32x8_t x1 = _mm256_loadu_ps(input + i + 8);
//Load imaginary and real values to seperate non-interleaved vectors
//Load imaginary and real values to separate non-interleaved vectors
float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
__m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
Expand All @@ -116,7 +116,7 @@ RopeKernel_Avx2_Impl<true>(
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0)));
float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask
float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask
//Load imaginary and real values to seperate non-interleaved vectors
//Load imaginary and real values to separate non-interleaved vectors
float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
__m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
Expand Down

0 comments on commit 46353d8

Please sign in to comment.