Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] change match_matrix_tensor op from old version to refector verison, test=develop, test=xpu #7012

Merged
merged 1 commit into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions lite/kernels/xpu/__xpu__mmdnn_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,13 @@ class MMDNNMatchConvTopk {
int dim_in_;
int out_channel_;

MMDNNFcOp xw_fc_;
const int16_t* match_weight_{nullptr};
XPUScratchPadGuard match_weight_max_guard_;
float* match_weight_max_{nullptr};
XPUScratchPadGuard in_max_guard_;
float* in_max_{nullptr};
XPUScratchPadGuard out_max_guard_;
float* out_max_{nullptr};
const int16_t* conv_weight_{nullptr};
float conv_weight_max_;
XPUScratchPadGuard hbm_buffer_guard_;
Expand Down Expand Up @@ -525,12 +531,17 @@ class MMDNNMatchConvTopk {
out_channel_ = out_channel;
topks_ = topks;

xw_fc_.Init(input_w,
input_w_max,
nullptr,
dim_t_ * dim_in_,
dim_in_,
xdnn::Activation_t::LINEAR);
match_weight_ = input_w->data<int16_t>();
match_weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
match_weight_max_ =
reinterpret_cast<float*>(match_weight_max_guard_->addr_);
FillMax(input_w_max, match_weight_max_);
in_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
in_max_ = reinterpret_cast<float*>(in_max_guard_->addr_);
out_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
out_max_ = reinterpret_cast<float*>(out_max_guard_->addr_);

conv_weight_ = conv_w->data<int16_t>();
conv_weight_max_ = conv_w_max;

Expand Down Expand Up @@ -644,21 +655,30 @@ class MMDNNMatchConvTopk {
}
seq_avg_topk_out = out->mutable_data<float>(TARGET(kXPU));

int max_width = std::max(left_seqlen_max, right_seqlen_max);
xw_fc_.Infer(ctx, left->data<float>(), left_seqlen_sum, xw_out);
int r = 0;
r = xdnn::match_matrix_tensor(ctx,
batch,
xw_out,
right->data<float>(),
left_lod_32_,
right_lod_32_,
dim_t_,
dim_in_,
xwy_out,
xw_fc_.out_max,
xdnn::Activation_t::RELU,
max_width);
r = xdnn::findmax<float>(
ctx, left->data<float>(), left_seqlen_sum * dim_in_, in_max_);
CHECK_EQ(r, 0);
r = xdnn::match_matrix_tensor<float, int16_t, int>(
ctx,
left->data<float>(),
right->data<float>(),
match_weight_,
xwy_out,
dim_in_,
dim_t_,
true,
{left_lod_32_cpu.data(),
static_cast<int>(left_lod_32_cpu.size()),
left_lod_32_},
{right_lod_32_cpu.data(),
static_cast<int>(right_lod_32_cpu.size()),
right_lod_32_},
in_max_,
nullptr,
match_weight_max_,
xdnn::Activation_t::RELU,
xw_out);
CHECK_EQ(r, 0);
r = xdnn::search_varconv<float, int16_t>(
ctx,
Expand Down
73 changes: 27 additions & 46 deletions lite/kernels/xpu/match_matrix_tensor_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ namespace kernels {
namespace xpu {

void MatchMatrixTensorCompute::PrepareForRun() {
wx_max_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int));
auto& param = this->Param<param_t>();
float w_max = param.__xpu__w_max;
std::vector<float> w_max_v(XPU_QUANT_SCALE_NUM, w_max);
weight_max_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(XPU_QUANT_SCALE_NUM * sizeof(float));
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_xpu_guard_->addr_),
w_max_v.data(),
XPU_QUANT_SCALE_NUM * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

offset_l_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(XPU_MAX_LOD_SIZE * sizeof(int));
offset_r_xpu_guard_ =
Expand All @@ -44,7 +52,6 @@ void MatchMatrixTensorCompute::Run() {
auto* out = param.out;
auto* tmp = param.tmp;
int dim_t = param.dim_t;
float w_max = param.__xpu__w_max;
bool fuse_relu = param.fuse_relu;
bool float_to_fix = param.__xpu__float_to_fix;
CHECK(float_to_fix) << "W should be fixed point";
Expand Down Expand Up @@ -74,44 +81,15 @@ void MatchMatrixTensorCompute::Run() {
auto* bottom_l_trans_data = tmp->mutable_data<float>(TARGET(kXPU));
int batch_size = x->lod()[0].size() - 1;

float* wx_max = reinterpret_cast<float*>(wx_max_xpu_guard_->addr_);
float* w_max = reinterpret_cast<float*>(weight_max_xpu_guard_->addr_);
int* offset_l_xpu = reinterpret_cast<int*>(offset_l_xpu_guard_->addr_);
int* offset_r_xpu = reinterpret_cast<int*>(offset_r_xpu_guard_->addr_);

int r = xdnn::gemm_int16_tmp_api<float, int16_t, float>(
ctx.GetRawContext(), /* ctx */
false, /* trans_a */
false, /* trans_b */
x->dims()[0], /* m */
dim_t * dim_in, /* n */
dim_in, /* k */
1.0f, /* alpha */
bottom_l_data, /* data_a */
dim_in, /* lda */
w_data, /* data_b */
dim_t * dim_in, /* ldb */
0.0f, /* beta */
bottom_l_trans_data, /* data_c */
dim_t * dim_in, /* ldc */
nullptr, /* bias */
xdnn::Activation_t::LINEAR, /* act */
0.0f, /* max_a */
w_max, /* max_b */
wx_max /* max_c */);
CHECK_EQ(r, 0);

int max_width = 0;
for (int i = 0; i < offset_l.size(); ++i) {
offset_l_cpu[i] = offset_l[i];
if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1] > max_width)) {
max_width = offset_l_cpu[i] - offset_l_cpu[i - 1];
}
}
for (int i = 0; i < offset_r.size(); ++i) {
offset_r_cpu[i] = offset_r[i];
if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1] > max_width)) {
max_width = offset_r_cpu[i] - offset_r_cpu[i - 1];
}
}
XPU_CALL(xpu_memcpy(offset_l_xpu,
offset_l_cpu.get(),
Expand All @@ -122,20 +100,23 @@ void MatchMatrixTensorCompute::Run() {
offset_r.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

r = xdnn::match_matrix_tensor(ctx.GetRawContext(),
batch_size,
bottom_l_trans_data,
bottom_r_data,
offset_l_xpu,
offset_r_xpu,
dim_t,
dim_in,
out_data,
wx_max,
act,
max_width);
int r = xdnn::match_matrix_tensor<float, int16_t, int>(
ctx.GetRawContext(),
bottom_l_data,
bottom_r_data,
w_data,
out_data,
dim_in,
dim_t,
true, // the weight is trans in XPUMmdnnFloat2Fix
{offset_l_cpu.get(), static_cast<int>(offset_l.size()), offset_l_xpu},
{offset_r_cpu.get(), static_cast<int>(offset_r.size()), offset_r_xpu},
nullptr,
nullptr,
w_max,
act,
bottom_l_trans_data);
CHECK_EQ(r, 0);

int lod_lv1_size = batch_size * dim_t;
int lod_lv2_size = x->lod()[0].back() * dim_t;
std::vector<size_t> out_lod0(batch_size + 1, 0);
Expand Down
2 changes: 1 addition & 1 deletion lite/kernels/xpu/match_matrix_tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MatchMatrixTensorCompute
virtual void Run();

private:
XPUScratchPadGuard wx_max_xpu_guard_;
XPUScratchPadGuard weight_max_xpu_guard_;
XPUScratchPadGuard offset_l_xpu_guard_;
XPUScratchPadGuard offset_r_xpu_guard_;

Expand Down