Skip to content

Commit

Permalink
[XPU] change match_matrix_tensor op from old version to refector veri…
Browse files Browse the repository at this point in the history
…son, test=develop, test=xpu
  • Loading branch information
shanliang1992 committed Sep 22, 2021
1 parent 85b8a9b commit 7e99266
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 68 deletions.
1 change: 1 addition & 0 deletions lite/backends/xpu/target_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace lite {
const int XPU_MAX_LOD_SIZE = 32;
// MAX(lod[i + 1] - lod[i]) = 512
const int XPU_MAX_LOD_SEQ_LEN = 512;
const int XPU_MAXPTR_SIZE = 6;

using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;

Expand Down
62 changes: 41 additions & 21 deletions lite/kernels/xpu/__xpu__mmdnn_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,13 @@ class MMDNNMatchConvTopk {
int dim_in_;
int out_channel_;

MMDNNFcOp xw_fc_;
const int16_t* match_weight_{nullptr};
XPUScratchPadGuard match_weight_max_guard_;
float* match_weight_max_{nullptr};
XPUScratchPadGuard in_max_guard_;
float* in_max_{nullptr};
XPUScratchPadGuard out_max_guard_;
float* out_max_{nullptr};
const int16_t* conv_weight_{nullptr};
float conv_weight_max_;
XPUScratchPadGuard hbm_buffer_guard_;
Expand Down Expand Up @@ -525,12 +531,17 @@ class MMDNNMatchConvTopk {
out_channel_ = out_channel;
topks_ = topks;

xw_fc_.Init(input_w,
input_w_max,
nullptr,
dim_t_ * dim_in_,
dim_in_,
xdnn::Activation_t::LINEAR);
match_weight_ = input_w->data<int16_t>();
match_weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
match_weight_max_ =
reinterpret_cast<float*>(match_weight_max_guard_->addr_);
FillMax(input_w_max, match_weight_max_);
in_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
in_max_ = reinterpret_cast<float*>(in_max_guard_->addr_);
out_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
out_max_ = reinterpret_cast<float*>(out_max_guard_->addr_);

conv_weight_ = conv_w->data<int16_t>();
conv_weight_max_ = conv_w_max;

Expand Down Expand Up @@ -644,21 +655,30 @@ class MMDNNMatchConvTopk {
}
seq_avg_topk_out = out->mutable_data<float>(TARGET(kXPU));

int max_width = std::max(left_seqlen_max, right_seqlen_max);
xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out);
int r = 0;
r = xdnn::match_matrix_tensor(ctx,
batch,
xw_out,
right->data<float>(),
left_lod_32_,
right_lod_32_,
dim_t_,
dim_in_,
xwy_out,
xw_fc_.out_max,
xdnn::Activation_t::RELU,
max_width);
r = xdnn::findmax<float>(
ctx, left->data<float>(), left_seqlen_sum * dim_in_, in_max_);
CHECK_EQ(r, 0);
r = xdnn::match_matrix_tensor<float, int16_t, int>(
ctx,
left->data<float>(),
right->data<float>(),
match_weight_,
xwy_out,
dim_in_,
dim_t_,
true,
{left_lod_32_cpu.data(),
static_cast<int>(left_lod_32_cpu.size()),
left_lod_32_},
{right_lod_32_cpu.data(),
static_cast<int>(right_lod_32_cpu.size()),
right_lod_32_},
in_max_,
nullptr,
match_weight_max_,
xdnn::Activation_t::RELU,
xw_out);
CHECK_EQ(r, 0);
r = xdnn::search_varconv<float, int16_t>(
ctx,
Expand Down
73 changes: 27 additions & 46 deletions lite/kernels/xpu/match_matrix_tensor_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ namespace kernels {
namespace xpu {

void MatchMatrixTensorCompute::PrepareForRun() {
wx_max_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int));
auto& param = this->Param<param_t>();
float w_max = param.__xpu__w_max;
std::vector<float> w_max_v(XPU_MAXPTR_SIZE, w_max);
weight_max_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(XPU_MAXPTR_SIZE * sizeof(float));
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_xpu_guard_->addr_),
w_max_v.data(),
XPU_MAXPTR_SIZE * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

offset_l_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int));
offset_r_xpu_guard_ =
Expand All @@ -44,7 +52,6 @@ void MatchMatrixTensorCompute::Run() {
auto* out = param.out;
auto* tmp = param.tmp;
int dim_t = param.dim_t;
float w_max = param.__xpu__w_max;
bool fuse_relu = param.fuse_relu;
bool float_to_fix = param.__xpu__float_to_fix;
CHECK(float_to_fix) << "W should be fixed point";
Expand Down Expand Up @@ -74,44 +81,15 @@ void MatchMatrixTensorCompute::Run() {
auto* bottom_l_trans_data = tmp->mutable_data<float>(TARGET(kXPU));
int batch_size = x->lod()[0].size() - 1;

float* wx_max = reinterpret_cast<float*>(wx_max_xpu_guard_->addr_);
float* w_max = reinterpret_cast<float*>(weight_max_xpu_guard_->addr_);
int* offset_l_xpu = reinterpret_cast<int*>(offset_l_xpu_guard_->addr_);
int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_);

int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */
false, /* trans_a */
false, /* trans_b */
x->dims()[0], /* m */
dim_t * dim_in, /* n */
dim_in, /* k */
1.0f, /* alpha */
bottom_l_data, /* data_a */
dim_in, /* lda */
w_data, /* data_b */
dim_t * dim_in, /* ldb */
0.0f, /* beta */
bottom_l_trans_data, /* data_c */
dim_t * dim_in, /* ldc */
nullptr, /* bias */
xdnn::Activation_t::LINEAR, /* act */
0.0f, /* max_a */
w_max, /* max_b */
wx_max /* max_c */);
CHECK_EQ(r, 0);

int max_width = 0;
for (int i = 0; i < offset_l.size(); ++i) {
offset_l_cpu[i] = offset_l[i];
if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1] > max_width)) {
max_width = offset_l_cpu[i] - offset_l_cpu[i - 1];
}
}
for (int i = 0; i < offset_r.size(); ++i) {
offset_r_cpu[i] = offset_r[i];
if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1] > max_width)) {
max_width = offset_r_cpu[i] - offset_r_cpu[i - 1];
}
}
XPU_CALL(xpu_memcpy(offset_l_xpu,
offset_l_cpu.get(),
Expand All @@ -122,20 +100,23 @@ void MatchMatrixTensorCompute::Run() {
offset_r.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

r = xdnn::match_matrix_tensor(ctx.GetRawContext(),
batch_size,
bottom_l_trans_data,
bottom_r_data,
offset_l_xpu,
offset_r_xpu,
dim_t,
dim_in,
out_data,
wx_max,
act,
max_width);
int r = xdnn::match_matrix_tensor<float, int16_t, int>(
ctx.GetRawContext(),
bottom_l_data,
bottom_r_data,
w_data,
out_data,
dim_in,
dim_t,
true, // the weight is trans in XPUMmdnnFloat2Fix
{offset_l_cpu.get(), static_cast<int>(offset_l.size()), offset_l_xpu},
{offset_r_cpu.get(), static_cast<int>(offset_r.size()), offset_r_xpu},
nullptr,
nullptr,
w_max,
act,
bottom_l_trans_data);
CHECK_EQ(r, 0);

int lod_lv1_size = batch_size * dim_t;
int lod_lv2_size = x->lod()[0].back() * dim_t;
std::vector<size_t> out_lod0(batch_size + 1, 0);
Expand Down
2 changes: 1 addition & 1 deletion lite/kernels/xpu/match_matrix_tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MatchMatrixTensorCompute
virtual void Run();

private:
XPUScratchPadGuard wx_max_xpu_guard_;
XPUScratchPadGuard weight_max_xpu_guard_;
XPUScratchPadGuard offset_l_xpu_guard_;
XPUScratchPadGuard offset_r_xpu_guard_;

Expand Down

0 comments on commit 7e99266

Please sign in to comment.