Skip to content

Commit

Permalink
x86 vulkan fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 15, 2024
1 parent 0f0f310 commit 8113cb8
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 55 deletions.
13 changes: 13 additions & 0 deletions src/layer/vulkan/multiheadattention_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ MultiHeadAttention_vulkan::MultiHeadAttention_vulkan()
pipeline_multiheadattention_qkv_cross_pack4to1 = 0;
}

int MultiHeadAttention_vulkan::load_param(const ParamDict& pd)
{
int ret = MultiHeadAttention::load_param(pd);

if (int8_scale_term)
{
support_vulkan = false;
support_image_storage = false;
}

return ret;
}

int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
{
const int embed_dim_per_head = embed_dim / num_heads;
Expand Down
2 changes: 2 additions & 0 deletions src/layer/vulkan/multiheadattention_vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class MultiHeadAttention_vulkan : public MultiHeadAttention
public:
MultiHeadAttention_vulkan();

virtual int load_param(const ParamDict& pd);

virtual int create_pipeline(const Option& opt);
virtual int destroy_pipeline(const Option& opt);

Expand Down
163 changes: 108 additions & 55 deletions src/layer/x86/multiheadattention_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,26 @@ MultiHeadAttention_x86::MultiHeadAttention_x86()
o_gemm = 0;
}

int MultiHeadAttention_x86::create_pipeline(const Option& opt)
int MultiHeadAttention_x86::create_pipeline(const Option& _opt)
{
Option opt = _opt;
if (int8_scale_term)
{
support_packing = false;

opt.use_packing_layout = false;// TODO enable packing
}

{
qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax);
ncnn::ParamDict pd;
pd.set(0, -1);
pd.set(1, 1);
qk_softmax->load_param(pd);
qk_softmax->load_model(ModelBinFromMatArray(0));
qk_softmax->create_pipeline(opt);
}

const int qdim = weight_data_size / embed_dim;

{
Expand All @@ -57,10 +75,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
q_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = q_weight_data;
weights[1] = q_bias_data;
#if NCNN_INT8
weights[2] = q_weight_data_int8_scales;
#endif
q_gemm->load_model(ModelBinFromMatArray(weights));
q_gemm->create_pipeline(opt);

Expand All @@ -86,10 +110,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
k_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = k_weight_data;
weights[1] = k_bias_data;
#if NCNN_INT8
weights[2] = k_weight_data_int8_scales;
#endif
k_gemm->load_model(ModelBinFromMatArray(weights));
k_gemm->create_pipeline(opt);

Expand All @@ -115,10 +145,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
v_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = v_weight_data;
weights[1] = v_bias_data;
#if NCNN_INT8
weights[2] = v_weight_data_int8_scales;
#endif
v_gemm->load_model(ModelBinFromMatArray(weights));
v_gemm->create_pipeline(opt);

Expand All @@ -129,6 +165,41 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
}
}

{
o_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
ncnn::ParamDict pd;
pd.set(2, 1); // transA
pd.set(3, 1); // transB
pd.set(4, 0); // constantA
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
o_gemm->load_param(pd);
Mat weights[3];
weights[0] = out_weight_data;
weights[1] = out_bias_data;
#if NCNN_INT8
Mat out_weight_data_int8_scales(1);
out_weight_data_int8_scales[0] = out_weight_data_int8_scale;
weights[2] = out_weight_data_int8_scales;
#endif
o_gemm->load_model(ModelBinFromMatArray(weights));
o_gemm->create_pipeline(opt);

if (opt.lightmode)
{
out_weight_data.release();
out_bias_data.release();
}
}

{
qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
ncnn::ParamDict pd;
Expand All @@ -143,12 +214,16 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qk_gemm->load_param(pd);
qk_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
opt1.num_threads = 1;
qk_gemm->create_pipeline(opt1);
}

{
qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
ncnn::ParamDict pd;
Expand All @@ -164,55 +239,34 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 1); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qkv_gemm->load_param(pd);
qkv_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
opt1.num_threads = 1;
qkv_gemm->create_pipeline(opt1);
}

return 0;
}

int MultiHeadAttention_x86::destroy_pipeline(const Option& _opt)
{
Option opt = _opt;
if (int8_scale_term)
{
qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax);
ncnn::ParamDict pd;
pd.set(0, -1);
pd.set(1, 1);
qk_softmax->load_param(pd);
qk_softmax->load_model(ModelBinFromMatArray(0));
qk_softmax->create_pipeline(opt);
opt.use_packing_layout = false;// TODO enable packing
}

if (qk_softmax)
{
o_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
ncnn::ParamDict pd;
pd.set(2, 1); // transA
pd.set(3, 1); // transB
pd.set(4, 0); // constantA
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
o_gemm->load_param(pd);
Mat weights[2];
weights[0] = out_weight_data;
weights[1] = out_bias_data;
o_gemm->load_model(ModelBinFromMatArray(weights));
o_gemm->create_pipeline(opt);

if (opt.lightmode)
{
out_weight_data.release();
out_bias_data.release();
}
qk_softmax->destroy_pipeline(opt);
delete qk_softmax;
qk_softmax = 0;
}

return 0;
}

int MultiHeadAttention_x86::destroy_pipeline(const Option& opt)
{
if (q_gemm)
{
q_gemm->destroy_pipeline(opt);
Expand All @@ -234,6 +288,13 @@ int MultiHeadAttention_x86::destroy_pipeline(const Option& opt)
v_gemm = 0;
}

if (o_gemm)
{
o_gemm->destroy_pipeline(opt);
delete o_gemm;
o_gemm = 0;
}

if (qk_gemm)
{
qk_gemm->destroy_pipeline(opt);
Expand All @@ -247,30 +308,22 @@ int MultiHeadAttention_x86::destroy_pipeline(const Option& opt)
qkv_gemm = 0;
}

if (qk_softmax)
{
qk_softmax->destroy_pipeline(opt);
delete qk_softmax;
qk_softmax = 0;
}

if (o_gemm)
{
o_gemm->destroy_pipeline(opt);
delete o_gemm;
o_gemm = 0;
}

return 0;
}

int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& _opt) const
{
const Mat& q_blob = bottom_blobs[0];
const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1];
const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2];
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat();

Option opt = _opt;
if (int8_scale_term)
{
opt.use_packing_layout = false;// TODO enable packing
}

Mat attn_mask_blob_unpacked;
if (attn_mask && attn_mask_blob.elempack != 1)
{
Expand Down

0 comments on commit 8113cb8

Please sign in to comment.