Skip to content

Commit

Permalink
mha int8
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 14, 2024
1 parent 1c7af00 commit 07b739e
Show file tree
Hide file tree
Showing 3 changed files with 479 additions and 4 deletions.
40 changes: 36 additions & 4 deletions src/layer/arm/multiheadattention_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
q_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = q_weight_data;
weights[1] = q_bias_data;
#if NCNN_INT8
weights[2] = q_weight_data_int8_scales;
#endif
q_gemm->load_model(ModelBinFromMatArray(weights));
q_gemm->create_pipeline(opt);

Expand All @@ -105,10 +111,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
k_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = k_weight_data;
weights[1] = k_bias_data;
#if NCNN_INT8
weights[2] = k_weight_data_int8_scales;
#endif
k_gemm->load_model(ModelBinFromMatArray(weights));
k_gemm->create_pipeline(opt);

Expand All @@ -134,10 +146,16 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 0); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
v_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = v_weight_data;
weights[1] = v_bias_data;
#if NCNN_INT8
weights[2] = v_weight_data_int8_scales;
#endif
v_gemm->load_model(ModelBinFromMatArray(weights));
v_gemm->create_pipeline(opt);

Expand All @@ -161,10 +179,18 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C = null
pd.set(11, 0); // output_N1M
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
o_gemm->load_param(pd);
Mat weights[2];
Mat weights[3];
weights[0] = out_weight_data;
weights[1] = out_bias_data;
#if NCNN_INT8
Mat out_weight_data_int8_scales(1);
out_weight_data_int8_scales[0] = out_weight_data_int8_scale;
weights[2] = out_weight_data_int8_scales;
#endif
o_gemm->load_model(ModelBinFromMatArray(weights));
o_gemm->create_pipeline(opt);

Expand All @@ -189,6 +215,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qk_gemm->load_param(pd);
qk_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
Expand All @@ -211,6 +240,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
pd.set(14, 1); // output_transpose
#if NCNN_INT8
pd.set(18, int8_scale_term);
#endif
qkv_gemm->load_param(pd);
qkv_gemm->load_model(ModelBinFromMatArray(0));
Option opt1 = opt;
Expand Down
Loading

0 comments on commit 07b739e

Please sign in to comment.