Skip to content

Commit

Permalink
[xpu][cherry-pick] Fc int31 (PaddlePaddle#7514) test=develop
Browse files Browse the repository at this point in the history
* [xpu] more check with encoder fuse pass

* [xpu] fix continuous encoder fuse and fc max size

* [xpu] refactor fc int31 for KL2
  • Loading branch information
newway committed Nov 9, 2021
1 parent 92eb5d7 commit e23bbdb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 33 deletions.
36 changes: 24 additions & 12 deletions lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,21 +665,24 @@ class XPUMultiEncoderFuser {

void operator()(SSAGraph* graph) {
std::vector<Node*> all_encoders;
for (auto* node : graph->StmtTopologicalOrder()) {
CHECK(node->IsStmt());
if (node->stmt()->op_info()->Type() == "single_encoder") {
if (all_encoders.empty() ||
IsDirectPredecessorOf(all_encoders.back(), node)) {
all_encoders.push_back(node);
} else {
break;
// if no node linked from all_encoders.back(), search is over
int encoder_num = 0;
do {
encoder_num = all_encoders.size();
for (auto* node : graph->StmtTopologicalOrder()) {
CHECK(node->IsStmt());
if (node->stmt()->op_info()->Type() == "single_encoder") {
if (all_encoders.empty() ||
IsDirectPredecessorOf(all_encoders.back(), node)) {
all_encoders.push_back(node);
}
}
}
}
VLOG(3) << "Found continuous " << all_encoders.size() << " single_encoder";
} while (encoder_num != all_encoders.size());
if (all_encoders.size() == 0) {
return;
}
VLOG(3) << "Found continuous " << all_encoders.size() << " single_encoder";

const bool enable_int8 =
all_encoders[0]->stmt()->op_info()->HasAttr("enable_int8") &&
Expand Down Expand Up @@ -773,6 +776,14 @@ class XPUMultiEncoderFuser {
CHECK_EQ(fc_precision_, "int8");
CHECK_EQ(fc_input_max.size(), all_encoders.size() * 6);
CHECK_EQ(fc_weight_max.size(), all_encoders.size() * 6);
for (int i = 0; i < fc_weight_max.size(); i += 6) {
CHECK_LT(std::abs(fc_weight_max[i] - fc_weight_max[i + 1]), 1e-5)
<< " quanted ernie's q/k weight scale should be equal: "
<< fc_weight_max[i] << ", " << fc_weight_max[i + 1];
CHECK_LT(std::abs(fc_weight_max[i] - fc_weight_max[i + 2]), 1e-5)
<< " quanted ernie's q/v weight scale should be equal: "
<< fc_weight_max[i] << ", " << fc_weight_max[i + 2];
}
op_desc.SetAttr<std::vector<float>>("FCInputMax", fc_input_max);
// "FCWeightMax" is also stored as "Input" now
op_desc.SetAttr<std::vector<float>>("FCWeightMax", fc_weight_max);
Expand Down Expand Up @@ -977,6 +988,7 @@ class XPUMultiEncoderFuser {
weight_dim1_acc += weight_dims_vec[i][1];
if (i > 0) {
CHECK_EQ(weight_dims_vec[i][0], weight_dims_vec[i - 1][0]);
CHECK_EQ(start % 6, 0) << " qkv fuse position invalid: " << start;
}
}

Expand Down Expand Up @@ -1046,7 +1058,7 @@ class XPUMultiEncoderFuser {
weight_qkv_trans_int8.get(),
max_f,
qkv_len);
memcpy(weight_tensor_vec[0]->mutable_data<float>(),
memcpy(weight_tensor_vec[0]->mutable_data<int8_t>(),
weight_qkv_trans_int8.get(),
qkv_len * sizeof(int8_t));
} else {
Expand All @@ -1056,7 +1068,7 @@ class XPUMultiEncoderFuser {
weight_qkv_trans_int16.get(),
max_f,
qkv_len);
memcpy(weight_tensor_vec[0]->mutable_data<float>(),
memcpy(weight_tensor_vec[0]->mutable_data<int16_t>(),
weight_qkv_trans_int16.get(),
qkv_len * sizeof(int16_t));
}
Expand Down
41 changes: 21 additions & 20 deletions lite/kernels/xpu/__xpu__fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ void XPUFcCompute::Run() {
int n = param.w->dims()[1];
bool quant_int8 = param.quant_w_max > 0.f;

param.output_max->Resize({lite::XPU_QUANT_SCALE_NUM});
float* output_max = quant_int8
? nullptr
: param.output_max->mutable_data<float>(TARGET(kXPU));
Expand All @@ -123,26 +124,26 @@ void XPUFcCompute::Run() {
}
// TODO(weihaoji): remove fc_int31 and fc_int16 after xpu fc wrapper refactor
if (param.precision == "int31") {
int r = xdnn::fc_int31(
ctx.GetRawContext(), /* context */
false, /* TransA */
true, /* TransB */
m, /* m */
n, /* n */
k, /* k */
1.0f, /* alpha */
param.input->data<float>(), /* A */
nullptr, /* max_a ptr */
reinterpret_cast<const float*>(quant_weight_guard_->addr_), /* B */
w_max, /* max_b */
0.0f, /* beta */
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
nullptr, /* max_c ptr */
bias, /* bias */
act /* act_type */);
CHECK_EQ(r, 0);
r = xdnn::findmax<float>(
ctx.GetRawContext(), param.output->data<float>(), m * n, output_max);
int r = xdnn::fc_fusion<float, float, float, int>(
ctx.GetRawContext(), // ctx
param.input->data<float>(), // x
reinterpret_cast<const float*>(quant_weight_guard_->addr_), // w
param.output->mutable_data<float>(TARGET(kXPU)), // y
m, // m
n, // n
k, // k
false, // x_trans
true, // w_trans
input_max, // x_maxptr
reinterpret_cast<const float*>(weight_max_guard_->addr_), // w_maxptr
output_max, // y_maxptr
k, // ldx
k, // ldw
n, // ldy
1.0f, // alpha
0.0f, // beta
bias, // bias
act);
CHECK_EQ(r, 0);
} else if (param.precision == "int16") {
int r = 0;
Expand Down
1 change: 0 additions & 1 deletion lite/operators/__xpu__fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ bool XPUFcOp::InferShapeImpl() const {
}
output_dims[in_num_col_dims] = w_dims_1;
param_.output->Resize(output_dims);
param_.output_max->Resize({4});

// share LoD
param_.output->set_lod(param_.input->lod());
Expand Down

0 comments on commit e23bbdb

Please sign in to comment.