Skip to content

Commit

Permalink
[xpu] support quanted ernie model; test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
newway committed Sep 3, 2021
1 parent 4110285 commit 3f68355
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 152 deletions.
10 changes: 10 additions & 0 deletions lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ class XPUFcFuser : public FuseBase {
} else if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int8" ||
lite::TargetWrapperXPU::multi_encoder_precision == "int8") {
precision = "int8";
if (op_desc.HasAttr("enable_int8") &&
op_desc.GetAttr<bool>("enable_int8")) {
CHECK(op_desc.HasAttr("X0_scale")) << " quant model fc no X0_scale";
CHECK(op_desc.HasAttr("Y0_scale")) << " quant model fc no Y0_scale";
VLOG(3) << "Use int8 quant model in XPUFcOp, InputMax:"
<< 127 * op_desc.GetAttr<std::vector<float>>("X0_scale")[0]
<< ", WeightMax: "
<< 127 * op_desc.GetAttr<std::vector<float>>("Y0_scale")[0];
}
VLOG(3) << "Use int8 in XPUFcOp";
}
#endif
Expand Down Expand Up @@ -134,6 +143,7 @@ class XPUFcFuser : public FuseBase {
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));

// meaningless when enable_int8
std::string max_output_name = output_name + "_max";
auto* max_output_node = graph->NewArgumentNode(max_output_name);
max_output_node->arg()->type = LiteType::GetTensorTy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ class XPULinkMaxFuser : public FuseBase {
public:
explicit XPULinkMaxFuser(const std::string& op_type) { op_type_ = op_type; }
void BuildPattern() override {
auto non_quant_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
return (!op_desc.HasAttr("enable_int8")
|| !op_desc.GetAttr<bool>("enable_int8"));
};
auto* input = VarNode("input")->assert_is_op_input(op_type_, "Input");
auto* xpu_fusion_op = OpNode("xpu_fusion_op", op_type_);
auto* xpu_fusion_op = OpNode("xpu_fusion_op", op_type_)
->assert_node_satisfied(non_quant_teller);
*input >> *xpu_fusion_op;
}

Expand Down
326 changes: 189 additions & 137 deletions lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions lite/core/optimizer/mir/static_kernel_pick_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} else {
bool out_type_int8 = true;
// Quantized lstm has fp32 output
if (instruct.op_type() == "lstm" || instruct.op_type() == "gru") {
if (instruct.op_type() == "lstm" || instruct.op_type() == "gru"
|| instruct.op_type() == "__xpu__multi_encoder"
|| instruct.op_type() == "__xpu__fc") {
out_type_int8 = false;
}
// Only if all ops linked to this op output has enable_int8 attr,
Expand All @@ -105,7 +107,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK(tmp_op->IsStmt());
auto* tmp_op_info = tmp_op->AsStmt().op_info();
if (!tmp_op_info->HasAttr("enable_int8") ||
tmp_op_info->Type() == "lstm" || tmp_op_info->Type() == "gru") {
tmp_op_info->Type() == "lstm" || tmp_op_info->Type() == "gru"
|| instruct.op_type() == "__xpu__multi_encoder"
|| instruct.op_type() == "__xpu__fc") {
out_type_int8 = false;
break;
}
Expand Down
54 changes: 43 additions & 11 deletions lite/kernels/xpu/__xpu__fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,35 @@ void XPUFcCompute::PrepareForRun() {
auto w_ptr = param.w->data<float>();
auto w_len = param.w->numel();
auto weight_dims = param.w->dims();
bool quant_int8 = false;
if (param.quant_w_max > 0.f) {
quant_int8 = true;
}
// max
w_max = paddle::lite::xpu::math::FindMaxAbs(w_ptr, w_len);
std::vector<float> w_max_v(4, w_max);
weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_guard_->addr_),
w_max_v.data(),
4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
if (!quant_int8) {
w_max = paddle::lite::xpu::math::FindMaxAbs(w_ptr, w_len);
std::vector<float> w_max_v(4, w_max);
weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_guard_->addr_),
w_max_v.data(),
4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
input_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
}
// transpose
if (quant_int8) {
std::vector<int8_t> transpose_w_int8(w_len, 0);
paddle::lite::xpu::math::Transpose<int8_t>(
reinterpret_cast<const int8_t*>(w_ptr),
transpose_w_int8.data(), weight_dims[0], weight_dims[1]);
quant_weight_guard_ =
TargetWrapperXPU::MallocScratchPad(w_len * sizeof(int8_t));
XPU_CALL(xpu_memcpy(reinterpret_cast<int8_t*>(quant_weight_guard_->addr_),
transpose_w_int8.data(),
w_len * sizeof(int8_t),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
return;
}
std::vector<float> transpose_w(w_len, 0);
paddle::lite::xpu::math::Transpose(
w_ptr, transpose_w.data(), weight_dims[0], weight_dims[1]);
Expand Down Expand Up @@ -70,7 +90,6 @@ void XPUFcCompute::PrepareForRun() {
w_len * sizeof(int8_t),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
}
input_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float));
}

void XPUFcCompute::Run() {
Expand All @@ -82,11 +101,13 @@ void XPUFcCompute::Run() {
int m = in_mat_dims[0];
int k = in_mat_dims[1];
int n = param.w->dims()[1];
bool quant_int8 = param.quant_w_max > 0.f;

float* output_max = param.output_max->mutable_data<float>(TARGET(kXPU));
float* output_max = quant_int8 ? nullptr
: param.output_max->mutable_data<float>(TARGET(kXPU));
const auto* bias = param.has_bias ? param.bias->data<float>() : nullptr;
const float* input_max =
param.input_max ? param.input_max->data<float>() : nullptr;
const float* input_max = quant_int8 ? nullptr :
(param.input_max ? param.input_max->data<float>() : nullptr);
xdnn::Activation_t act((xdnn::Activation_t::act_enum)param.act_type);
if (param.act_type == 5) {
act.leaky_alpha = param.act_param;
Expand Down Expand Up @@ -150,6 +171,17 @@ void XPUFcCompute::Run() {
} else if (param.precision == "int8") {
bool x_trans = false;
bool w_trans = true;
if (quant_int8) {
int r = xdnn::fc_int8(ctx.GetRawContext(), false, true,
m, n, k,
1.0f, param.input->data<float>(), param.quant_input_max,
reinterpret_cast<const int8_t*>(quant_weight_guard_->addr_),
param.quant_w_max,
0.f, param.output->mutable_data<float>(TARGET(kXPU)),
bias, act);
CHECK_EQ(r, 0);
return;
}
int ldx = (x_trans ? m : k);
int ldw = (w_trans ? k : n);
int ldy = n;
Expand Down
7 changes: 6 additions & 1 deletion lite/kernels/xpu/__xpu__multi_encoder_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ void XPUMultiEncoderCompute::PrepareForRun() {
encoder_param_.n_layers = param.n_layers;
encoder_param_.pretrans_b = true;
encoder_param_.use_l3 = true;
if (param.input_max.size()) {
encoder_param_.input_max = param.input_max;
encoder_param_.weight_max = param.weight_max;
}
encoder_param_.slice_starts = param.slice_starts;
encoder_param_.slice_ends = param.slice_ends;
encoder_param_.slice_axes = param.slice_axes;
Expand Down Expand Up @@ -94,7 +98,8 @@ int XPUMultiEncoderCompute::bert_encoder_run() {
arg_fc_bias_, /* fc_biass */
arg_ln_scale_, /* ln_scales */
arg_ln_bias_, /* ln_biass */
param.fc_weight_max->data<float>(), /* fc_weights_max */
/* fc_weights_max = param.weight_max */
param.fc_weight_max->data<float>(),
encoder_param_);
} else {
r = xdnn::bert_encoder_transformer_int16<float, int16_t, float>(
Expand Down
9 changes: 9 additions & 0 deletions lite/operators/__xpu__fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ bool XPUFcOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
if (op_desc.HasAttr("precision")) {
param_.precision = op_desc.GetAttr<std::string>("precision");
}
if (op_desc.HasAttr("enable_int8")
&& op_desc.GetAttr<bool>("enable_int8")) {
CHECK(param_.precision == "int8")
<< "enable_int8 precison:" << param_.precision;
param_.quant_input_max =
127 * op_desc.GetAttr<std::vector<float>>("X0_scale")[0];
param_.quant_w_max =
127 * op_desc.GetAttr<std::vector<float>>("Y0_scale")[0];
}
return true;
}

Expand Down
4 changes: 4 additions & 0 deletions lite/operators/__xpu__multi_encoder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc,
param_.enable_qkv_fusion = op_desc.GetAttr<bool>("enable_qkv_fusion");
param_.norm_before = op_desc.GetAttr<bool>("norm_before");
param_.adaptive_seqlen = op_desc.GetAttr<bool>("adaptive_seqlen");
if (op_desc.HasAttr("enable_int8") && op_desc.GetAttr<bool>("enable_int8")) {
param_.input_max = op_desc.GetAttr<std::vector<float>>("FCInputMax");
param_.weight_max = op_desc.GetAttr<std::vector<float>>("FCWeightMax");
}

if (op_desc.HasAttr("slice_axes")) {
param_.slice_axes = op_desc.GetAttr<std::vector<int>>("slice_axes");
Expand Down
4 changes: 4 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,8 @@ struct XPUMultiEncoderParam : ParamBase {
std::vector<int> slice_starts{};
std::vector<int> slice_ends{};
std::vector<int> slice_decrease_axis{};
std::vector<float> input_max{};
std::vector<float> weight_max{};
int n_layers{};
int head_num{};
int size_per_head{};
Expand Down Expand Up @@ -2016,6 +2018,8 @@ struct XPUFcParam : ParamBase {

int act_type;
float act_param;
float quant_input_max{0.f};
float quant_w_max{0.f};
std::string precision{};
bool has_bias{false};
int in_num_col_dims{1};
Expand Down

0 comments on commit 3f68355

Please sign in to comment.