Skip to content

Commit

Permalink
[XPU] [Cherry-Pick] change fc_int16 op to fc_fusion (PaddlePaddle#7029)…
Browse files Browse the repository at this point in the history
… test=develop
  • Loading branch information
shanliang1992 authored and newway committed Nov 11, 2021
1 parent e23bbdb commit 1cd2315
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 75 deletions.
38 changes: 20 additions & 18 deletions lite/kernels/xpu/__xpu__bigru_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,26 @@ void XPUBiGRUCompute::MulRun(bool forward) {
(bias_guard == nullptr)
? nullptr
: reinterpret_cast<const float*>(bias_guard->addr_);

int r = xdnn::fc_int16(
ctx.GetRawContext(), /* context */
false, /* TransA */
true, /* TransB */
m,
n,
k,
1.0f, /* alpha */
x_matrix.data<float>(), /* A */
reinterpret_cast<const float*>(input_max_guard_->addr_),
reinterpret_cast<const int16_t*>(quant_weight_guard->addr_), /* B */
reinterpret_cast<const float*>(weight_max_guard->addr_),
0.0f, /* beta */
output.mutable_data<float>(TARGET(kXPU)), /* C */
reinterpret_cast<float*>(mul_output_max_guard_->addr_),
bias_ptr,
xdnn::Activation_t::LINEAR);
int r = xdnn::fc_fusion<float, int16_t, float, int16_t>(
ctx.GetRawContext(), // ctx
x_matrix.data<float>(), // x
reinterpret_cast<const int16_t*>(quant_weight_guard->addr_), // w
output.mutable_data<float>(TARGET(kXPU)), // y
m, // m
n, // n
k, // k
false, // x_trans
true, // w_trans
reinterpret_cast<const float*>(input_max_guard_->addr_), // x_maxptr
reinterpret_cast<const float*>(weight_max_guard->addr_), // w_maxptr
reinterpret_cast<float*>(mul_output_max_guard_->addr_), // y_maxptr,
k, // ldx
k, // ldw
n, // ldy
1.0f, // alpha
0.0f, // beta
bias_ptr, // bias
xdnn::Activation_t::LINEAR); // act
CHECK_EQ(r, 0);
*(output.mutable_lod()) = origin_x.lod();
}
Expand Down
38 changes: 21 additions & 17 deletions lite/kernels/xpu/__xpu__fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,29 @@ void XPUFcCompute::Run() {
reinterpret_cast<float*>(input_max_guard_->addr_));
CHECK_EQ(r, 0);
}
r = xdnn::fc_int16(
ctx.GetRawContext(), /* context */
false, /* TransA */
true, /* TransB */
m, /* m */
n, /* n */
k, /* k */
1.0f, /* alpha */
param.input->data<float>(), /* A */
r = xdnn::fc_fusion<float, int16_t, float, int16_t>(
ctx.GetRawContext(), // ctx
param.input->data<float>(), // x
reinterpret_cast<const int16_t*>(quant_weight_guard_->addr_), // w
param.output->mutable_data<float>(TARGET(kXPU)), // y
m, // m
n, // n
k, // k
false, // x_trans
true, // w_trans
(input_max == nullptr)
? reinterpret_cast<const float*>(input_max_guard_->addr_)
: input_max,
reinterpret_cast<const int16_t*>(quant_weight_guard_->addr_), /* B */
reinterpret_cast<const float*>(weight_max_guard_->addr_),
0.0f, /* beta */
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
output_max,
bias, /* bias */
act /* act_type */);
: input_max, // x_maxptr
reinterpret_cast<const float*>(weight_max_guard_->addr_), // w_maxptr
output_max, // y_maxptr
k, // ldx
k, // ldw
n, // ldy
1.0f, // alpha
0.0f, // beta
bias, // bias
act); // act

CHECK_EQ(r, 0);
} else if (param.precision == "int8") {
bool x_trans = false;
Expand Down
37 changes: 20 additions & 17 deletions lite/kernels/xpu/__xpu__mmdnn_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,23 +224,26 @@ class MMDNNFcOp {
CHECK_EQ(r, 0);
in_max_by_caller = in_max_;
}

r = xdnn::fc_int16(ctx,
false,
true,
m,
n_,
k_,
1.0f,
in,
in_max_by_caller,
weight_,
weight_max_,
0.0f,
out,
out_max,
bias_,
act_type_);
r = xdnn::fc_fusion<float, int16_t, float, int16_t>(
ctx, // ctx
in, // x
weight_, // w
out, // y
m, // m
n_, // n
k_, // k
false, // x_trans
true, // w_trans
in_max_by_caller, // x_maxptr
weight_max_, // w_maxptr
out_max, // y_maxptr
k_, // ldx
k_, // ldw
n_, // ldy
1.0f, // alpha
0.0f, // beta
bias_, // bias
act_type_); // act
CHECK_EQ(r, 0);
}
};
Expand Down
32 changes: 21 additions & 11 deletions lite/kernels/xpu/matmul_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,27 @@ void MatMulCompute::Run() {

int r = 0;
if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) {
r = xdnn::fc_int16(ctx.GetRawContext(), /* context */
mat_dim_a.trans_, /* TransA */
mat_dim_b.trans_, /* TransB */
mat_dim_a.height_, /* m */
mat_dim_b.width_, /* n */
mat_dim_a.width_, /* k */
param.alpha, /* alpha */
x->data<float>(), /* A */
y->data<float>(), /* B */
0.0f, /* beta */
out->mutable_data<float>(TARGET(kXPU)) /* C */);
r = xdnn::fc_fusion<float, float, float, int16_t>(
ctx.GetRawContext(), // ctx
x->data<float>(), // x
y->data<float>(), // w
out->mutable_data<float>(TARGET(kXPU)), // y
mat_dim_a.height_, // m
mat_dim_b.width_, // n
mat_dim_a.width_, // k
mat_dim_a.trans_, // x_trans
mat_dim_b.trans_, // w_trans
nullptr, // x_maxptr
nullptr, // w_maxptr
nullptr, // y_maxptr
lda, // ldx
ldb, // ldw
ldc, // ldy
param.alpha, // alpha
0.0f, // beta
nullptr, // bias
xdnn::Activation_t::LINEAR); // act

} else {
// batch matmul
r = xdnn::gemm_strided_batched_int16<float, float, float>(
Expand Down
33 changes: 21 additions & 12 deletions lite/kernels/xpu/mul_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,27 @@ void MulCompute::Run() {
int k = x_matrix.dims()[1];
int n = y_matrix.dims()[1];

int r =
xdnn::fc_int16(ctx.GetRawContext(), /* context */
false, /* TransA */
false, /* TransB */
m,
n,
k,
1.0f, /* alpha */
x_matrix.data<float>(), /* A */
y_matrix.data<float>(), /* B */
0.0f, /* beta */
param.output->mutable_data<float>(TARGET(kXPU)) /* C */);
int r = xdnn::fc_fusion<float, float, float, int16_t>(
ctx.GetRawContext(), // ctx
x_matrix.data<float>(), // x
y_matrix.data<float>(), // w
param.output->mutable_data<float>(TARGET(kXPU)), // y
m, // m
n, // n
k, // k
false, // x_trans
false, // w_trans
nullptr, // x_maxptr
nullptr, // w_maxptr
nullptr, // y_maxptr
k, // ldx
n, // ldw
n, // ldy
1.0f, // alpha
0.0f, // beta
nullptr, // bias
xdnn::Activation_t::LINEAR); // act

CHECK_EQ(r, 0);
}

Expand Down

0 comments on commit 1cd2315

Please sign in to comment.