diff --git a/lite/kernels/xpu/__xpu__mmdnn_compute.cc b/lite/kernels/xpu/__xpu__mmdnn_compute.cc index e2feeae7b23..7dc0468320c 100644 --- a/lite/kernels/xpu/__xpu__mmdnn_compute.cc +++ b/lite/kernels/xpu/__xpu__mmdnn_compute.cc @@ -34,53 +34,6 @@ void FillMax(float max, float* xpu_ptr) { XPU_QUANT_SCALE_NUM * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); } - -void GrnnLayout(int batch, - const std::vector& offset, - std::vector* new_offset_ptr, - std::vector* idx_sorted_ptr) { - auto& new_offset = *new_offset_ptr; - auto& idx_sorted = *idx_sorted_ptr; - - std::vector width; - width.resize(batch); - new_offset.clear(); - idx_sorted.clear(); - - idx_sorted.resize(batch); - for (int i = 0; i < batch; i++) { - width[i] = offset[i + 1] - offset[i]; - idx_sorted[i] = i; - } - std::sort(idx_sorted.data(), - idx_sorted.data() + batch, - [&width](int a, int b) { return width[a] > width[b]; }); - int max_width = width[idx_sorted[0]]; - new_offset.resize(max_width + 1); - new_offset[0] = 0; - int j = batch - 1; - int last_width = 0; - int sub_row = 0; - int sub_col = 0; - - for (int i = 1; i <= max_width;) { - for (int k = j; k >= 0; --k) { - if (width[idx_sorted[k]] > last_width) { - sub_row = width[idx_sorted[k]] - last_width; - sub_col = k + 1; - for (int s = 0; s < sub_row; s++) { - new_offset[i] = new_offset[i - 1] + sub_col; - i++; - } - // move on - last_width = width[idx_sorted[k]]; - j = k - 1; - break; - } - } - } -} - } // anonymous namespace class MMDNNIdInfo { @@ -90,38 +43,29 @@ class MMDNNIdInfo { char* cpu_buffer_{nullptr}; std::vector lod; - std::vector new_offset; - std::vector idx_sorted; + std::vector seqlen_list; + std::vector lod_64; std::vector id0_64_cpu; std::vector id1_64_cpu; - std::vector seqlen_list; public: - int64_t* id0_64{nullptr}; - int64_t* id1_64{nullptr}; - int64_t* lod_64{nullptr}; int* lod_32{nullptr}; - int* new_offset_32{nullptr}; - int* idx_sorted_32{nullptr}; - int batch; int seqlen_max; int seqlen_sum; int seqlen_square_sum; - xdnn::VectorParam id0_64_vector; - xdnn::VectorParam id1_64_vector; xdnn::VectorParam lod_32_vector; xdnn::VectorParam seqlen_list_vector; + xdnn::VectorParam id0_64_vector; + xdnn::VectorParam id1_64_vector; + xdnn::VectorParam lod_64_vector; void Init(int upper_bound_batch, int upper_bound_seqlen) { int ub_lod_64_size = (upper_bound_batch + 1) * sizeof(int64_t); int ub_lod_32_size = (upper_bound_batch + 1) * sizeof(int); - int ub_new_offset_32_size = (upper_bound_seqlen + 1) * sizeof(int); - int ub_idx_sorted_32_size = (upper_bound_batch + 1) * sizeof(int); int ub_seqlen_list_size = upper_bound_batch * sizeof(int); - int total_size = ub_lod_64_size + ub_lod_32_size + ub_new_offset_32_size + - ub_idx_sorted_32_size + ub_seqlen_list_size; + int total_size = ub_lod_64_size + ub_lod_32_size + ub_seqlen_list_size; // TODO(miaotianxiang): use l3? l3_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(total_size); @@ -134,8 +78,10 @@ class MMDNNIdInfo { auto& id0_lod = id0->lod()[0]; int idx_len = id0_lod.back(); lod.clear(); + lod_64.clear(); for (auto e : id0_lod) { lod.push_back(e); + lod_64.push_back(e); } seqlen_max = 0; @@ -150,10 +96,9 @@ class MMDNNIdInfo { seqlen_square_sum = seqlen_square_sum + seqlen * seqlen; seqlen_list[i] = seqlen; } - GrnnLayout(batch, lod, &new_offset, &idx_sorted); - id0_64 = const_cast(id0->data()); - id1_64 = const_cast(id1->data()); + int64_t* id0_64 = const_cast(id0->data()); + int64_t* id1_64 = const_cast(id1->data()); id0_64_cpu.resize(idx_len); id1_64_cpu.resize(idx_len); XPU_CALL(xpu_memcpy(id0_64_cpu.data(), @@ -170,7 +115,10 @@ class MMDNNIdInfo { xdnn::VectorParam{id1_64_cpu.data(), idx_len, id1_64}; int offset = 0; - lod_64 = reinterpret_cast(l3_buffer_ + offset); + lod_64_vector = xdnn::VectorParam{ + lod_64.data(), + static_cast(lod_64.size()), + reinterpret_cast(l3_buffer_ + offset)}; memcpy( cpu_buffer_ + offset, id0_lod.data(), id0_lod.size() * sizeof(int64_t)); offset += id0_lod.size() * sizeof(int64_t); @@ -179,16 +127,6 @@ class MMDNNIdInfo { lod.data(), static_cast(lod.size()), lod_32}; memcpy(cpu_buffer_ + offset, lod.data(), lod.size() * sizeof(int)); offset += lod.size() * sizeof(int); - new_offset_32 = reinterpret_cast(l3_buffer_ + offset); - memcpy(cpu_buffer_ + offset, - new_offset.data(), - new_offset.size() * sizeof(int)); - offset += new_offset.size() * sizeof(int); - idx_sorted_32 = reinterpret_cast(l3_buffer_ + offset); - memcpy(cpu_buffer_ + offset, - idx_sorted.data(), - idx_sorted.size() * sizeof(int)); - offset += idx_sorted.size() * sizeof(int); seqlen_list_vector = xdnn::VectorParam{ seqlen_list.data(), batch, reinterpret_cast(l3_buffer_ + offset)}; memcpy(cpu_buffer_ + offset, @@ -291,22 +229,14 @@ class MMDNNFcOp { }; class MMDNNGrnnOp { - MMDNNFcOp fc_e2h0_; - MMDNNFcOp fc_e2h1_; - MMDNNFcOp fc_e2h2_; + const int16_t* dense_e2h_{nullptr}; const int16_t* dense_h2h_{nullptr}; - float dense_h2h_max_[3]; - XPUScratchPadGuard input_max_guard_; - float* input_max_{nullptr}; - XPUScratchPadGuard hbm_buffer_guard_; - float* hbm_buffer_{nullptr}; - // require: cap_l * max(cap_e_, cap_h_) * 5 - // seq2batch_out: [cap_l, cap_e_] - // fc_e2h_out: [3, cap_l, cap_h_] - // gru_out: [cap_l, cap_h_] + XPUScratchPadGuard weight_x_max_guard_; + float* weight_x_max_{nullptr}; + XPUScratchPadGuard weight_w_max_guard_; + float* weight_w_max_{nullptr}; int cap_e_; int cap_h_; - int max_cap_l_; public: void Init(lite::Tensor* wh, @@ -314,108 +244,53 @@ class MMDNNGrnnOp { lite::Tensor* wi, const std::vector& wi_maxs, int cap_e, - int cap_h, - int max_cap_l) { + int cap_h) { cap_e_ = cap_e; cap_h_ = cap_h; - max_cap_l_ = max_cap_l; - - // weight - auto* dense_e2h = wi->data(); - fc_e2h0_.Init(dense_e2h, - wi_maxs[0], - nullptr, - cap_h_, - cap_e_, - xdnn::Activation_t::LINEAR); - fc_e2h1_.Init(dense_e2h + cap_e_ * cap_h_, - wi_maxs[1], - nullptr, - cap_h_, - cap_e_, - xdnn::Activation_t::LINEAR); - fc_e2h2_.Init(dense_e2h + cap_e_ * cap_h_ * 2, - wi_maxs[2], - nullptr, - cap_h_, - cap_e_, - xdnn::Activation_t::LINEAR); - dense_h2h_ = wh->data(); - dense_h2h_max_[0] = wh_maxs[0]; - dense_h2h_max_[1] = wh_maxs[1]; - dense_h2h_max_[2] = wh_maxs[2]; - - input_max_guard_ = - TargetWrapperXPU::MallocScratchPad(XPU_QUANT_SCALE_NUM * sizeof(float)); - input_max_ = reinterpret_cast(input_max_guard_->addr_); - hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( - 5 * std::max(cap_e_, cap_h_) * max_cap_l_ * sizeof(float)); - hbm_buffer_ = reinterpret_cast(hbm_buffer_guard_->addr_); + dense_e2h_ = wi->data(); + weight_x_max_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_QUANT_SCALE_NUM * 3 * sizeof(float)); + weight_x_max_ = reinterpret_cast(weight_x_max_guard_->addr_); + FillMax(wi_maxs[0], weight_x_max_); + FillMax(wi_maxs[1], weight_x_max_ + XPU_QUANT_SCALE_NUM); + FillMax(wi_maxs[2], weight_x_max_ + XPU_QUANT_SCALE_NUM * 2); + weight_w_max_guard_ = TargetWrapperXPU::MallocScratchPad( + XPU_QUANT_SCALE_NUM * 3 * sizeof(float)); + weight_w_max_ = reinterpret_cast(weight_w_max_guard_->addr_); + FillMax(wh_maxs[0], weight_w_max_); + FillMax(wh_maxs[1], weight_w_max_ + XPU_QUANT_SCALE_NUM); + FillMax(wh_maxs[2], weight_w_max_ + XPU_QUANT_SCALE_NUM * 2); } void Infer(xdnn::Context* ctx, const MMDNNIdInfo& sentense, const float* in, - float* out, - float* l3_buffer = nullptr, - int l3_size = 0) { - int batch = sentense.batch; - int cap_l = sentense.seqlen_sum; - int max_width = sentense.seqlen_max; - - int slot_size = cap_l * std::max(cap_e_, cap_h_); - float* seq2batch_out = hbm_buffer_; - float* fc_e2h_out = hbm_buffer_ + 1 * slot_size; - float* gru_out = hbm_buffer_ + 4 * slot_size; - if (l3_size > 0 && l3_size >= 5 * slot_size * sizeof(float)) { - seq2batch_out = l3_buffer; - fc_e2h_out = l3_buffer + 1 * slot_size; - gru_out = l3_buffer + 4 * slot_size; - } - + float* out) { int r = 0; - r = xdnn::search_seq2batch(ctx, - batch, - max_width, - cap_e_, - sentense.idx_sorted_32, - sentense.lod_32, - sentense.new_offset_32, - in, - seq2batch_out); - CHECK_EQ(r, 0); - - r = xdnn::findmax(ctx, in, cap_l * cap_e_, input_max_); - CHECK_EQ(r, 0); - fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_); - fc_e2h1_.Infer( - ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_); - fc_e2h2_.Infer( - ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_); - r = xdnn::search_grnn(ctx, - cap_l, - cap_h_, - cap_e_, - max_width, - sentense.new_offset_32, - fc_e2h_out, - dense_h2h_, - gru_out, - dense_h2h_max_[0], - dense_h2h_max_[1], - dense_h2h_max_[2]); - CHECK_EQ(r, 0); - - r = xdnn::search_batch2seq(ctx, - batch, - max_width, - cap_h_, - sentense.idx_sorted_32, - sentense.lod_32, - sentense.new_offset_32, - gru_out, - out); + r = xdnn::grnn_cell( + ctx, + in, + nullptr, + {dense_e2h_, + dense_e2h_ + cap_e_ * cap_h_, + dense_e2h_ + cap_e_ * cap_h_ * 2}, + {dense_h2h_, + dense_h2h_ + cap_h_ * cap_h_, + dense_h2h_ + cap_h_ * cap_h_ * 2}, + out, + cap_e_, + cap_h_, + sentense.lod_32_vector, + nullptr, + nullptr, + {weight_x_max_, + weight_x_max_ + XPU_QUANT_SCALE_NUM, + weight_x_max_ + XPU_QUANT_SCALE_NUM * 2}, + {weight_w_max_, + weight_w_max_ + XPU_QUANT_SCALE_NUM, + weight_w_max_ + XPU_QUANT_SCALE_NUM * 2}, + nullptr); CHECK_EQ(r, 0); } }; @@ -824,10 +699,8 @@ class MMDNNBidEmbGrnnAtt { cap_h_ = emb_dim_; int max_cap_l = upper_bound_batch * upper_bound_seqlen; - bi_fw_.Init( - fw_wh, fw_wh_maxs, fw_wi, fw_wi_maxs, emb_dim_, cap_h_, max_cap_l); - bi_rv_.Init( - rv_wh, rv_wh_maxs, rv_wi, rv_wi_maxs, emb_dim_, cap_h_, max_cap_l); + bi_fw_.Init(fw_wh, fw_wh_maxs, fw_wi, fw_wi_maxs, emb_dim_, cap_h_); + bi_rv_.Init(rv_wh, rv_wh_maxs, rv_wi, rv_wi_maxs, emb_dim_, cap_h_); att_.Init(att_fc_w, att_fc_w_max, att_fc_b, @@ -871,25 +744,18 @@ class MMDNNBidEmbGrnnAtt { att_out = att_pool_out->mutable_data(TARGET(kXPU)); int r = 0; - r = xdnn::search_bid_emb_ew(ctx, - batch, - sentense.lod_64, - sentense.id0_64, - sentense.id1_64, - table_, - table_len_, - emb_dim_, - emb_fw, - emb_rv, - table_len_ - 2, - 1); + r = xdnn::bidirection_embedding_add(ctx, + table_, + emb_fw, + emb_rv, + sentense.lod_64_vector, + sentense.id0_64_vector, + sentense.id1_64_vector, + table_len_, + emb_dim_, + table_len_ - 2); CHECK_EQ(r, 0); - bi_rv_.Infer(ctx, - sentense, - emb_rv, - grnn_rv, - l3_buffer + 2 * slot_len, - l3_size - 2 * slot_len * sizeof(float)); + bi_rv_.Infer(ctx, sentense, emb_rv, grnn_rv); r = xdnn::sequence_reverse( ctx, grnn_rv, sentense.lod_32, grnn_rv_rv, batch, cap_h_); CHECK_EQ(r, 0); @@ -897,12 +763,7 @@ class MMDNNBidEmbGrnnAtt { ctx, grnn_rv, pool_rv, sentense.lod_32_vector, batch, cap_h_, 0); CHECK_EQ(r, 0); - bi_fw_.Infer(ctx, - sentense, - emb_fw, - grnn_fw, - l3_buffer + 2 * slot_len, - l3_size - 2 * slot_len * sizeof(float)); + bi_fw_.Infer(ctx, sentense, emb_fw, grnn_fw); r = xdnn::sequence_last_pool( ctx, grnn_fw, pool_fw, sentense.lod_32_vector, batch, cap_h_, 0); CHECK_EQ(r, 0); @@ -1034,15 +895,13 @@ class MMDNNMergeAll { grnn_fw_wi, grnn_fw_wi_maxs, cap_e_, - cap_h_, - max_cap_l); + cap_h_); coverage_rv_.Init(grnn_rv_wh, grnn_rv_wh_maxs, grnn_rv_wi, grnn_rv_wi_maxs, cap_e_, - cap_h_, - max_cap_l); + cap_h_); fc0_.Init( fc0_w, fc0_w_max, fc0_b, fc0_n_, fc0_k_, xdnn::Activation_t::RELU); @@ -1109,18 +968,8 @@ class MMDNNMergeAll { batch, cap_e_); CHECK_EQ(r, 0); - coverage_fw_.Infer(ctx, - sentense, - topk_concat_out_fw, - grnn_fw, - l3_buffer + hbm_total_len, - l3_size - hbm_total_len * sizeof(float)); - coverage_rv_.Infer(ctx, - sentense, - topk_concat_out_rv, - grnn_rv, - l3_buffer + hbm_total_len, - l3_size - hbm_total_len * sizeof(float)); + coverage_fw_.Infer(ctx, sentense, topk_concat_out_fw, grnn_fw); + coverage_rv_.Infer(ctx, sentense, topk_concat_out_rv, grnn_rv); r = xdnn::sequence_last_pool( ctx, grnn_fw, pool_fw, sentense.lod_32_vector, batch, cap_h_, 0); CHECK_EQ(r, 0);