Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mha allow qdim differs from embed_dim #5519

Merged
merged 4 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/layer/arm/multiheadattention_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
qk_softmax->create_pipeline(opt);
}

const int qdim = weight_data_size / embed_dim;

{
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);
Expand All @@ -72,7 +74,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(9, qdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
Expand Down Expand Up @@ -158,7 +160,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C = null
pd.set(11, 0); // output_N1M
Expand Down
29 changes: 21 additions & 8 deletions src/layer/multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ int MultiHeadAttention::load_param(const ParamDict& pd)

int MultiHeadAttention::load_model(const ModelBin& mb)
{
q_weight_data = mb.load(weight_data_size, 0);
const int qdim = weight_data_size / embed_dim;

q_weight_data = mb.load(embed_dim * qdim, 0);
if (q_weight_data.empty())
return -100;

Expand All @@ -60,11 +62,11 @@ int MultiHeadAttention::load_model(const ModelBin& mb)
if (v_bias_data.empty())
return -100;

out_weight_data = mb.load(weight_data_size, 0);
out_weight_data = mb.load(qdim * embed_dim, 0);
if (out_weight_data.empty())
return -100;

out_bias_data = mb.load(embed_dim, 1);
out_bias_data = mb.load(qdim, 1);
if (out_bias_data.empty())
return -100;

Expand All @@ -82,21 +84,32 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
const int src_seqlen = q_blob.h;
const int dst_seqlen = k_blob.h;
const int embed_dim_per_head = embed_dim / num_heads;
const int qdim = weight_data_size / embed_dim;

// assert k_blob.h == v_blob.h

Mat& top_blob = top_blobs[0];
top_blob.create(embed_dim, src_seqlen, 4u, opt.blob_allocator);
top_blob.create(qdim, src_seqlen, 4u, opt.blob_allocator);
if (top_blob.empty())
return -1;
return -100;

Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator);
if (xq.empty())
return -100;
Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator);
if (xk.empty())
return -100;
Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator);
if (xv.empty())
return -100;

Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);
if (xqk.empty())
return -100;

Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator);
if (xqkv.empty())
return -100;

const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

Expand All @@ -114,10 +127,10 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
for (int j = 0; j < embed_dim_per_head; j++)
{
const float* ptr = q_blob.row(i);
const float* kptr = (const float*)q_weight_data + embed_dim * (q * embed_dim_per_head + j);
const float* kptr = (const float*)q_weight_data + qdim * (q * embed_dim_per_head + j);

float sum = q_bias_data[q * embed_dim_per_head + j];
for (int k = 0; k < embed_dim; k++)
for (int k = 0; k < qdim; k++)
{
sum += *ptr++ * *kptr++;
}
Expand Down Expand Up @@ -286,7 +299,7 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
{
float* outptr = top_blob.row(i);

for (int j = 0; j < embed_dim; j++)
for (int j = 0; j < qdim; j++)
{
const float* ptr = xqkv.channel(i);
const float* kptr = (const float*)out_weight_data + embed_dim * j;
Expand Down
5 changes: 3 additions & 2 deletions src/layer/vulkan/multiheadattention_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ MultiHeadAttention_vulkan::MultiHeadAttention_vulkan()
int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
{
const int embed_dim_per_head = embed_dim / num_heads;
const int qdim = weight_data_size / embed_dim;
{
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

Expand All @@ -61,7 +62,7 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(9, qdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
// pd.set(12, 1); // output_elempack
Expand Down Expand Up @@ -220,7 +221,7 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
Expand Down
6 changes: 4 additions & 2 deletions src/layer/x86/multiheadattention_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ MultiHeadAttention_x86::MultiHeadAttention_x86()

int MultiHeadAttention_x86::create_pipeline(const Option& opt)
{
const int qdim = weight_data_size / embed_dim;

{
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);
Expand All @@ -53,7 +55,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(9, qdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
Expand Down Expand Up @@ -191,7 +193,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
Expand Down
83 changes: 44 additions & 39 deletions tests/test_multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,29 @@

#include "testutil.h"

static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int num_heads, int kdim, int vdim, int attn_mask)
static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
{
int embed_dim = q.w;
const int qdim = q.w;
const int kdim = k.w;
const int vdim = v.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * embed_dim);
pd.set(2, embed_dim * qdim);
pd.set(3, kdim);
pd.set(4, vdim);
pd.set(5, attn_mask);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * embed_dim);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * kdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * vdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(embed_dim * embed_dim);
weights[7] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(3);
as[0] = q;
Expand All @@ -51,32 +53,33 @@ static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, num_heads, kdim, vdim, attn_mask);
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
}

return ret;
}

static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int num_heads, int kvdim)
static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads)
{
int embed_dim = q.w;
const int qdim = q.w;
const int kvdim = kv.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * embed_dim);
pd.set(2, embed_dim * qdim);
pd.set(3, kvdim);
pd.set(4, kvdim);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * embed_dim);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * kvdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * kvdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(embed_dim * embed_dim);
weights[7] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(2);
as[0] = q;
Expand All @@ -87,30 +90,32 @@ static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& k
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, num_heads, kvdim);
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim);
}

return ret;
}

static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads)
{
int embed_dim = a.w;
const int qdim = a.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * embed_dim);
pd.set(2, embed_dim * qdim);
pd.set(3, qdim);
pd.set(4, qdim);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * embed_dim);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * embed_dim);
weights[2] = RandomMat(embed_dim * qdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * embed_dim);
weights[4] = RandomMat(embed_dim * qdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(embed_dim * embed_dim);
weights[7] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(1);
as[0] = a;
Expand All @@ -120,7 +125,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) num_heads=%d\n", a.w, a.h, num_heads);
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads);
}

return ret;
Expand All @@ -129,32 +134,32 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
static int test_multiheadattention_0()
{
return 0
|| test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 2, 32, 20, 0)
|| test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 2, 32, 18, 1)
|| test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 4, 64, 64, 0)
|| test_multiheadattention(RandomMat(64, 127), RandomMat(64, 127), RandomMat(64, 127), 16, 64, 64, 1)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 2, 44, 55, 0)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 127), RandomMat(55, 127), 4, 44, 55, 1)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 3, 28, 32, 0)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 3, 28, 11, 1);
|| test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
|| test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
|| test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0)
|| test_multiheadattention(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0)
|| test_multiheadattention(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
}

static int test_multiheadattention_1()
{
return 0
|| test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 4, 64)
|| test_multiheadattention_samekv(RandomMat(64, 127), RandomMat(64, 127), 16, 64)
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 2, 44)
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(22, 127), 4, 22)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 3, 28)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 3, 11);
|| test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4)
|| test_multiheadattention_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16)
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2)
|| test_multiheadattention_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3);
}

static int test_multiheadattention_2()
{
return 0
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 4)
|| test_multiheadattention_sameqkv(RandomMat(64, 127), 8);
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 64, 4)
|| test_multiheadattention_sameqkv(RandomMat(48, 127), 64, 8);
}

int main()
Expand Down
Loading
Loading