From 07b840087b57a671a1b0f6f7814e05be8a8d71ef Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 26 Jul 2023 10:43:26 +0800 Subject: [PATCH] code clean for mha arm opt (#4613) --- src/layer/arm/multiheadattention_arm.cpp | 394 +++-------------------- 1 file changed, 44 insertions(+), 350 deletions(-) diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index 4e491bfc428..15eca715699 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -41,16 +41,11 @@ MultiHeadAttention_arm::MultiHeadAttention_arm() qk_softmax = 0; } -int MultiHeadAttention_arm::create_pipeline(const Option& opt) +int MultiHeadAttention_arm::create_pipeline(const Option& _opt) { - Option optn = opt; - optn.use_bf16_storage = false; - - Option opt32 = opt; - opt32.use_bf16_storage = false; - opt32.use_fp16_arithmetic = false; - opt32.use_fp16_packed = false; - opt32.use_fp16_storage = false; + Option opt = _opt; + opt.use_fp16_storage &= support_fp16_storage; + opt.use_bf16_storage &= support_bf16_storage; { qk_softmax = ncnn::create_layer(ncnn::LayerType::Softmax); @@ -59,185 +54,8 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) pd.set(1, 1); qk_softmax->load_param(pd); qk_softmax->load_model(ModelBinFromMatArray(0)); - qk_softmax->create_pipeline(opt32); - } - -#if NCNN_ARM82 - if (support_fp16_storage && optn.use_fp16_storage) - { - Option optopt = optn; - - { - const int embed_dim_per_head = embed_dim / num_heads; - const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); - - q_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(0, inv_sqrt_embed_dim_per_head); - pd.set(1, 1.f); - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 1); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC - pd.set(7, embed_dim); // M - pd.set(8, 0); // N - pd.set(9, embed_dim); // K - pd.set(10, 1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(14, 0); // output_transpose - q_gemm->load_param(pd); - Mat weights[2]; - weights[0] = q_weight_data; - weights[1] = q_bias_data; - q_gemm->load_model(ModelBinFromMatArray(weights)); - q_gemm->create_pipeline(optopt); - - if (optopt.lightmode) - { - q_weight_data.release(); - q_bias_data.release(); - } - } - - { - k_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 1); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC - pd.set(7, embed_dim); // M - pd.set(8, 0); // N - pd.set(9, kdim); // K - pd.set(10, 1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(14, 0); // output_transpose - k_gemm->load_param(pd); - Mat weights[2]; - weights[0] = k_weight_data; - weights[1] = k_bias_data; - k_gemm->load_model(ModelBinFromMatArray(weights)); - k_gemm->create_pipeline(optopt); - - if (optopt.lightmode) - { - k_weight_data.release(); - k_bias_data.release(); - } - } - - { - v_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 1); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC - pd.set(7, embed_dim); // M - pd.set(8, 0); // N - pd.set(9, vdim); // K - pd.set(10, 1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(14, 0); // output_transpose - v_gemm->load_param(pd); - Mat weights[2]; - weights[0] = v_weight_data; - weights[1] = v_bias_data; - v_gemm->load_model(ModelBinFromMatArray(weights)); - v_gemm->create_pipeline(optopt); - - if (optopt.lightmode) - { - v_weight_data.release(); - v_bias_data.release(); - } - } - - { - o_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 1); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 1); // constantB - pd.set(6, 1); // constantC - pd.set(7, 0); // M = outch - pd.set(8, embed_dim); // N = size - pd.set(9, embed_dim); // K = maxk*inch - pd.set(10, 4); // constant_broadcast_type_C = null - pd.set(11, 0); // output_N1M - o_gemm->load_param(pd); - Mat weights[2]; - weights[0] = out_weight_data; - weights[1] = out_bias_data; - o_gemm->load_model(ModelBinFromMatArray(weights)); - o_gemm->create_pipeline(optopt); - - if (optopt.lightmode) - { - out_weight_data.release(); - out_bias_data.release(); - } - } - - { - qk_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 1); // transA - pd.set(3, 0); // transB - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - qk_gemm->load_param(pd); - qk_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = optopt; - opt1.num_threads = 1; - qk_gemm->create_pipeline(opt1); - } - - { - qkv_gemm = ncnn::create_layer(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(2, 0); // transA - pd.set(3, 1); // transB - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(14, 1); // output_transpose - qkv_gemm->load_param(pd); - qkv_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = optopt; - opt1.num_threads = 1; - qkv_gemm->create_pipeline(opt1); - } - - return 0; + qk_softmax->create_pipeline(opt); } -#endif - - Option optopt = optn; - optopt.use_bf16_storage = false; - optopt.use_fp16_arithmetic = false; - optopt.use_fp16_packed = false; - optopt.use_fp16_storage = false; { const int embed_dim_per_head = embed_dim / num_heads; @@ -264,9 +82,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) weights[0] = q_weight_data; weights[1] = q_bias_data; q_gemm->load_model(ModelBinFromMatArray(weights)); - q_gemm->create_pipeline(optopt); + q_gemm->create_pipeline(opt); - if (optopt.lightmode) + if (opt.lightmode) { q_weight_data.release(); q_bias_data.release(); @@ -293,9 +111,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) weights[0] = k_weight_data; weights[1] = k_bias_data; k_gemm->load_model(ModelBinFromMatArray(weights)); - k_gemm->create_pipeline(optopt); + k_gemm->create_pipeline(opt); - if (optopt.lightmode) + if (opt.lightmode) { k_weight_data.release(); k_bias_data.release(); @@ -322,9 +140,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) weights[0] = v_weight_data; weights[1] = v_bias_data; v_gemm->load_model(ModelBinFromMatArray(weights)); - v_gemm->create_pipeline(optopt); + v_gemm->create_pipeline(opt); - if (optopt.lightmode) + if (opt.lightmode) { v_weight_data.release(); v_bias_data.release(); @@ -349,9 +167,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) weights[0] = out_weight_data; weights[1] = out_bias_data; o_gemm->load_model(ModelBinFromMatArray(weights)); - o_gemm->create_pipeline(optopt); + o_gemm->create_pipeline(opt); - if (optopt.lightmode) + if (opt.lightmode) { out_weight_data.release(); out_bias_data.release(); @@ -374,7 +192,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) pd.set(12, 1); // output_elempack qk_gemm->load_param(pd); qk_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = optopt; + Option opt1 = opt; opt1.num_threads = 1; qk_gemm->create_pipeline(opt1); } @@ -396,7 +214,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) pd.set(14, 1); // output_transpose qkv_gemm->load_param(pd); qkv_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = optopt; + Option opt1 = opt; opt1.num_threads = 1; qkv_gemm->create_pipeline(opt1); } @@ -404,119 +222,57 @@ int MultiHeadAttention_arm::create_pipeline(const Option& opt) return 0; } -int MultiHeadAttention_arm::destroy_pipeline(const Option& opt) +int MultiHeadAttention_arm::destroy_pipeline(const Option& _opt) { - Option optn = opt; - optn.use_bf16_storage = false; - - Option opt32 = optn; - opt32.use_bf16_storage = false; - opt32.use_fp16_arithmetic = false; - opt32.use_fp16_packed = false; - opt32.use_fp16_storage = false; + Option opt = _opt; + opt.use_fp16_storage &= support_fp16_storage; + opt.use_bf16_storage &= support_bf16_storage; if (qk_softmax) { - qk_softmax->destroy_pipeline(opt32); + qk_softmax->destroy_pipeline(opt); delete qk_softmax; qk_softmax = 0; } -#if NCNN_ARM82 - if (support_fp16_storage && optn.use_fp16_storage) - { - Option optopt = optn; - - if (q_gemm) - { - q_gemm->destroy_pipeline(optopt); - delete q_gemm; - q_gemm = 0; - } - - if (k_gemm) - { - k_gemm->destroy_pipeline(optopt); - delete k_gemm; - k_gemm = 0; - } - - if (v_gemm) - { - v_gemm->destroy_pipeline(optopt); - delete v_gemm; - v_gemm = 0; - } - - if (o_gemm) - { - o_gemm->destroy_pipeline(optopt); - delete o_gemm; - o_gemm = 0; - } - - if (qk_gemm) - { - qk_gemm->destroy_pipeline(optopt); - delete qk_gemm; - qk_gemm = 0; - } - - if (qkv_gemm) - { - qkv_gemm->destroy_pipeline(optopt); - delete qkv_gemm; - qkv_gemm = 0; - } - - return 0; - } -#endif - - Option optopt = optn; - optopt.use_bf16_storage = false; - optopt.use_fp16_arithmetic = false; - optopt.use_fp16_packed = false; - optopt.use_fp16_storage = false; - if (q_gemm) { - q_gemm->destroy_pipeline(optopt); + q_gemm->destroy_pipeline(opt); delete q_gemm; q_gemm = 0; } if (k_gemm) { - k_gemm->destroy_pipeline(optopt); + k_gemm->destroy_pipeline(opt); delete k_gemm; k_gemm = 0; } if (v_gemm) { - v_gemm->destroy_pipeline(optopt); + v_gemm->destroy_pipeline(opt); delete v_gemm; v_gemm = 0; } if (o_gemm) { - o_gemm->destroy_pipeline(optopt); + o_gemm->destroy_pipeline(opt); delete o_gemm; o_gemm = 0; } if (qk_gemm) { - qk_gemm->destroy_pipeline(optopt); + qk_gemm->destroy_pipeline(opt); delete qk_gemm; qk_gemm = 0; } if (qkv_gemm) { - qkv_gemm->destroy_pipeline(optopt); + qkv_gemm->destroy_pipeline(opt); delete qkv_gemm; qkv_gemm = 0; } @@ -524,13 +280,17 @@ int MultiHeadAttention_arm::destroy_pipeline(const Option& opt) return 0; } -int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { const Mat& q_blob = bottom_blobs[0]; const Mat& k_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : bottom_blobs[1]; const Mat& v_blob = (bottom_blobs.size() == 1 || (bottom_blobs.size() == 2 && attn_mask)) ? q_blob : (bottom_blobs.size() == 2 || (bottom_blobs.size() == 3 && attn_mask)) ? k_blob : bottom_blobs[2]; const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat(); + Option opt = _opt; + opt.use_fp16_storage &= support_fp16_storage; + opt.use_bf16_storage &= support_bf16_storage; + Mat attn_mask_blob_unpacked; if (attn_mask_blob.elempack != 1) { @@ -545,84 +305,18 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v const int src_seqlen = q_blob.h * q_blob.elempack; const int dst_seqlen = k_blob.h * k_blob.elempack; - const int elembits = q_blob.elembits(); - - Option optn = opt; - optn.use_bf16_storage = false; + // const int elembits = q_blob.elembits(); - Option opt32 = optn; - opt32.use_bf16_storage = false; - opt32.use_fp16_arithmetic = false; - opt32.use_fp16_packed = false; - opt32.use_fp16_storage = false; - -#if NCNN_ARM82 - if (support_fp16_storage && optn.use_fp16_storage && elembits == 16) - { - // TODO implement true fp16s with gemm output_elemtype fp32 - Mat q_affine; - q_gemm->forward(q_blob, q_affine, optn); - - Mat k_affine; - k_gemm->forward(k_blob, k_affine, optn); - - Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 2u, optn.blob_allocator); - #pragma omp parallel for num_threads(optn.num_threads) - for (int i = 0; i < num_heads; i++) - { - std::vector qk_bottom_blobs(2); - qk_bottom_blobs[0] = q_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); - qk_bottom_blobs[1] = k_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); - if (attn_mask) - { - const Mat& maskm = attn_mask_blob_unpacked.dims == 3 ? attn_mask_blob_unpacked.channel(i) : attn_mask_blob_unpacked; - qk_bottom_blobs.push_back(maskm); - } - std::vector qk_top_blobs(1); - qk_top_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); - Option opt1 = optn; - opt1.num_threads = 1; - qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); - } - - q_affine.release(); - k_affine.release(); - - qk_softmax->forward_inplace(qk_cross, optn); - - Mat v_affine; - v_gemm->forward(v_blob, v_affine, optn); - - Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 2u, optn.blob_allocator); - #pragma omp parallel for num_threads(optn.num_threads) - for (int i = 0; i < num_heads; i++) - { - std::vector qkv_bottom_blobs(2); - qkv_bottom_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); - qkv_bottom_blobs[1] = v_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); - std::vector qkv_top_blobs(1); - qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head); - Option opt1 = optn; - opt1.num_threads = 1; - qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); - } - - v_affine.release(); - - o_gemm->forward(qkv_cross, top_blobs[0], optn); - - return 0; - } -#endif + size_t elemsize = q_blob.elemsize / q_blob.elempack; Mat q_affine; - q_gemm->forward(q_blob, q_affine, opt32); + q_gemm->forward(q_blob, q_affine, opt); Mat k_affine; - k_gemm->forward(k_blob, k_affine, opt32); + k_gemm->forward(k_blob, k_affine, opt); - Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 4u, opt32.blob_allocator); - #pragma omp parallel for num_threads(opt32.num_threads) + Mat qk_cross(dst_seqlen, src_seqlen * num_heads, elemsize, opt.blob_allocator); + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < num_heads; i++) { std::vector qk_bottom_blobs(2); @@ -635,7 +329,7 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v } std::vector qk_top_blobs(1); qk_top_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen); - Option opt1 = opt32; + Option opt1 = opt; opt1.num_threads = 1; qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); } @@ -643,13 +337,13 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v q_affine.release(); k_affine.release(); - qk_softmax->forward_inplace(qk_cross, opt32); + qk_softmax->forward_inplace(qk_cross, opt); Mat v_affine; - v_gemm->forward(v_blob, v_affine, opt32); + v_gemm->forward(v_blob, v_affine, opt); - Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt32.blob_allocator); - #pragma omp parallel for num_threads(opt32.num_threads) + Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, elemsize, opt.blob_allocator); + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < num_heads; i++) { std::vector qkv_bottom_blobs(2); @@ -657,14 +351,14 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v qkv_bottom_blobs[1] = v_affine.row_range(i * embed_dim_per_head, embed_dim_per_head); std::vector qkv_top_blobs(1); qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head); - Option opt1 = opt32; + Option opt1 = opt; opt1.num_threads = 1; qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); } v_affine.release(); - o_gemm->forward(qkv_cross, top_blobs[0], opt32); + o_gemm->forward(qkv_cross, top_blobs[0], opt); return 0; }