From b308346a4118e35b17f90831b44a4aad73d8b08b Mon Sep 17 00:00:00 2001 From: newway Date: Fri, 3 Sep 2021 17:13:53 +0800 Subject: [PATCH 1/4] [xpu] support quanted ernie model; test=develop --- .../mir/fusion/__xpu__fc_fuse_pass.cc | 10 + .../__xpu__link_previous_out_max_pass.cc | 8 +- .../fusion/__xpu__multi_encoder_fuse_pass.cc | 326 ++++++++++-------- .../optimizer/mir/static_kernel_pick_pass.cc | 8 +- lite/kernels/xpu/__xpu__fc_compute.cc | 67 +++- .../xpu/__xpu__multi_encoder_compute.cc | 7 +- lite/operators/__xpu__fc_op.cc | 9 + lite/operators/__xpu__multi_encoder_op.cc | 4 + lite/operators/op_params.h | 4 + 9 files changed, 291 insertions(+), 152 deletions(-) diff --git a/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc index 2f9e4942b22..cb0a66cd617 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc @@ -91,6 +91,15 @@ class XPUFcFuser : public FuseBase { } else if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int8" || lite::TargetWrapperXPU::multi_encoder_precision == "int8") { precision = "int8"; + if (op_desc.HasAttr("enable_int8") && + op_desc.GetAttr("enable_int8")) { + CHECK(op_desc.HasAttr("X0_scale")) << " quant model fc no X0_scale"; + CHECK(op_desc.HasAttr("Y0_scale")) << " quant model fc no Y0_scale"; + VLOG(3) << "Use int8 quant model in XPUFcOp, InputMax:" + << 127 * op_desc.GetAttr>("X0_scale")[0] + << ", WeightMax: " + << 127 * op_desc.GetAttr>("Y0_scale")[0]; + } VLOG(3) << "Use int8 in XPUFcOp"; } #endif @@ -134,6 +143,7 @@ class XPUFcFuser : public FuseBase { "in_num_col_dims", matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + // meaningless when enable_int8 std::string max_output_name = output_name + "_max"; auto* max_output_node = graph->NewArgumentNode(max_output_name); max_output_node->arg()->type = LiteType::GetTensorTy( diff --git a/lite/core/optimizer/mir/fusion/__xpu__link_previous_out_max_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__link_previous_out_max_pass.cc index e259839d02e..64b0fc50d3b 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__link_previous_out_max_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__link_previous_out_max_pass.cc @@ -59,8 +59,14 @@ class XPULinkMaxFuser : public FuseBase { public: explicit XPULinkMaxFuser(const std::string& op_type) { op_type_ = op_type; } void BuildPattern() override { + auto non_quant_teller = [](const Node* node) -> bool { + auto op_desc = *const_cast(node)->stmt()->op_info(); + return (!op_desc.HasAttr("enable_int8") + || !op_desc.GetAttr("enable_int8")); + }; auto* input = VarNode("input")->assert_is_op_input(op_type_, "Input"); - auto* xpu_fusion_op = OpNode("xpu_fusion_op", op_type_); + auto* xpu_fusion_op = OpNode("xpu_fusion_op", op_type_) + ->assert_node_satisfied(non_quant_teller); *input >> *xpu_fusion_op; } diff --git a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 80c7acda9d1..c2675bce7db 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -482,6 +482,48 @@ class XPUSingleEncoderFuser : public FuseBase { op_desc.SetAttr>("input_data_names", {}); op_desc.SetAttr>("output_data_names", {}); + auto* q_mul_op_info = matched.at("q_mul")->stmt()->op_info(); + if (q_mul_op_info->HasAttr("enable_int8") && + q_mul_op_info->GetAttr("enable_int8")) { + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr>("X0_max", + { + 127 * matched.at("q_mul")->stmt()->op_info() + ->GetAttr>("X0_scale")[0], + 127 * matched.at("k_mul")->stmt()->op_info() + ->GetAttr>("X0_scale")[0], + 127 * matched.at("v_mul")->stmt()->op_info() + ->GetAttr>("X0_scale")[0], + 127 * matched.at("qkv_mul")->stmt()->op_info() + ->GetAttr>("X0_scale")[0], + 127 * matched.at("qkv_mul_3")->stmt()->op_info() + ->GetAttr>("X0_scale")[0], + 127 * matched.at("qkv_mul_4")->stmt()->op_info() + ->GetAttr>("X0_scale")[0], + }); + op_desc.SetAttr>("Y0_max", + { + 127 * matched.at("q_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0], + 127 * matched.at("k_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0], + 127 * matched.at("v_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0], + 127 * matched.at("qkv_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0], + 127 * matched.at("qkv_mul_3")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0], + 127 * matched.at("qkv_mul_4")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0], + }); + VLOG(3) << "q/k/v 127*y0_scale: " << + 127 * matched.at("q_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0] << ", " << + 127 * matched.at("k_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0] << ", " << + 127 * matched.at("v_mul")->stmt()->op_info() + ->GetAttr>("Y0_scale")[0]; + } // extra traits to distill auto* reshape_op_info = matched.at("q_reshape2")->stmt()->op_info(); auto reshape_dim = reshape_op_info->GetAttr>("shape"); @@ -590,6 +632,9 @@ class XPUMultiEncoderFuser { return; } + const bool enable_int8 = all_encoders[0]->stmt() + ->op_info()->HasAttr("enable_int8") + && all_encoders[0]->stmt()->op_info()->GetAttr("enable_int8"); // TODO(miaotianxiang): more verification const bool norm_before_0 = all_encoders[0]->stmt()->op_info()->GetAttr("norm_before"); @@ -615,9 +660,21 @@ class XPUMultiEncoderFuser { std::vector arg_names{ "FCWeight", "FCBias", "LNScale", "LNBias"}; std::map> arg_map; + std::vector fc_weight_max; + std::vector fc_input_max; for (size_t i = 0; i < all_encoders.size(); ++i) { Node* cur_encoder = all_encoders[i]; auto* op_info = cur_encoder->stmt()->op_info(); + if (enable_int8) { + CHECK(op_info->HasAttr("enable_int8") && op_info->HasAttr("Y0_max") + && op_info->HasAttr("X0_max")/* && op_info->HasAttr("Out0_max")*/); + for (auto y0 : op_info->GetAttr>("Y0_max")) { + fc_weight_max.push_back(y0); + } + for (auto x0 : op_info->GetAttr>("X0_max")) { + fc_input_max.push_back(x0); + } + } for (auto arg_name : arg_names) { auto real_names = op_info->Input(arg_name); for (auto name : real_names) { @@ -662,6 +719,19 @@ class XPUMultiEncoderFuser { op_desc.SetOutput("Output", {out_name}); op_desc.SetAttr("xpu", 1); op_desc.SetAttr("norm_before", norm_before_0); + op_desc.SetAttr("enable_int8", enable_int8); + if (enable_int8) { + CHECK_EQ(fc_precision_, "int8"); + CHECK_EQ(fc_input_max.size(), all_encoders.size() * 6); + CHECK_EQ(fc_weight_max.size(), all_encoders.size() * 6); + op_desc.SetAttr>("FCInputMax", fc_input_max); + // "FCWeightMax" is also stored as "Input" now + op_desc.SetAttr>("FCWeightMax", fc_weight_max); + // only support adaptive_seqlen in int8 quant model + CHECK_EQ(adaptive_seqlen_, true); + } else { + fc_weight_max.resize(arg_map["FCWeight"].size()); + } auto* first_encoder_op_info = multi_encoder_stmt->op_info(); op_desc.SetAttr("head_num", first_encoder_op_info->GetAttr("head_num")); @@ -681,148 +751,18 @@ class XPUMultiEncoderFuser { op_desc.SetAttr("enable_qkv_fusion", enable_qkv_fusion); auto* scope = multi_encoder_stmt->op()->scope(); - std::vector fc_weight_max(arg_map["FCWeight"].size()); auto& fc_weight_names = arg_map["FCWeight"]; + CHECK_EQ(fc_weight_max.size(), fc_weight_names.size()); for (size_t i = 0; i < fc_weight_names.size(); ++i) { if (enable_qkv_fusion && (i % 6 == 0)) { - // q/k/v FCWeight fusion - auto* weight_q = scope->FindMutableTensor(fc_weight_names[i]); - auto* weight_k = scope->FindMutableTensor(fc_weight_names[i + 1]); - auto* weight_v = scope->FindMutableTensor(fc_weight_names[i + 2]); - auto weight_q_dims = weight_q->dims(); - auto weight_k_dims = weight_k->dims(); - auto weight_v_dims = weight_v->dims(); - int weight_q_len = weight_q->numel(); - int weight_k_len = weight_k->numel(); - int weight_v_len = weight_v->numel(); - float* weight_q_on_host = weight_q->mutable_data(); - float* weight_k_on_host = weight_k->mutable_data(); - float* weight_v_on_host = weight_v->mutable_data(); - int qkv_len = weight_q_len + weight_k_len + weight_v_len; - int qkv_offset = 0; - CHECK_EQ(weight_q_dims[0], weight_k_dims[0]); - CHECK_EQ(weight_q_dims[0], weight_v_dims[0]); - - // 1. transpose - std::unique_ptr weight_q_trans(new float[weight_q_len]); - std::unique_ptr weight_k_trans(new float[weight_k_len]); - std::unique_ptr weight_v_trans(new float[weight_v_len]); - std::unique_ptr weight_qkv_trans(new float[qkv_len]); - paddle::lite::xpu::math::Transpose(weight_q_on_host, - weight_q_trans.get(), - weight_q_dims[0], - weight_q_dims[1]); - paddle::lite::xpu::math::Transpose(weight_k_on_host, - weight_k_trans.get(), - weight_k_dims[0], - weight_k_dims[1]); - paddle::lite::xpu::math::Transpose(weight_v_on_host, - weight_v_trans.get(), - weight_v_dims[0], - weight_v_dims[1]); - - // 2. concat - memcpy(weight_qkv_trans.get() + qkv_offset, - weight_q_trans.get(), - weight_q_len * sizeof(float)); - qkv_offset += weight_q_len; - memcpy(weight_qkv_trans.get() + qkv_offset, - weight_k_trans.get(), - weight_k_len * sizeof(float)); - qkv_offset += weight_k_len; - memcpy(weight_qkv_trans.get() + qkv_offset, - weight_v_trans.get(), - weight_v_len * sizeof(float)); - qkv_offset += weight_v_len; - CHECK_EQ(qkv_offset, qkv_len); - - weight_q->Resize( - {weight_q_dims[1] + weight_k_dims[1] + weight_v_dims[1], - weight_q_dims[0]}); - - // 3. int31 or int16 - float max_f = paddle::lite::xpu::math::FindMaxAbs( - weight_qkv_trans.get(), qkv_len); - fc_weight_max[i] = max_f; - VLOG(3) << "QKV fused FC-" << i << ", weight_max:" << max_f; - if (fc_precision_ == "int31") { - memcpy(weight_q->mutable_data(), - weight_qkv_trans.get(), - qkv_len * sizeof(float)); - } else if (fc_precision_ == "int8") { - std::unique_ptr weight_qkv_trans_int8(new int8_t[qkv_len]); - paddle::lite::xpu::math::ConvertFP32ToInt8( - weight_qkv_trans.get(), - weight_qkv_trans_int8.get(), - max_f, - qkv_len); - memcpy(weight_q->mutable_data(), - weight_qkv_trans_int8.get(), - qkv_len * sizeof(int8_t)); - } else { - std::unique_ptr weight_qkv_trans_int16( - new int16_t[qkv_len]); - paddle::lite::xpu::math::ConvertFP32ToInt16( - weight_qkv_trans.get(), - weight_qkv_trans_int16.get(), - max_f, - qkv_len); - memcpy(weight_q->mutable_data(), - weight_qkv_trans_int16.get(), - qkv_len * sizeof(int16_t)); - } - + // quant q/k/v weight into q + update_weight(scope, fc_weight_names, i, i + 3, + enable_int8, &fc_weight_max); continue; } - - // no q/k/v fusion - auto* weight_t = scope->FindMutableTensor(fc_weight_names[i]); - auto weight_dims = weight_t->dims(); - int weight_len = weight_t->numel(); - float* weight_on_host = weight_t->mutable_data(); - - float max_f = - paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); - VLOG(3) << "FC-" << i << ", weight_max:" << max_f; - // i ranges from 0 to 6*encoder_num, so we need to do i%6 to get relative - // position in the encoder - if (fc_precision_ == "int31") { - // FCs in encoder use int31 - std::unique_ptr weight_trans_fp32(new float[weight_len]); - paddle::lite::xpu::math::Transpose(weight_on_host, - weight_trans_fp32.get(), - weight_dims[0], - weight_dims[1]); - - memcpy(weight_on_host, - weight_trans_fp32.get(), - weight_len * sizeof(float)); - } else if (fc_precision_ == "int8") { - std::unique_ptr weight_int8(new int8_t[weight_len]); - std::unique_ptr weight_trans_int8(new int8_t[weight_len]); - paddle::lite::xpu::math::ConvertFP32ToInt8( - weight_on_host, weight_int8.get(), max_f, weight_len); - paddle::lite::xpu::math::Transpose(weight_int8.get(), - weight_trans_int8.get(), - weight_dims[0], - weight_dims[1]); - memcpy(weight_on_host, - weight_trans_int8.get(), - weight_len * sizeof(int8_t)); - } else { - std::unique_ptr weight_int16(new int16_t[weight_len]); - std::unique_ptr weight_trans_int16(new int16_t[weight_len]); - paddle::lite::xpu::math::ConvertFP32ToInt16( - weight_on_host, weight_int16.get(), max_f, weight_len); - paddle::lite::xpu::math::Transpose(weight_int16.get(), - weight_trans_int16.get(), - weight_dims[0], - weight_dims[1]); - memcpy(weight_on_host, - weight_trans_int16.get(), - weight_len * sizeof(int16_t)); - } - fc_weight_max[i] = max_f; + // quant weight + update_weight(scope, fc_weight_names, i, i + 1, + enable_int8, &fc_weight_max); } auto& fc_bias_names = arg_map["FCBias"]; @@ -941,6 +881,117 @@ class XPUMultiEncoderFuser { private: std::string fc_precision_; bool adaptive_seqlen_; + void update_weight(Scope* scope, + const std::vector& fc_weight_names, int start, + int end, bool enable_int8, std::vector* fc_weight_max) { + CHECK(start >=0 && end <= fc_weight_names.size()); + CHECK(start < end) << " start:" << start << ", end:" << end; + std::vector weight_tensor_vec(end - start, nullptr); + std::vector weight_dims_vec(end - start); + std::vector weight_len_vec(end - start); + int qkv_len = 0; + int weight_dim1_acc = 0; + for (int i = 0; i < (end - start); ++i) { + weight_tensor_vec[i] = scope->FindMutableTensor( + fc_weight_names[start + i]); + CHECK(weight_tensor_vec[i] != nullptr); + weight_dims_vec[i] = weight_tensor_vec[i]->dims(); + weight_len_vec[i] = weight_tensor_vec[i]->numel(); + qkv_len += weight_len_vec[i]; + weight_dim1_acc += weight_dims_vec[i][1]; + if (i > 0) { + CHECK_EQ(weight_dims_vec[i][0], weight_dims_vec[i-1][0]); + } + } + + int qkv_offset = 0; + if (enable_int8) { + CHECK_EQ(fc_precision_, "int8"); + CHECK(end <= fc_weight_max->size()); + std::unique_ptr weight_qkv_trans(new int8_t[qkv_len]); + float max_f = (*fc_weight_max)[start]; + for (int i = 0; i < (end - start); ++i) { + // the quanted weight is alreay int8 in quanted model + int8_t* weight_host_ptr = + weight_tensor_vec[i]->mutable_data(); + std::unique_ptr + weight_host_trans(new int8_t[weight_len_vec[i]]); + paddle::lite::xpu::math::Transpose(weight_host_ptr, + weight_host_trans.get(), + weight_dims_vec[i][0], + weight_dims_vec[i][1]); + memcpy(weight_qkv_trans.get() + qkv_offset, + weight_host_trans.get(), + weight_len_vec[i] * sizeof(int8_t)); + qkv_offset += weight_len_vec[i]; + if (i > 0) { + max_f = std::max(max_f, (*fc_weight_max)[start + i]); + VLOG(5) << "start+i:" << start + i << ", weigh_max: " + << (*fc_weight_max)[start + i] << ", max_f:" << max_f; + } + } + CHECK_EQ(qkv_offset, qkv_len); + weight_tensor_vec[0]->Resize({weight_dim1_acc, + weight_dims_vec[0][0]}); + (*fc_weight_max)[start] = max_f; + VLOG(3) << "QKV fused FC-" << start << ", weight_max:" << max_f; + memcpy(weight_tensor_vec[0]->mutable_data(), + weight_qkv_trans.get(), + qkv_len * sizeof(int8_t)); + } else { + std::unique_ptr weight_qkv_trans(new float[qkv_len]); + for (int i = 0; i < (end - start); ++i) { + float* weight_host_ptr = + weight_tensor_vec[i]->mutable_data(); + std::unique_ptr + weight_host_trans(new float[weight_len_vec[i]]); + paddle::lite::xpu::math::Transpose(weight_host_ptr, + weight_host_trans.get(), + weight_dims_vec[i][0], + weight_dims_vec[i][1]); + memcpy(weight_qkv_trans.get() + qkv_offset, + weight_host_trans.get(), + weight_len_vec[i] * sizeof(float)); + qkv_offset += weight_len_vec[i]; + } + CHECK_EQ(qkv_offset, qkv_len); + weight_tensor_vec[0]->Resize({weight_dim1_acc, + weight_dims_vec[0][0]}); + float max_f = paddle::lite::xpu::math::FindMaxAbs( + weight_qkv_trans.get(), qkv_len); + CHECK(start < fc_weight_max->size()); + (*fc_weight_max)[start] = max_f; + VLOG(3) << "QKV fused FC-" << start << ", weight_max:" << max_f; + if (fc_precision_ == "int31") { + memcpy(weight_tensor_vec[0]->mutable_data(), + weight_qkv_trans.get(), + qkv_len * sizeof(float)); + } else if (fc_precision_ == "int8") { + // quant the weight here, not from the quanted-model + std::unique_ptr + weight_qkv_trans_int8(new int8_t[qkv_len]); + paddle::lite::xpu::math::ConvertFP32ToInt8( + weight_qkv_trans.get(), + weight_qkv_trans_int8.get(), + max_f, + qkv_len); + memcpy(weight_tensor_vec[0]->mutable_data(), + weight_qkv_trans_int8.get(), + qkv_len * sizeof(int8_t)); + } else { + std::unique_ptr weight_qkv_trans_int16( + new int16_t[qkv_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + weight_qkv_trans.get(), + weight_qkv_trans_int16.get(), + max_f, + qkv_len); + memcpy(weight_tensor_vec[0]->mutable_data(), + weight_qkv_trans_int16.get(), + qkv_len * sizeof(int16_t)); + } + } + } }; } // namespace fusion @@ -986,6 +1037,7 @@ class XPUMultiEncoderFusePass : public ProgramPass { << lite::TargetWrapperXPU::multi_encoder_precision; } adaptive_seqlen = lite::TargetWrapperXPU::multi_encoder_adaptive_seqlen; + VLOG(3) << "adaptive_seqlen: " << adaptive_seqlen; #endif for (auto& act_type : act_types) { diff --git a/lite/core/optimizer/mir/static_kernel_pick_pass.cc b/lite/core/optimizer/mir/static_kernel_pick_pass.cc index 43d24275aab..be0b0e009d5 100644 --- a/lite/core/optimizer/mir/static_kernel_pick_pass.cc +++ b/lite/core/optimizer/mir/static_kernel_pick_pass.cc @@ -92,7 +92,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { } else { bool out_type_int8 = true; // Quantized lstm has fp32 output - if (instruct.op_type() == "lstm" || instruct.op_type() == "gru") { + if (instruct.op_type() == "lstm" || instruct.op_type() == "gru" + || instruct.op_type() == "__xpu__multi_encoder" + || instruct.op_type() == "__xpu__fc") { out_type_int8 = false; } // Only if all ops linked to this op output has enable_int8 attr, @@ -105,7 +107,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { CHECK(tmp_op->IsStmt()); auto* tmp_op_info = tmp_op->AsStmt().op_info(); if (!tmp_op_info->HasAttr("enable_int8") || - tmp_op_info->Type() == "lstm" || tmp_op_info->Type() == "gru") { + tmp_op_info->Type() == "lstm" || tmp_op_info->Type() == "gru" + || instruct.op_type() == "__xpu__multi_encoder" + || instruct.op_type() == "__xpu__fc") { out_type_int8 = false; break; } diff --git a/lite/kernels/xpu/__xpu__fc_compute.cc b/lite/kernels/xpu/__xpu__fc_compute.cc index a70f1dfc587..7737c1f8b45 100644 --- a/lite/kernels/xpu/__xpu__fc_compute.cc +++ b/lite/kernels/xpu/__xpu__fc_compute.cc @@ -29,15 +29,37 @@ void XPUFcCompute::PrepareForRun() { auto w_ptr = param.w->data(); auto w_len = param.w->numel(); auto weight_dims = param.w->dims(); + bool quant_int8 = false; + if (param.quant_w_max > 0.f) { + quant_int8 = true; + } // max - w_max = paddle::lite::xpu::math::FindMaxAbs(w_ptr, w_len); - std::vector w_max_v(4, w_max); - weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); - XPU_CALL(xpu_memcpy(reinterpret_cast(weight_max_guard_->addr_), - w_max_v.data(), - 4 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + if (!quant_int8) { + w_max = paddle::lite::xpu::math::FindMaxAbs(w_ptr, w_len); + std::vector w_max_v(4, w_max); + weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + XPU_CALL(xpu_memcpy(reinterpret_cast(weight_max_guard_->addr_), + w_max_v.data(), + 4 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + input_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); + } // transpose + if (quant_int8) { + std::vector transpose_w_int8(w_len, 0); + paddle::lite::xpu::math::Transpose( + reinterpret_cast(w_ptr), + transpose_w_int8.data(), + weight_dims[0], + weight_dims[1]); + quant_weight_guard_ = + TargetWrapperXPU::MallocScratchPad(w_len * sizeof(int8_t)); + XPU_CALL(xpu_memcpy(reinterpret_cast(quant_weight_guard_->addr_), + transpose_w_int8.data(), + w_len * sizeof(int8_t), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + return; + } std::vector transpose_w(w_len, 0); paddle::lite::xpu::math::Transpose( w_ptr, transpose_w.data(), weight_dims[0], weight_dims[1]); @@ -70,7 +92,6 @@ void XPUFcCompute::PrepareForRun() { w_len * sizeof(int8_t), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); } - input_max_guard_ = TargetWrapperXPU::MallocScratchPad(4 * sizeof(float)); } void XPUFcCompute::Run() { @@ -82,11 +103,15 @@ void XPUFcCompute::Run() { int m = in_mat_dims[0]; int k = in_mat_dims[1]; int n = param.w->dims()[1]; + bool quant_int8 = param.quant_w_max > 0.f; - float* output_max = param.output_max->mutable_data(TARGET(kXPU)); + float* output_max = quant_int8 + ? nullptr + : param.output_max->mutable_data(TARGET(kXPU)); const auto* bias = param.has_bias ? param.bias->data() : nullptr; const float* input_max = - param.input_max ? param.input_max->data() : nullptr; + quant_int8 ? nullptr + : (param.input_max ? param.input_max->data() : nullptr); xdnn::Activation_t act((xdnn::Activation_t::act_enum)param.act_type); if (param.act_type == 5) { act.leaky_alpha = param.act_param; @@ -150,6 +175,26 @@ void XPUFcCompute::Run() { } else if (param.precision == "int8") { bool x_trans = false; bool w_trans = true; + if (quant_int8) { + int r = xdnn::fc_int8( + ctx.GetRawContext(), + false, + true, + m, + n, + k, + 1.0f, + param.input->data(), + param.quant_input_max, + reinterpret_cast(quant_weight_guard_->addr_), + param.quant_w_max, + 0.f, + param.output->mutable_data(TARGET(kXPU)), + bias, + act); + CHECK_EQ(r, 0); + return; + } int ldx = (x_trans ? m : k); int ldw = (w_trans ? k : n); int ldy = n; @@ -172,7 +217,7 @@ void XPUFcCompute::Run() { 1.0f, /* alpha */ 0.0f, /* beta */ bias, /* bias */ - act /* act_type */); + act); /* act_type */ CHECK_EQ(r, 0); } else { LOG(FATAL) << "Unsupport XPUFC Precision: " << param.precision; diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index 08a446bf95b..d1d477d2d06 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -54,6 +54,10 @@ void XPUMultiEncoderCompute::PrepareForRun() { encoder_param_.n_layers = param.n_layers; encoder_param_.pretrans_b = true; encoder_param_.use_l3 = true; + if (param.input_max.size()) { + encoder_param_.input_max = param.input_max; + encoder_param_.weight_max = param.weight_max; + } encoder_param_.slice_starts = param.slice_starts; encoder_param_.slice_ends = param.slice_ends; encoder_param_.slice_axes = param.slice_axes; @@ -94,7 +98,8 @@ int XPUMultiEncoderCompute::bert_encoder_run() { arg_fc_bias_, /* fc_biass */ arg_ln_scale_, /* ln_scales */ arg_ln_bias_, /* ln_biass */ - param.fc_weight_max->data(), /* fc_weights_max */ + /* fc_weights_max = param.weight_max */ + param.fc_weight_max->data(), encoder_param_); } else { r = xdnn::bert_encoder_transformer_int16( diff --git a/lite/operators/__xpu__fc_op.cc b/lite/operators/__xpu__fc_op.cc index 46fe9d92f2d..2840ad71977 100644 --- a/lite/operators/__xpu__fc_op.cc +++ b/lite/operators/__xpu__fc_op.cc @@ -112,6 +112,15 @@ bool XPUFcOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { if (op_desc.HasAttr("precision")) { param_.precision = op_desc.GetAttr("precision"); } + if (op_desc.HasAttr("enable_int8") + && op_desc.GetAttr("enable_int8")) { + CHECK(param_.precision == "int8") + << "enable_int8 precison:" << param_.precision; + param_.quant_input_max = + 127 * op_desc.GetAttr>("X0_scale")[0]; + param_.quant_w_max = + 127 * op_desc.GetAttr>("Y0_scale")[0]; + } return true; } diff --git a/lite/operators/__xpu__multi_encoder_op.cc b/lite/operators/__xpu__multi_encoder_op.cc index 75a558abfe9..b99c6b969f7 100644 --- a/lite/operators/__xpu__multi_encoder_op.cc +++ b/lite/operators/__xpu__multi_encoder_op.cc @@ -140,6 +140,10 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, param_.enable_qkv_fusion = op_desc.GetAttr("enable_qkv_fusion"); param_.norm_before = op_desc.GetAttr("norm_before"); param_.adaptive_seqlen = op_desc.GetAttr("adaptive_seqlen"); + if (op_desc.HasAttr("enable_int8") && op_desc.GetAttr("enable_int8")) { + param_.input_max = op_desc.GetAttr>("FCInputMax"); + param_.weight_max = op_desc.GetAttr>("FCWeightMax"); + } if (op_desc.HasAttr("slice_axes")) { param_.slice_axes = op_desc.GetAttr>("slice_axes"); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index a73517ab333..bb39339ebce 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1985,6 +1985,8 @@ struct XPUMultiEncoderParam : ParamBase { std::vector slice_starts{}; std::vector slice_ends{}; std::vector slice_decrease_axis{}; + std::vector input_max{}; + std::vector weight_max{}; int n_layers{}; int head_num{}; int size_per_head{}; @@ -2016,6 +2018,8 @@ struct XPUFcParam : ParamBase { int act_type; float act_param; + float quant_input_max{0.f}; + float quant_w_max{0.f}; std::string precision{}; bool has_bias{false}; int in_num_col_dims{1}; From 7c626ea5c94d2e52d83be6fa7a782d2f83d447ea Mon Sep 17 00:00:00 2001 From: weihaoji Date: Mon, 6 Sep 2021 12:43:50 +0800 Subject: [PATCH 2/4] [XPU] support xpu_kl2_encoder, test=develop, test=xpu --- lite/backends/xpu/target_wrapper.h | 2 + ...xpu__embedding_with_eltwise_add_compute.cc | 111 ++----- ..._xpu__embedding_with_eltwise_add_compute.h | 3 - .../xpu/__xpu__multi_encoder_compute.cc | 284 ++++++++++-------- .../xpu/__xpu__multi_encoder_compute.h | 16 +- lite/kernels/xpu/assign_value_compute.cc | 16 +- lite/kernels/xpu/batch_norm_compute.cc | 27 +- lite/kernels/xpu/layer_norm_compute.cc | 9 +- lite/kernels/xpu/stack_compute.cc | 42 ++- lite/kernels/xpu/stack_compute.h | 6 - .../__xpu__embedding_with_eltwise_add_op.cc | 26 +- 11 files changed, 273 insertions(+), 269 deletions(-) diff --git a/lite/backends/xpu/target_wrapper.h b/lite/backends/xpu/target_wrapper.h index 4d1234153c3..d4bf718ec91 100644 --- a/lite/backends/xpu/target_wrapper.h +++ b/lite/backends/xpu/target_wrapper.h @@ -34,6 +34,8 @@ namespace lite { const int XPU_MAX_LOD_SIZE = 32; // MAX(lod[i + 1] - lod[i]) = 512 const int XPU_MAX_LOD_SEQ_LEN = 512; +// QUANT SCALE NUM == XPU CDNN NUM +const int XPU_QUANT_SCALE_NUM = 6; using TargetWrapperXPU = TargetWrapper; diff --git a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc index 58ee28b6630..afa41682417 100644 --- a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc +++ b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc @@ -21,40 +21,17 @@ namespace lite { namespace kernels { namespace xpu { -bool CheckEmbeddingIds(const int64_t* idx, - size_t idx_len, - size_t table_len, - int64_t padding_idx) { - CHECK_GT(idx_len, 0); - CHECK_GT(table_len, 0); - for (auto i = 0; i < idx_len; i++) { - if (idx[i] >= table_len || idx[i] < 0) { - if (idx[i] != padding_idx) { - return false; - } - } - } - return true; -} - void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() { auto& param = this->Param(); - - arg_ids_.reserve(param.Ids.size()); - arg_tables_.reserve(param.Tables.size()); + CHECK_GT(param.Tables.size(), 0); + auto embed_dim = param.Tables[0]->dims()[1]; for (auto* table : param.Tables) { auto& table_dims = table->dims(); CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */ + CHECK_EQ(table_dims[1], embed_dim); table_lens_cpu_.push_back(table_dims[0]); + arg_tables_.push_back(table->data()); } - - size_t lens_size = table_lens_cpu_.size() * sizeof(int); - table_lens_guard_ = TargetWrapperXPU::MallocScratchPad(lens_size); - XPU_CALL(xpu_memcpy(table_lens_guard_->addr_, - &table_lens_cpu_[0], - lens_size, - XPU_HOST_TO_DEVICE)); - idx_guard_ = TargetWrapperXPU::MallocScratchPad(32768 * sizeof(int64_t)); } void XPUEmbeddingWithEltwiseAddCompute::Run() { @@ -63,18 +40,16 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() { auto& id_dims = param.Ids[0]->dims(); int idx_len = id_dims[0] * id_dims[1]; int emb_layer_num = param.Ids.size(); - auto& table_dims = param.Tables[0]->dims(); - int embed_dim = table_dims[1]; - for (size_t i = 0; i < param.Tables.size(); ++i) { - arg_tables_[i] = param.Tables[i]->data(); - } + int embed_dim = param.Tables[0]->dims()[1]; + std::vector> int_idx(emb_layer_num, + std::vector(idx_len, 0)); + std::vector> arg_ids_; + if (param.Mask && param.Mask->data()) { auto& mask_dims = param.Mask->dims(); auto batch_size = mask_dims[0]; auto pad_seq_len = mask_dims[1]; param.PadSeqLen->mutable_data()[0] = pad_seq_len; - CHECK_EQ(batch_size, id_dims[0]); - CHECK_EQ(idx_len, param.Mask->numel()); auto* seq_lod = param.SeqLod; seq_lod->Resize({batch_size + 1}); std::vector cpu_seq_lod{0}; @@ -94,59 +69,39 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() { memcpy(seq_lod_ptr, cpu_seq_lod.data(), cpu_seq_lod.size() * sizeof(int)); idx_len = cpu_seq_lod.back(); - idx_guard_->Reserve(emb_layer_num * idx_len * sizeof(int64_t)); - int64_t* idx_xpu_ptr = static_cast(idx_guard_->addr_); - std::vector> idx_remove_pad( - emb_layer_num, std::vector(idx_len, 0)); for (size_t i = 0; i < emb_layer_num; ++i) { auto* idx_pad_ptr = param.Ids[i]->data(); for (auto batch_idx = 0; batch_idx < batch_size; batch_idx++) { - memcpy(&idx_remove_pad[i][cpu_seq_lod[batch_idx]], - idx_pad_ptr + batch_idx * pad_seq_len, - sizeof(int64_t) * - (cpu_seq_lod[batch_idx + 1] - cpu_seq_lod[batch_idx])); + for (auto j = 0; + j < cpu_seq_lod[batch_idx + 1] - cpu_seq_lod[batch_idx]; + j++) { + int_idx[i][cpu_seq_lod[batch_idx] + j] = + static_cast(idx_pad_ptr[batch_idx * pad_seq_len + j]); + } } - CHECK_EQ(CheckEmbeddingIds(&idx_remove_pad[i][0], - idx_len, - param.Tables[i]->dims()[0], - param.padding_idx), - true); - XPU_CALL(xpu_memcpy(idx_xpu_ptr + i * idx_len, - &idx_remove_pad[i][0], - sizeof(int64_t) * idx_len, - XPU_HOST_TO_DEVICE)); - arg_ids_[i] = idx_xpu_ptr + i * idx_len; + arg_ids_.push_back( + xdnn::VectorParam{int_idx[i].data(), idx_len, nullptr}); } param.Out->Resize({1, idx_len, embed_dim}); } else { - idx_guard_->Reserve(emb_layer_num * idx_len * sizeof(int64_t)); - int64_t* idx_xpu_ptr = static_cast(idx_guard_->addr_); - for (size_t i = 0; i < emb_layer_num; ++i) { - CHECK_EQ(idx_len, param.Ids[i]->numel()); - CHECK_EQ(CheckEmbeddingIds(param.Ids[i]->data(), - idx_len, - param.Tables[i]->dims()[0], - param.padding_idx), - true); - XPU_CALL(xpu_memcpy(idx_xpu_ptr + idx_len * i, - param.Ids[i]->data(), - sizeof(int64_t) * idx_len, - XPU_HOST_TO_DEVICE)); - arg_ids_[i] = idx_xpu_ptr + idx_len * i; + for (size_t i = 0; i < emb_layer_num; i++) { + for (size_t j = 0; j < idx_len; j++) { + int_idx[i][j] = static_cast(param.Ids[i]->data()[j]); + } + arg_ids_.push_back( + xdnn::VectorParam{int_idx[i].data(), idx_len, nullptr}); } } - int r = xdnn::embedding_with_ewadd( - ctx.GetRawContext(), /* context */ - embed_dim, /* embed_dim */ - idx_len, /* idx_len */ - emb_layer_num, /* emb_layer_num */ - param.padding_idx, /* padding_idx */ - &arg_tables_[0], /* tables */ - &arg_ids_[0], /* indices */ - static_cast(table_lens_guard_->addr_), /* table_lens */ - nullptr, /* scale_after_emb */ - nullptr, /* scale_after_ewadd */ - param.Out->mutable_data(TARGET(kXPU)) /* top */); + int r = xdnn::multi_embedding_fusion( + ctx.GetRawContext(), + arg_tables_, /* tables */ + param.Out->mutable_data(TARGET(kXPU)), + arg_ids_, + table_lens_cpu_, + embed_dim, + std::vector(table_lens_cpu_.size(), 1.0f), + std::vector(table_lens_cpu_.size(), + static_cast(param.padding_idx))); CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h index 06c84823b1c..fd977d7c60f 100644 --- a/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h +++ b/lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h @@ -33,10 +33,7 @@ class XPUEmbeddingWithEltwiseAddCompute void Run() override; private: - std::vector arg_ids_; std::vector arg_tables_; - XPUScratchPadGuard table_lens_guard_; - XPUScratchPadGuard idx_guard_; std::vector table_lens_cpu_; }; diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index 08a446bf95b..6e589b2794a 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -29,150 +29,198 @@ static std::vector prepare_weight( return result; } +template +std::vector* XPUMultiEncoderCompute::get_weight() { + LOG(FATAL) << "Invalid Weight Type"; + return nullptr; +} + +template <> +std::vector* XPUMultiEncoderCompute::get_weight() { + return &arg_fc_weight_int16_; +} + +template <> +std::vector* XPUMultiEncoderCompute::get_weight() { + return &arg_fc_weight_int8_; +} + +template <> +std::vector* XPUMultiEncoderCompute::get_weight() { + return &arg_fc_weight_fp32_; +} + void XPUMultiEncoderCompute::PrepareForRun() { auto& param = this->Param(); - - if (param.precision == "int16") { - arg_fc_weight_int16_ = prepare_weight(param.fc_weight); - } else if (param.precision == "int8") { - arg_fc_weight_int8_ = prepare_weight(param.fc_weight); - } else if (param.precision == "int31") { - arg_fc_weight_fp32_ = prepare_weight(param.fc_weight); - } + // prepare bias for (auto* fc_bias : param.fc_bias) { arg_fc_bias_.push_back(fc_bias->data()); } + // prepare scale for (auto* ln_scale : param.ln_scale) { arg_ln_scale_.push_back(ln_scale->data()); } + // prepare ln_bias for (auto* ln_bias : param.ln_bias) { arg_ln_bias_.push_back(ln_bias->data()); } - - encoder_param_.head_num = param.head_num; - encoder_param_.size_per_head = param.size_per_head; - encoder_param_.n_layers = param.n_layers; - encoder_param_.pretrans_b = true; - encoder_param_.use_l3 = true; - encoder_param_.slice_starts = param.slice_starts; - encoder_param_.slice_ends = param.slice_ends; - encoder_param_.slice_axes = param.slice_axes; - if (param.act_type == "relu") { - encoder_param_.act_type = xdnn::Activation_t::RELU; - } else if (param.act_type == "gelu") { - encoder_param_.act_type = xdnn::Activation_t::GELU; - } -} - -int XPUMultiEncoderCompute::bert_encoder_run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->As(); - ctx.GetRawContext()->qkv_fusion = param.enable_qkv_fusion; - - int r = -1; - if (param.precision == "int31") { - r = xdnn::bert_encoder_transformer_int31( - ctx.GetRawContext(), /* context */ - param.input->data(), /* from_tensor */ - param.input->data(), /* to_tensor */ - param.mask->data(), /* att_mask */ - param.output->mutable_data(TARGET(kXPU)), /* output */ - arg_fc_weight_fp32_, /* fc_weights */ - arg_fc_bias_, /* fc_biass */ - arg_ln_scale_, /* ln_scales */ - arg_ln_bias_, /* ln_biass */ - param.fc_weight_max->data(), /* fc_weights_max */ - encoder_param_); + // prepare weights + if (param.precision == "int16") { + arg_fc_weight_int16_ = prepare_weight(param.fc_weight); } else if (param.precision == "int8") { - r = xdnn::bert_encoder_transformer_int8( - ctx.GetRawContext(), /* context */ - param.input->data(), /* from_tensor */ - param.input->data(), /* to_tensor */ - param.mask->data(), /* att_mask */ - param.output->mutable_data(TARGET(kXPU)), /* output */ - arg_fc_weight_int8_, /* fc_weights */ - arg_fc_bias_, /* fc_biass */ - arg_ln_scale_, /* ln_scales */ - arg_ln_bias_, /* ln_biass */ - param.fc_weight_max->data(), /* fc_weights_max */ - encoder_param_); - } else { - r = xdnn::bert_encoder_transformer_int16( - ctx.GetRawContext(), /* context */ - param.input->data(), /* from_tensor */ - param.input->data(), /* to_tensor */ - param.mask->data(), /* att_mask */ - param.output->mutable_data(TARGET(kXPU)), /* output */ - arg_fc_weight_int16_, /* fc_weights */ - arg_fc_bias_, /* fc_biass */ - arg_ln_scale_, /* ln_scales */ - arg_ln_bias_, /* ln_biass */ - param.fc_weight_max->data(), /* fc_weights_max */ - encoder_param_); + arg_fc_weight_int8_ = prepare_weight(param.fc_weight); + } else if (param.precision == "int31") { + arg_fc_weight_fp32_ = prepare_weight(param.fc_weight); } - return r; + // prepare weight_max + weight_max_guard_ = TargetWrapperXPU::MallocScratchPad( + param.fc_weight_max->numel() * lite::XPU_QUANT_SCALE_NUM * sizeof(float)); + float* weight_max_ptr = reinterpret_cast(weight_max_guard_->addr_); + for (int i = 0; i < param.fc_weight_max->numel(); i++) { + float* cur_weight_max_ptr = weight_max_ptr + i * lite::XPU_QUANT_SCALE_NUM; + std::vector cpu_max(lite::XPU_QUANT_SCALE_NUM, + param.fc_weight_max->data()[i]); + lite::TargetWrapperXPU::MemcpySync( + cur_weight_max_ptr, + cpu_max.data(), + sizeof(float) * lite::XPU_QUANT_SCALE_NUM, + IoDirection::HtoD); + fc_weight_max_.push_back(cur_weight_max_ptr); + } + // prepare act_type + if (param.act_type == "gelu") { + qkv_act = xdnn::Activation_t::GELU; + } else if (param.act_type != "relu") { + CHECK(false) << "Invalid QKV Activation Type: " << param.act_type; + } + // prepare with sice + if ((param.slice_starts.size() > 0 && param.slice_starts[0] == 0) && + (param.slice_ends.size() > 0 && param.slice_ends[0] == 1) && + (param.slice_axes.size() > 0 && param.slice_axes[0] == 1)) { + slice_idx = 0; + } + // prepare input_cast and output_cast guard_ + cast_in_guard_ = TargetWrapperXPU::MallocScratchPad(4 * 1024 * 1024); + cast_out_guard_ = TargetWrapperXPU::MallocScratchPad(4 * 1024 * 1024); } -int XPUMultiEncoderCompute::transformer_encoder_run() { +template +void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { auto& param = this->Param(); auto& ctx = this->ctx_->As(); - ctx.GetRawContext()->qkv_fusion = param.enable_qkv_fusion; - int r = -1; - if (param.precision == "int31") { - LOG(FATAL) << "Not support int31 at now"; - } else if (param.precision == "int8") { - LOG(FATAL) << "Not support int8 at now"; + xdnn::VectorParam query_lod; + if (param.SeqLod && param.SeqLod->data()) { + // vsl + query_lod = {param.SeqLod->data(), + static_cast(param.SeqLod->numel()), + nullptr}; + xdnn::QKVAttnParam qkv_attn_param(query_lod, /* lod */ + param.head_num, + param.size_per_head, + qkv_act, + slice_idx, + true /* qkv fusion */); + + int r = xdnn::transformer_encoder( + ctx.GetRawContext(), + in, + *(XPUMultiEncoderCompute::get_weight()), + out, + nullptr, + fc_weight_max_, + nullptr, + arg_fc_bias_, + arg_ln_scale_, + arg_ln_bias_, + qkv_attn_param); + CHECK_EQ(r, 0); } else { - r = xdnn::transformer_encoder_int16( - ctx.GetRawContext(), /* context */ - param.input->data(), /* from_tensor */ - param.input->data(), /* to_tensor */ - param.mask->data(), /* att_mask */ - param.output->mutable_data(TARGET(kXPU)), /* output */ - arg_fc_weight_int16_, /* fc_weights */ - arg_fc_bias_, /* fc_biass */ - arg_ln_scale_, /* ln_scales */ - arg_ln_bias_, /* ln_biass */ - param.fc_weight_max->data(), /* fc_weights_max */ - encoder_param_); + // no vsl + int batch = static_cast(param.input->dims()[0]); + int max_seqlen = static_cast(param.input->dims()[1]); + xdnn::QKVAttnParam qkv_attn_param( + batch, + max_seqlen, + param.head_num, + param.size_per_head, + {batch, param.head_num, max_seqlen, max_seqlen}, + qkv_act, + slice_idx, + true); + int r = xdnn::transformer_encoder( + ctx.GetRawContext(), + in, + *(XPUMultiEncoderCompute::get_weight()), + out, + nullptr, + fc_weight_max_, + nullptr, + arg_fc_bias_, + arg_ln_scale_, + arg_ln_bias_, + qkv_attn_param, + param.mask->data()); + CHECK_EQ(r, 0); } - return r; } void XPUMultiEncoderCompute::Run() { auto& param = this->Param(); - std::vector mask_shape = param.mask->dims().Vectorize(); - encoder_param_.mask_shape = - std::vector(mask_shape.begin(), mask_shape.end()); - encoder_param_.slice_starts = param.slice_starts; - encoder_param_.slice_ends = param.slice_ends; - encoder_param_.slice_axes = param.slice_axes; - const bool norm_before_ = param.norm_before; - if (param.SeqLod && param.SeqLod->data()) { - auto& ctx = this->ctx_->As(); - ctx.GetRawContext()->batch_split_type = -1; // disable auto split batch - encoder_param_.seq_lod.resize(param.SeqLod->numel()); - memcpy(encoder_param_.seq_lod.data(), - param.SeqLod->data(), - sizeof(int) * param.SeqLod->numel()); - encoder_param_.adaptive_seqlen = true; - encoder_param_.batch_size = param.SeqLod->numel() - 1; - encoder_param_.from_seq_len = param.PadSeqLen->data()[0]; - encoder_param_.to_seq_len = param.PadSeqLen->data()[0]; - } else { - encoder_param_.adaptive_seqlen = false; - encoder_param_.batch_size = param.input->dims()[0]; - encoder_param_.from_seq_len = param.input->dims()[1]; - encoder_param_.to_seq_len = param.input->dims()[1]; - } - int r = -1; - if (norm_before_) { - r = transformer_encoder_run(); + auto& ctx = this->ctx_->As(); + const float* in = param.input->data(); + float* out = param.output->mutable_data(TARGET(kXPU)); + if (ctx.GetRawContext()->dev().type() == xdnn::kXPU1) { + if (param.precision == "int8") { + run_encoder(in, out); + } else if (param.precision == "int16") { + run_encoder(in, out); + } else if (param.precision == "int31") { + run_encoder(in, out); + } else { + CHECK(false); + } } else { - r = bert_encoder_run(); + cast_in_guard_->Reserve(param.input->numel() * sizeof(float)); + cast_out_guard_->Reserve(param.output->numel() * sizeof(float)); + if (param.precision == "int8") { + int r = xdnn::cast_v2( + ctx.GetRawContext(), + in, + reinterpret_cast(cast_in_guard_->addr_), + param.input->numel()); + CHECK_EQ(r, 0); + run_encoder( + reinterpret_cast(cast_in_guard_->addr_), + reinterpret_cast(cast_out_guard_->addr_)); + r = xdnn::cast_v2( + ctx.GetRawContext(), + reinterpret_cast(cast_out_guard_->addr_), + out, + param.output->numel()); + CHECK_EQ(r, 0); + } else if (param.precision == "int16") { + int r = xdnn::cast_v2( + ctx.GetRawContext(), + in, + reinterpret_cast(cast_in_guard_->addr_), + param.input->numel()); + CHECK_EQ(r, 0); + run_encoder( + reinterpret_cast(cast_in_guard_->addr_), + reinterpret_cast(cast_out_guard_->addr_)); + r = xdnn::cast_v2( + ctx.GetRawContext(), + reinterpret_cast(cast_out_guard_->addr_), + out, + param.output->numel()); + CHECK_EQ(r, 0); + } else if (param.precision == "int31") { + run_encoder(in, out); + } else { + CHECK(false); + } } - CHECK_EQ(r, 0); } } // namespace xpu diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.h b/lite/kernels/xpu/__xpu__multi_encoder_compute.h index fddd810aae4..52c8768a68f 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.h +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.h @@ -34,16 +34,24 @@ class XPUMultiEncoderCompute virtual void Run(); private: - int bert_encoder_run(); - int transformer_encoder_run(); - std::vector arg_fc_weight_int8_; std::vector arg_fc_weight_int16_; std::vector arg_fc_weight_fp32_; std::vector arg_fc_bias_; std::vector arg_ln_scale_; std::vector arg_ln_bias_; - xdnn::EncoderParam encoder_param_; + std::vector fc_weight_max_; + XPUScratchPadGuard weight_max_guard_; + XPUScratchPadGuard cast_in_guard_; + XPUScratchPadGuard cast_out_guard_; + xdnn::Activation_t qkv_act = xdnn::Activation_t::RELU; + int slice_idx = -1; + + template + std::vector *get_weight(); + + template + void run_encoder(const T *in, T *out); }; } // namespace xpu diff --git a/lite/kernels/xpu/assign_value_compute.cc b/lite/kernels/xpu/assign_value_compute.cc index 59a22e5de3f..4af03ff6177 100644 --- a/lite/kernels/xpu/assign_value_compute.cc +++ b/lite/kernels/xpu/assign_value_compute.cc @@ -30,16 +30,16 @@ void AssignValueCompute::Run() { CHECK_GT(param.shape.size(), 0UL); if (dtype == static_cast(lite::core::FluidType::INT32)) { auto* out = param.Out->mutable_data(TARGET(kXPU)); - XPU_CALL(xpu_memcpy(out, - int32_values.data(), - sizeof(int) * int32_values.size(), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + lite::TargetWrapperXPU::MemcpySync(out, + int32_values.data(), + sizeof(int) * int32_values.size(), + IoDirection::HtoD); } else if (dtype == static_cast(lite::core::FluidType::FP32)) { auto* out = param.Out->mutable_data(TARGET(kXPU)); - XPU_CALL(xpu_memcpy(out, - fp32_values.data(), - sizeof(float) * fp32_values.size(), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + lite::TargetWrapperXPU::MemcpySync(out, + fp32_values.data(), + sizeof(float) * fp32_values.size(), + IoDirection::HtoD); } else { LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype; } diff --git a/lite/kernels/xpu/batch_norm_compute.cc b/lite/kernels/xpu/batch_norm_compute.cc index 77048528fc9..4920883b4d8 100644 --- a/lite/kernels/xpu/batch_norm_compute.cc +++ b/lite/kernels/xpu/batch_norm_compute.cc @@ -32,20 +32,21 @@ void BatchNormCompute::Run() { for (int i = 0; i < x_dims.size(); i++) { x_shape[i] = x_dims[i]; } - int r = - xdnn::batch_norm_infer_forward(ctx.GetRawContext(), - epsilon, - x_shape[0], - x_shape[1], - x_shape[2], - x_shape[3], - param.x->data(), - param.y->mutable_data(TARGET(kXPU)), - param.scale->data(), - param.bias->data(), - param.mean->data(), - param.variance->data()); + xdnn::batch_norm_infer(ctx.GetRawContext(), + param.x->data(), + param.y->mutable_data(TARGET(kXPU)), + x_shape[0], + x_shape[1], + x_shape[2], + x_shape[3], + epsilon, + param.scale->data(), + param.bias->data(), + param.mean->data(), + param.variance->data(), + true); + CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/layer_norm_compute.cc b/lite/kernels/xpu/layer_norm_compute.cc index 538ad849d93..47fd4e9986e 100644 --- a/lite/kernels/xpu/layer_norm_compute.cc +++ b/lite/kernels/xpu/layer_norm_compute.cc @@ -31,13 +31,16 @@ void LayerNormCompute::Run() { float epsilon = param.epsilon; int r = xdnn::layer_norm(ctx.GetRawContext(), /* context */ - matrix_dim[0], /* m */ - matrix_dim[1], /* n */ param.X->data(), /* in */ param.Y->mutable_data(TARGET(kXPU)), /* out */ + matrix_dim[0], /* m */ + matrix_dim[1], /* n */ + epsilon, /* epsilon */ param.Scale->data(), /* scale */ param.Bias->data(), /* bias */ - epsilon /* epsilon */); + nullptr, + nullptr); + CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/stack_compute.cc b/lite/kernels/xpu/stack_compute.cc index 99c9e17bc2e..bd929349389 100644 --- a/lite/kernels/xpu/stack_compute.cc +++ b/lite/kernels/xpu/stack_compute.cc @@ -21,15 +21,6 @@ namespace lite { namespace kernels { namespace xpu { -void StackCompute::PrepareForRun() { - auto& param = this->Param(); - - int n = param.X.size(); - x_ptr_guard_ = - TargetWrapperXPU::MallocScratchPad(n * 8 /* sizeof(__global__ float*) */); - x_ptr_cpu_.reserve(n); -} - void StackCompute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->As(); @@ -37,25 +28,26 @@ void StackCompute::Run() { int n = param.X.size(); auto x_dims = param.X[0]->dims(); int axis = param.axis; - // XXX(miaotianxiang): +1? - if (axis < 0) axis += (x_dims.size() + 1); - auto matrix = x_dims.Flatten2D(axis); - int height = matrix[0]; - int width = matrix[1]; + if (axis < 0) { + axis += (x_dims.size() + 1); + } + std::vector x_shape; + auto y_dim = param.Out->dims(); + for (int i = 0; i < y_dim.size(); i++) { + x_shape.push_back(y_dim[i]); + } + x_shape[axis] = 1; + std::vector> xdims_list(n, x_shape); + std::vector x_list(n, nullptr); for (int i = 0; i < n; ++i) { - x_ptr_cpu_[i] = param.X[i]->data(); + x_list[i] = param.X[i]->data(); } - XPU_CALL(xpu_memcpy( - x_ptr_guard_->addr_, &x_ptr_cpu_[0], n * 8, XPU_HOST_TO_DEVICE)); - - int r = xdnn::stack_forward( - ctx.GetRawContext(), /* context */ - height, /* height */ - width, /* width */ - n, /* n */ - x_ptr_guard_->addr_, /* x_ptr */ - param.Out->mutable_data(TARGET(kXPU)) /* out */); + int r = xdnn::concat(ctx.GetRawContext(), + x_list, + param.Out->mutable_data(TARGET(kXPU)), + xdims_list, + axis); CHECK_EQ(r, 0); } diff --git a/lite/kernels/xpu/stack_compute.h b/lite/kernels/xpu/stack_compute.h index 7618e2a147b..00f01b9466a 100644 --- a/lite/kernels/xpu/stack_compute.h +++ b/lite/kernels/xpu/stack_compute.h @@ -27,15 +27,9 @@ class StackCompute : public KernelLite { public: using param_t = operators::StackParam; - virtual void PrepareForRun(); - virtual void Run(); virtual ~StackCompute() = default; - - private: - XPUScratchPadGuard x_ptr_guard_; - std::vector x_ptr_cpu_; }; } // namespace xpu diff --git a/lite/operators/__xpu__embedding_with_eltwise_add_op.cc b/lite/operators/__xpu__embedding_with_eltwise_add_op.cc index 435cabe2c9a..13819d61046 100644 --- a/lite/operators/__xpu__embedding_with_eltwise_add_op.cc +++ b/lite/operators/__xpu__embedding_with_eltwise_add_op.cc @@ -21,18 +21,22 @@ namespace lite { namespace operators { bool XPUEmbeddingWithEltwiseAddOp::CheckShape() const { - CHECK_OR_FALSE(param_.Ids.size() == param_.Tables.size()); - - auto& id_dims = param_.Ids[0]->dims(); - auto& table_dims = param_.Tables[0]->dims(); - - int id_rank = id_dims.size(); - - CHECK_EQ_OR_FALSE(table_dims.size(), 2); - // id_dims must be [batch_size, seq_len] or [batch_size, seq_len, 1] - CHECK(id_rank == 2 || id_rank == 3) << "unsupported id_rank: " << id_rank; - + CHECK_EQ(param_.Ids.size(), param_.Tables.size()); + auto ids_dim = param_.Ids[0]->dims(); + auto id_rank = ids_dim.size(); + CHECK(id_rank == 2 || (id_rank == 3 && ids_dim[2] == 1)) + << "unsupported id_rank: " << id_rank; + for (size_t i = 1; i < param_.Ids.size(); i++) { + CHECK_EQ(id_rank, param_.Ids[i]->dims().size()); + for (size_t j = 0; j < id_rank; j++) { + CHECK_EQ(ids_dim[j], param_.Ids[i]->dims()[j]); + } + } if (param_.Mask != nullptr) { + CHECK_EQ(id_rank, param_.Mask->dims().size()); + for (size_t j = 0; j < id_rank; j++) { + CHECK_EQ(ids_dim[j], param_.Mask->dims()[j]); + } CHECK(param_.SeqLod != nullptr); CHECK(param_.PadSeqLen != nullptr); } From cfc4d7f756766b7847d5aaf9045eedc4fa1d0836 Mon Sep 17 00:00:00 2001 From: weihaoji Date: Mon, 13 Sep 2021 17:12:46 +0800 Subject: [PATCH 3/4] [XPU] refine encoder, test=develop, test=xpu --- lite/api/paddle_api.h | 2 +- lite/backends/xpu/target_wrapper.cc | 5 +++-- lite/backends/xpu/target_wrapper.h | 9 +++++++++ .../mir/fusion/__xpu__multi_encoder_fuse_pass.cc | 2 +- lite/kernels/xpu/__xpu__multi_encoder_compute.cc | 6 ++---- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 67ea3dca2b9..c20aedb9e9c 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -397,7 +397,7 @@ class LITE_API CxxConfig : public ConfigBase { // XPU only, set the size of the workspace memory from L3 cache for the // current thread. // **DEPRECATED**, use set_xpu_l3_cache_method() in the future - void set_xpu_workspace_l3_size_per_thread(int l3_size = 0xfffc00); + void set_xpu_workspace_l3_size_per_thread(int l3_size = 0x4000000); void set_xpu_l3_cache_method(size_t l3_size, bool locked = false); void set_xpu_conv_autotune(bool autotune = true, diff --git a/lite/backends/xpu/target_wrapper.cc b/lite/backends/xpu/target_wrapper.cc index 69ebdce903d..35fd386d77e 100644 --- a/lite/backends/xpu/target_wrapper.cc +++ b/lite/backends/xpu/target_wrapper.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/backends/xpu/target_wrapper.h" +#include #include "lite/utils/macros.h" namespace paddle { @@ -55,7 +56,6 @@ void TargetWrapperXPU::MemcpySync(void* dst, XPU_CALL(xpu_memcpy(dst, src, size, XPU_HOST_TO_DEVICE)); break; case IoDirection::DtoH: - // TODO(weihaoji): remove xpu_wait XPU_CALL(xpu_wait()); XPU_CALL(xpu_memcpy(dst, src, size, XPU_DEVICE_TO_HOST)); break; @@ -120,7 +120,8 @@ void TargetWrapperXPU::FreeL3Cache() { LITE_THREAD_LOCAL std::string TargetWrapperXPU::multi_encoder_precision; // NOLINT -LITE_THREAD_LOCAL size_t TargetWrapperXPU::local_l3_size{0xfffc00}; +LITE_THREAD_LOCAL size_t TargetWrapperXPU::local_l3_size{ + std::numeric_limits::max()}; LITE_THREAD_LOCAL bool TargetWrapperXPU::conv_autotune{false}; LITE_THREAD_LOCAL std::string TargetWrapperXPU::conv_autotune_file; // NOLINT LITE_THREAD_LOCAL xdnn::Context* TargetWrapperXPU::tls_raw_ctx_{nullptr}; diff --git a/lite/backends/xpu/target_wrapper.h b/lite/backends/xpu/target_wrapper.h index d4bf718ec91..ff29e48983a 100644 --- a/lite/backends/xpu/target_wrapper.h +++ b/lite/backends/xpu/target_wrapper.h @@ -83,6 +83,15 @@ class TargetWrapper { tls_raw_ctx_->_xpu1_conv_selector.set_autotune_file( conv_autotune_file.c_str()); } + int devid = -1; + uint64_t max_l3_size = 0; + XPU_CALL(xpu_current_device(&devid)); + XPU_CALL(xpu_device_get_attr( + &max_l3_size, XPUDeviceAttr(XPUATTR_MEM_L3_CAPACITY), devid)); + if (local_l3_size > max_l3_size) { + local_l3_size = max_l3_size; + } + CHECK_LE(shared_l3_size, max_l3_size); } return tls_raw_ctx_; } diff --git a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 80c7acda9d1..c4d48945e9b 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -956,7 +956,7 @@ class XPUMultiEncoderFusePass : public ProgramPass { std::vector matmul_types{"matmul", "matmul_v2"}; std::vector mul_types{"mul", "matmul"}; std::vector with_q_scales{true, false}; - std::vector norm_befores{true, false}; + std::vector norm_befores{false}; std::string fc_precision; bool adaptive_seqlen = false; diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index 6e589b2794a..e9ff62edb30 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -127,9 +127,8 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { in, *(XPUMultiEncoderCompute::get_weight()), out, - nullptr, + {}, fc_weight_max_, - nullptr, arg_fc_bias_, arg_ln_scale_, arg_ln_bias_, @@ -153,9 +152,8 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { in, *(XPUMultiEncoderCompute::get_weight()), out, - nullptr, + {}, fc_weight_max_, - nullptr, arg_fc_bias_, arg_ln_scale_, arg_ln_bias_, From 730b07714e1971feb595b71b85094c9e887bf7da Mon Sep 17 00:00:00 2001 From: weihaoji Date: Wed, 15 Sep 2021 20:25:39 +0800 Subject: [PATCH 4/4] [XPU] kl2 encoder mask bugfix, test=develop, test=xpu --- .../xpu/__xpu__multi_encoder_compute.cc | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index e9ff62edb30..90d64258c82 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -115,12 +115,14 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { query_lod = {param.SeqLod->data(), static_cast(param.SeqLod->numel()), nullptr}; + int max_pad_seqlen = slice_idx == -1 ? param.SeqLod->data()[0] : -1; xdnn::QKVAttnParam qkv_attn_param(query_lod, /* lod */ param.head_num, param.size_per_head, qkv_act, slice_idx, - true /* qkv fusion */); + true /* qkv fusion */, + max_pad_seqlen); int r = xdnn::transformer_encoder( ctx.GetRawContext(), @@ -138,15 +140,19 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { // no vsl int batch = static_cast(param.input->dims()[0]); int max_seqlen = static_cast(param.input->dims()[1]); - xdnn::QKVAttnParam qkv_attn_param( - batch, - max_seqlen, - param.head_num, - param.size_per_head, - {batch, param.head_num, max_seqlen, max_seqlen}, - qkv_act, - slice_idx, - true); + + std::vector mask_shape = param.mask->dims().Vectorize(); + std::vector encoder_mask_shape = + std::vector(mask_shape.begin(), mask_shape.end()); + + xdnn::QKVAttnParam qkv_attn_param(batch, + max_seqlen, + param.head_num, + param.size_per_head, + encoder_mask_shape, + qkv_act, + slice_idx, + true); int r = xdnn::transformer_encoder( ctx.GetRawContext(), in,