diff --git a/lite/kernels/xpu/__xpu__mmdnn_compute.cc b/lite/kernels/xpu/__xpu__mmdnn_compute.cc index f2fadf141ce..cc01df9d899 100644 --- a/lite/kernels/xpu/__xpu__mmdnn_compute.cc +++ b/lite/kernels/xpu/__xpu__mmdnn_compute.cc @@ -482,7 +482,13 @@ class MMDNNMatchConvTopk { int dim_in_; int out_channel_; - MMDNNFcOp xw_fc_; + const int16_t* match_weight_{nullptr}; + XPUScratchPadGuard match_weight_max_guard_; + float* match_weight_max_{nullptr}; + XPUScratchPadGuard in_max_guard_; + float* in_max_{nullptr}; + XPUScratchPadGuard out_max_guard_; + float* out_max_{nullptr}; const int16_t* conv_weight_{nullptr}; float conv_weight_max_; XPUScratchPadGuard hbm_buffer_guard_; @@ -525,12 +531,17 @@ class MMDNNMatchConvTopk { out_channel_ = out_channel; topks_ = topks; - xw_fc_.Init(input_w, - input_w_max, - nullptr, - dim_t_ * dim_in_, - dim_in_, - xdnn::Activation_t::LINEAR); + match_weight_ = input_w->data(); + match_weight_max_guard_ = + TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + match_weight_max_ = + reinterpret_cast(match_weight_max_guard_->addr_); + FillMax(input_w_max, match_weight_max_); + in_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + in_max_ = reinterpret_cast(in_max_guard_->addr_); + out_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + out_max_ = reinterpret_cast(out_max_guard_->addr_); + conv_weight_ = conv_w->data(); conv_weight_max_ = conv_w_max; @@ -644,21 +655,30 @@ class MMDNNMatchConvTopk { } seq_avg_topk_out = out->mutable_data(TARGET(kXPU)); - int max_width = std::max(left_seqlen_max, right_seqlen_max); - xw_fc_.Infer(ctx, left->data(), left_seqlen_sum, xw_out); int r = 0; - r = xdnn::match_matrix_tensor(ctx, - batch, - xw_out, - right->data(), - left_lod_32_, - right_lod_32_, - dim_t_, - dim_in_, - xwy_out, - xw_fc_.out_max, - xdnn::Activation_t::RELU, - max_width); + r = xdnn::findmax( + ctx, left->data(), left_seqlen_sum * dim_in_, in_max_); + CHECK_EQ(r, 0); + r = xdnn::match_matrix_tensor( + ctx, + left->data(), + right->data(), + match_weight_, + xwy_out, + dim_in_, + dim_t_, + true, + {left_lod_32_cpu.data(), + static_cast(left_lod_32_cpu.size()), + left_lod_32_}, + {right_lod_32_cpu.data(), + static_cast(right_lod_32_cpu.size()), + right_lod_32_}, + in_max_, + nullptr, + match_weight_max_, + xdnn::Activation_t::RELU, + xw_out); CHECK_EQ(r, 0); r = xdnn::search_varconv( ctx, diff --git a/lite/kernels/xpu/match_matrix_tensor_compute.cc b/lite/kernels/xpu/match_matrix_tensor_compute.cc index 727338096e0..993f94b7ed1 100644 --- a/lite/kernels/xpu/match_matrix_tensor_compute.cc +++ b/lite/kernels/xpu/match_matrix_tensor_compute.cc @@ -23,8 +23,16 @@ namespace kernels { namespace xpu { void MatchMatrixTensorCompute::PrepareForRun() { - wx_max_xpu_guard_ = - TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int)); + auto& param = this->Param(); + float w_max = param.__xpu__w_max; + std::vector w_max_v(XPU_QUANT_SCALE_NUM, w_max); + weight_max_xpu_guard_ = + TargetWrapperXPU::MallocScratchPad(XPU_QUANT_SCALE_NUM * sizeof(float)); + XPU_CALL(xpu_memcpy(reinterpret_cast(weight_max_xpu_guard_->addr_), + w_max_v.data(), + XPU_QUANT_SCALE_NUM * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int)); offset_r_xpu_guard_ = @@ -44,7 +52,6 @@ void MatchMatrixTensorCompute::Run() { auto* out = param.out; auto* tmp = param.tmp; int dim_t = param.dim_t; - float w_max = param.__xpu__w_max; bool fuse_relu = param.fuse_relu; bool float_to_fix = param.__xpu__float_to_fix; CHECK(float_to_fix) << "W should be fixed point"; @@ -74,44 +81,15 @@ void MatchMatrixTensorCompute::Run() { auto* bottom_l_trans_data = tmp->mutable_data(TARGET(kXPU)); int batch_size = x->lod()[0].size() - 1; - float* wx_max = reinterpret_cast(wx_max_xpu_guard_->addr_); + float* w_max = reinterpret_cast(weight_max_xpu_guard_->addr_); int* offset_l_xpu = reinterpret_cast(offset_l_xpu_guard_->addr_); int* offset_r_xpu = reinterpret_cast(offset_r_xpu_guard_->addr_); - int r = xdnn::gemm_int16_tmp_api( - ctx.GetRawContext(), /* ctx */ - false, /* trans_a */ - false, /* trans_b */ - x->dims()[0], /* m */ - dim_t * dim_in, /* n */ - dim_in, /* k */ - 1.0f, /* alpha */ - bottom_l_data, /* data_a */ - dim_in, /* lda */ - w_data, /* data_b */ - dim_t * dim_in, /* ldb */ - 0.0f, /* beta */ - bottom_l_trans_data, /* data_c */ - dim_t * dim_in, /* ldc */ - nullptr, /* bias */ - xdnn::Activation_t::LINEAR, /* act */ - 0.0f, /* max_a */ - w_max, /* max_b */ - wx_max /* max_c */); - CHECK_EQ(r, 0); - - int max_width = 0; for (int i = 0; i < offset_l.size(); ++i) { offset_l_cpu[i] = offset_l[i]; - if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1] > max_width)) { - max_width = offset_l_cpu[i] - offset_l_cpu[i - 1]; - } } for (int i = 0; i < offset_r.size(); ++i) { offset_r_cpu[i] = offset_r[i]; - if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1] > max_width)) { - max_width = offset_r_cpu[i] - offset_r_cpu[i - 1]; - } } XPU_CALL(xpu_memcpy(offset_l_xpu, offset_l_cpu.get(), @@ -122,20 +100,23 @@ void MatchMatrixTensorCompute::Run() { offset_r.size() * sizeof(int), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - r = xdnn::match_matrix_tensor(ctx.GetRawContext(), - batch_size, - bottom_l_trans_data, - bottom_r_data, - offset_l_xpu, - offset_r_xpu, - dim_t, - dim_in, - out_data, - wx_max, - act, - max_width); + int r = xdnn::match_matrix_tensor( + ctx.GetRawContext(), + bottom_l_data, + bottom_r_data, + w_data, + out_data, + dim_in, + dim_t, + true, // the weight is trans in XPUMmdnnFloat2Fix + {offset_l_cpu.get(), static_cast(offset_l.size()), offset_l_xpu}, + {offset_r_cpu.get(), static_cast(offset_r.size()), offset_r_xpu}, + nullptr, + nullptr, + w_max, + act, + bottom_l_trans_data); CHECK_EQ(r, 0); - int lod_lv1_size = batch_size * dim_t; int lod_lv2_size = x->lod()[0].back() * dim_t; std::vector out_lod0(batch_size + 1, 0); diff --git a/lite/kernels/xpu/match_matrix_tensor_compute.h b/lite/kernels/xpu/match_matrix_tensor_compute.h index 3bd0b622db1..49831a41880 100644 --- a/lite/kernels/xpu/match_matrix_tensor_compute.h +++ b/lite/kernels/xpu/match_matrix_tensor_compute.h @@ -33,7 +33,7 @@ class MatchMatrixTensorCompute virtual void Run(); private: - XPUScratchPadGuard wx_max_xpu_guard_; + XPUScratchPadGuard weight_max_xpu_guard_; XPUScratchPadGuard offset_l_xpu_guard_; XPUScratchPadGuard offset_r_xpu_guard_;