From 1f93dbd97c5e32261d6fd6aef6a000ba64eb97f9 Mon Sep 17 00:00:00 2001 From: sprouteer Date: Fri, 24 Feb 2023 16:14:13 +0800 Subject: [PATCH 1/5] fix fuse_attention support old_quant_format test=develop --- .../fusion/transformer_attention_fuse_pass.cc | 2 - .../mir/fusion/transformer_attention_fuser.cc | 27 +++++----- lite/kernels/arm/fused_attention_compute.cc | 37 +++++++------ lite/operators/fused_attention_op.cc | 52 +++++++++++++------ lite/operators/op_params.h | 4 +- 5 files changed, 71 insertions(+), 51 deletions(-) diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc index e90464d43d2..0c3e1fde54b 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc @@ -34,8 +34,6 @@ void TransformerAttentionFusePass::Apply( } if ((has_int8)) { fuser(graph.get()); - } else { - return; } } diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc index a928a837f7c..53b9b6be921 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc @@ -68,8 +68,7 @@ void TransformerAttentionFuser::BuildPattern() { auto res = (trans_x == false && trans_y == false); return res; }; - auto* input0 = - VarNode("input0")->assert_is_op_input("fc", "Input")->AsInput(); + auto* input = VarNode("input")->assert_is_op_input("fc", "Input")->AsInput(); // fc auto* fc0_w = VarNode("fc0_w")->assert_is_op_input("fc", "W"); auto* fc0_bias = VarNode("fc0_bias")->assert_is_op_input("fc", "Bias"); @@ -135,8 +134,9 @@ void TransformerAttentionFuser::BuildPattern() { VarNode("matmul0_out")->assert_is_op_output("matmul_v2", "Out"); // elementwise_add - auto* input1 = - VarNode("input1")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + auto* residual = VarNode("residual") + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); auto* add = OpNode("add", "elementwise_add"); auto* add0_out = VarNode("add0_out")->assert_is_op_output("elementwise_add", "Out"); @@ -152,9 +152,9 @@ void TransformerAttentionFuser::BuildPattern() { auto* Out = VarNode("Out"); - std::vector fc0_inputs{input0, fc0_w, fc0_bias}; - std::vector fc1_inputs{input0, fc1_w, fc1_bias}; - std::vector fc2_inputs{input0, fc2_w, fc2_bias}; + std::vector fc0_inputs{input, fc0_w, fc0_bias}; + std::vector fc1_inputs{input, fc1_w, fc1_bias}; + std::vector fc2_inputs{input, fc2_w, fc2_bias}; fc0_inputs >> *fc0 >> *fc0_out >> *reshape0 >> *reshape0_out >> *transpose0 >> *transpose0_out >> *scale0 >> *scale0_out; fc1_inputs >> *fc1 >> *fc1_out >> *reshape1 >> *reshape1_out >> *transpose1 >> @@ -170,7 +170,7 @@ void TransformerAttentionFuser::BuildPattern() { std::vector matmul0_inputs{scale0_out, transpose1_out}; matmul0_inputs >> *matmul0 >> *matmul0_out; - std::vector add0_inputs{matmul0_out, input1}; + std::vector add0_inputs{matmul0_out, residual}; add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out; std::vector matmul1_inputs{softmax0_out, transpose2_out}; @@ -312,8 +312,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, auto* scope = fc->scope(); // set input - op_desc.SetInput("Input0", {matched.at("input0")->arg()->name}); - op_desc.SetInput("Input1", {matched.at("input1")->arg()->name}); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Residual", {matched.at("residual")->arg()->name}); // fc auto fc0_op_desc = matched.at("fc0")->stmt()->op_info(); @@ -393,6 +393,7 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, scale0_scale, bias0_dims[0]); op_desc.SetAttr>("fc0_scale", fuse_scales); + op_desc.SetAttr>("Input0_scale", fc0_scale_x); // fc 1 auto matmul0_scale_x = matmul0_op_desc->GetAttr>("X0_scale"); @@ -424,6 +425,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, op_desc.SetInput("Bias", {matched.at("fc0_bias")->arg()->name}); op_desc.SetAttr("op_type", fc0_op_desc->GetAttr("op_type")); + op_desc.SetAttr("in_num_col_dims", + fc0_op_desc->GetAttr("in_num_col_dims")); // reshape auto reshape_op_desc = matched.at("reshape0")->stmt()->op_info(); op_desc.SetAttr>( @@ -440,8 +443,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, fused_attention_op->Attach(op_desc, scope); auto* new_op_node = graph->GraphCreateInstructNode(fused_attention_op, valid_places); - DirectedLink(matched.at("input0"), new_op_node); - DirectedLink(matched.at("input1"), new_op_node); + DirectedLink(matched.at("input"), new_op_node); + DirectedLink(matched.at("residual"), new_op_node); DirectedLink(matched.at("fc0_w"), new_op_node); DirectedLink(matched.at("fc0_bias"), new_op_node); DirectedLink(new_op_node, matched.at("Out")); diff --git a/lite/kernels/arm/fused_attention_compute.cc b/lite/kernels/arm/fused_attention_compute.cc index 70731384dbf..e72aa8948aa 100644 --- a/lite/kernels/arm/fused_attention_compute.cc +++ b/lite/kernels/arm/fused_attention_compute.cc @@ -70,7 +70,7 @@ void FusedAttentionCompute::PrepareForRun() { template void FusedAttentionCompute::ReInitWhenNeeded() { auto& param = this->template Param(); - auto input0_dims = param.input0->dims(); + auto input_dims = param.input->dims(); // fc act_param_.has_active = false; @@ -86,29 +86,29 @@ void FusedAttentionCompute::ReInitWhenNeeded() { int in_num_col_dims = param.in_num_col_dims; std::string op_type = param.op_type; if (op_type == "matmul" || op_type == "matmul_v2") { - in_num_col_dims = input0_dims.size() - 1; + in_num_col_dims = input_dims.size() - 1; } - fc_m_ = input0_dims.Slice(0, in_num_col_dims).production(); - fc_k_ = input0_dims.Slice(in_num_col_dims, input0_dims.size()).production(); + fc_m_ = input_dims.Slice(0, in_num_col_dims).production(); + fc_k_ = input_dims.Slice(in_num_col_dims, input_dims.size()).production(); CHECK_EQ(fc_k_, w_dims[0]); fc_n_ = w_dims[1]; - fc_dims_ = DDim(std::vector{input0_dims[0], fc_m_, fc_n_}); + fc_dims_ = DDim(std::vector{input_dims[0], fc_m_, fc_n_}); // reshape - reshape_shape_.push_back(input0_dims[0]); + reshape_shape_.push_back(input_dims[0]); reshape_shape_.push_back(param.reshape_shape[2]); - reshape_shape_.push_back(input0_dims[1]); + reshape_shape_.push_back(input_dims[1]); reshape_shape_.push_back(param.reshape_shape[3]); // transpose transpose_out_dim_ = DDim(std::vector{ - input0_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3]}); + input_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3]}); // fc1 fc1_m_ = transpose_out_dim_[2]; fc1_n_ = transpose_out_dim_[2]; fc1_k_ = transpose_out_dim_[3]; fc1_out_dim_ = DDim(std::vector{ - input0_dims[0], transpose_out_dim_[1], fc1_m_, fc1_n_}); + input_dims[0], transpose_out_dim_[1], fc1_m_, fc1_n_}); // softmax softmax_out_dim_ = fc1_out_dim_; @@ -135,10 +135,10 @@ template <> void FusedAttentionCompute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); - auto* input0_data = param.input0->data(); - auto* input1_data = param.input1->data(); + auto* input_data = param.input->data(); + auto* residual_data = param.residual->data(); auto* o_data = param.output->mutable_data(); - auto input0_dims = param.input0->dims(); + auto input_dims = param.input->dims(); auto out_dims = param.output->dims(); // fc + dequant_scale, bias, quant_scale @@ -153,7 +153,7 @@ void FusedAttentionCompute::Run() { fc_m_, fc_n_, fc_k_, - input0_data, + input_data, w_data, fc_out, b_data, @@ -164,7 +164,7 @@ void FusedAttentionCompute::Run() { &ctx); // transpose2 fuse reshape2 DDim trans_dims = DDim(std::vector{ - input0_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3] * 3}); + input_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3] * 3}); Tensor trans_t; trans_t.Resize(trans_dims); trans_t.mutable_data(); @@ -174,7 +174,7 @@ void FusedAttentionCompute::Run() { auto* v1 = v0 + stride; auto* v2 = v1 + stride; TransposeCompute_1to3( - fc_out, v0, v1, v2, input0_dims[0], fc_m_, fc_n_, transpose_out_dim_[3]); + fc_out, v0, v1, v2, input_dims[0], fc_m_, fc_n_, transpose_out_dim_[3]); // fc -> out fp32 Tensor fc1_t; fc1_t.Resize(fc1_out_dim_); @@ -183,7 +183,7 @@ void FusedAttentionCompute::Run() { int x_inner = fc1_m_ * fc1_k_; int y_inner = fc1_k_ * fc1_n_; int out_inner = fc1_m_ * fc1_n_; - auto* fc1_b_data = param.input1->data(); + auto* fc1_b_data = param.residual->data(); for (size_t i = 0; i < transpose_out_dim_[1]; ++i) { lite::arm::math::gemm_s8(false, true, @@ -263,9 +263,8 @@ typedef paddle::lite::kernels::arm::FusedAttentionCompute FusedAttentionCompute_Int8; REGISTER_LITE_KERNEL( fused_attention, kARM, kInt8, kNCHW, FusedAttentionCompute_Int8, def) - .BindInput("Input0", - {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) - .BindInput("Input1", + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Residual", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) diff --git a/lite/operators/fused_attention_op.cc b/lite/operators/fused_attention_op.cc index 2c34d42d019..040b755a4ac 100644 --- a/lite/operators/fused_attention_op.cc +++ b/lite/operators/fused_attention_op.cc @@ -20,12 +20,12 @@ namespace lite { namespace operators { bool FusedAttentionOpLite::CheckShape() const { - CHECK_OR_FALSE(param_.input0); - CHECK_OR_FALSE(param_.input1); + CHECK_OR_FALSE(param_.input); + CHECK_OR_FALSE(param_.residual); CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.fc_w); - const auto input_dims = param_.input0->dims(); + const auto input_dims = param_.input->dims(); const auto w_dims = param_.fc_w->dims(); CHECK_EQ_OR_FALSE(w_dims.size(), 2UL); int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; @@ -58,9 +58,28 @@ static bool CheckPositive(const DDim &dims) { } return true; } + bool FusedAttentionOpLite::InferShape() { - lite::DDim x_dims = param_.input0->dims(); - const DDim::value_type input_size = x_dims.production(); + lite::DDim x_dims = param_.input->dims(); + + // infer fc + int in_num_col_dims = param_.in_num_col_dims; + std::string op_type = param_.op_type; + const auto &w_dims = param_.fc_w->dims(); + int64_t w_dims_1 = w_dims[1] / 3; + + if (op_type == "matmul" || op_type == "matmul_v2") { + in_num_col_dims = x_dims.size() - 1; + } + DDim::value_type fc_output_size = 1; + std::vector fc_output_dims(in_num_col_dims + 1); + for (int i = 0; i < in_num_col_dims; ++i) { + fc_output_dims[i] = x_dims[i]; + fc_output_size *= fc_output_dims[i]; + } + fc_output_dims[in_num_col_dims] = w_dims_1; + fc_output_size *= fc_output_dims[in_num_col_dims]; + std::vector shape = param_.reshape_shape; std::vector reshape_output_dims(shape.size()); DDim::value_type capacity = 1; @@ -73,7 +92,7 @@ bool FusedAttentionOpLite::InferShape() { << "Only one input dimension of Attr(shape) can be unknown."; unk_dim_idx = i; } else if (shape[i] == copy_dim_val) { - CHECK_LT(i, x_dims.size()) + CHECK_LT(i, fc_output_dims.size()) << "The index of dimension to copy from input shape must be less " "than the size of input shape."; } else { @@ -82,24 +101,24 @@ bool FusedAttentionOpLite::InferShape() { } DDim::value_type output_dim_i = - shape[i] ? static_cast(shape[i]) : x_dims[i]; + shape[i] ? static_cast(shape[i]) : fc_output_dims[i]; reshape_output_dims[i] = output_dim_i; capacity *= output_dim_i; } if (unk_dim_idx != -1) { - if (CheckPositive(x_dims)) { + if (CheckPositive(lite::DDim(fc_output_dims))) { // input_size < 0 and is un-determinate in compile time, skip the check, // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // capacity = -24, input_size = -8, output_shape[0] = 0 // the following check will fail. - reshape_output_dims[unk_dim_idx] = -input_size / capacity; - CHECK_EQ(reshape_output_dims[unk_dim_idx] * capacity, -input_size) + reshape_output_dims[unk_dim_idx] = -fc_output_size / capacity; + CHECK_EQ(reshape_output_dims[unk_dim_idx] * capacity, -fc_output_size) << "Invalid shape is given."; } else { reshape_output_dims[unk_dim_idx] = -1; } } else { - CHECK_EQ(capacity, input_size) << "Invalid shape is given."; + CHECK_EQ(capacity, fc_output_size) << "Invalid shape is given."; } lite::DDim out_dims = lite::DDim({reshape_output_dims[0], reshape_output_dims[2], @@ -108,19 +127,19 @@ bool FusedAttentionOpLite::InferShape() { param_.output->Resize(out_dims); // share LoD - param_.output->set_lod(param_.input0->lod()); + param_.output->set_lod(param_.input->lod()); return true; } bool FusedAttentionOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { - auto input0 = op_desc.Input("Input0").front(); - auto input1 = op_desc.Input("Input1").front(); + auto input = op_desc.Input("Input").front(); + auto residual = op_desc.Input("Residual").front(); auto fc_w = op_desc.Input("W").front(); auto output = op_desc.Output("Out").front(); - param_.input0 = scope->FindVar(input0)->GetMutable(); - param_.input1 = scope->FindVar(input1)->GetMutable(); + param_.input = scope->FindVar(input)->GetMutable(); + param_.residual = scope->FindVar(residual)->GetMutable(); param_.fc_w = scope->FindVar(fc_w)->GetMutable(); param_.output = scope->FindVar(output)->GetMutable(); @@ -135,6 +154,7 @@ bool FusedAttentionOpLite::AttachImpl(const cpp::OpDesc &op_desc, } } } + param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims"); param_.reshape_shape = op_desc.GetAttr>("reshape_shape"); param_.softmax_axis = op_desc.GetAttr("softmax_axis"); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 052ad9e9ad0..abd9227f8c0 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -118,8 +118,8 @@ struct FcParam : ParamBase { }; struct FusedAttentionParam : ParamBase { - lite::Tensor* input0{nullptr}; - lite::Tensor* input1{nullptr}; + lite::Tensor* input{nullptr}; + lite::Tensor* residual{nullptr}; lite::Tensor* fc_w{nullptr}; lite::Tensor* fc_bias{nullptr}; lite::Tensor* output{nullptr}; From be115eae60332798f9c8ab266199e4320bd1065f Mon Sep 17 00:00:00 2001 From: sprouteer Date: Wed, 1 Mar 2023 21:29:10 +0800 Subject: [PATCH 2/5] fix dropout support no_mask, reshape2 transpose2 xshape test=develop --- .../identity_dropout_eliminate_pass.cc | 19 ++++-- .../fusion/transformer_attention_fuse_pass.cc | 13 +++- .../mir/fusion/transformer_attention_fuser.cc | 59 ++++++++++++------- .../mir/fusion/transformer_attention_fuser.h | 6 ++ 4 files changed, 68 insertions(+), 29 deletions(-) diff --git a/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc b/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc index cb6b38e4100..4862bccc8c5 100644 --- a/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc +++ b/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc @@ -24,6 +24,8 @@ namespace { class Eliminator : public FuseBase { public: + explicit Eliminator(bool has_mask) : has_mask_(has_mask) {} + static bool DropoutIsTest(const Node* x) { if (x && x->IsStmt()) { auto* op_info = x->stmt()->op_info(); @@ -51,15 +53,18 @@ class Eliminator : public FuseBase { ->assert_op_attr( "dropout_implementation", "upscale_in_train"); auto* out = VarNode("out")->assert_is_op_output("dropout", "Out"); - auto* mask = VarNode("mask")->assert_is_op_output("dropout", "Mask"); + PMNode* mask = nullptr; + if (has_mask_) { + mask = VarNode("mask")->assert_is_op_output("dropout", "Mask"); + } *pre_op >> *x >> *dropout_op >> *out; - *dropout_op >> *mask; + if (mask) *dropout_op >> *mask; // The pre_op will be eliminated, and a new output-updated op will insert. x->AsIntermediate(); // x is pre_op's output, need to update dropout_op->AsIntermediate(); - mask->AsIntermediate(); + if (mask) mask->AsIntermediate(); } private: @@ -73,6 +78,7 @@ class Eliminator : public FuseBase { IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); } + bool has_mask_; }; } // namespace @@ -80,8 +86,11 @@ class Eliminator : public FuseBase { class IdentityDropoutEliminatePass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override { - Eliminator eliminator; - eliminator(graph.get()); + std::vector has_masks = {false, true}; + for (auto has_mask : has_masks) { + Eliminator eliminator(has_mask); + eliminator(graph.get()); + } } }; diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc index 0c3e1fde54b..517d11b5651 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc @@ -25,15 +25,22 @@ namespace mir { void TransformerAttentionFusePass::Apply( const std::unique_ptr& graph) { - fusion::TransformerAttentionFuser fuser; bool has_int8 = false; for (auto& place : graph->valid_places()) { if (place.precision == PRECISION(kInt8)) { has_int8 = true; } } - if ((has_int8)) { - fuser(graph.get()); + std::vector reshape_has_xshapes = {false, true}; + std::vector transpose_has_xshapes = {false, true}; + for (auto reshape_has_xshape : reshape_has_xshapes) { + for (auto transpose_has_xshape : transpose_has_xshapes) { + fusion::TransformerAttentionFuser fuser(reshape_has_xshape, + transpose_has_xshape); + if ((has_int8)) { + fuser(graph.get()); + } + } } } diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc index 53b9b6be921..44822e6f2ea 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc @@ -98,30 +98,39 @@ void TransformerAttentionFuser::BuildPattern() { auto* reshape2_out = VarNode("reshape2_out")->assert_is_op_output("reshape2", "Out"); - auto* xshape0 = VarNode("xshape0")->assert_is_op_output("reshape2", "XShape"); - auto* xshape1 = VarNode("xshape1")->assert_is_op_output("reshape2", "XShape"); - auto* xshape2 = VarNode("xshape2")->assert_is_op_output("reshape2", "XShape"); + PMNode* xshape0 = nullptr; + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + if (reshape_has_xshape_) { + xshape0 = VarNode("xshape0")->assert_is_op_output("reshape2", "XShape"); + xshape1 = VarNode("xshape1")->assert_is_op_output("reshape2", "XShape"); + xshape2 = VarNode("xshape2")->assert_is_op_output("reshape2", "XShape"); + } + // transpose2 auto* transpose0 = OpNode("transpose0", "transpose2") ->assert_op_attr("axis", std::vector{0, 2, 1, 3}); auto* transpose0_out = VarNode("transpose0_out")->assert_is_op_output("transpose2", "Out"); - auto* xshape3 = - VarNode("xshape3")->assert_is_op_output("transpose2", "XShape"); auto* transpose1 = OpNode("transpose1", "transpose2") ->assert_op_attr("axis", std::vector{0, 2, 1, 3}); auto* transpose1_out = VarNode("transpose1_out")->assert_is_op_output("transpose2", "Out"); - auto* xshape4 = - VarNode("xshape4")->assert_is_op_output("transpose2", "XShape"); auto* transpose2 = OpNode("transpose2", "transpose2") ->assert_op_attr("axis", std::vector{0, 2, 1, 3}); auto* transpose2_out = VarNode("transpose2_out")->assert_is_op_output("transpose2", "Out"); - auto* xshape5 = - VarNode("xshape5")->assert_is_op_output("transpose2", "XShape"); + + PMNode* xshape3 = nullptr; + PMNode* xshape4 = nullptr; + PMNode* xshape5 = nullptr; + if (transpose_has_xshape_) { + xshape3 = VarNode("xshape3")->assert_is_op_output("transpose2", "XShape"); + xshape4 = VarNode("xshape4")->assert_is_op_output("transpose2", "XShape"); + xshape5 = VarNode("xshape5")->assert_is_op_output("transpose2", "XShape"); + } // scale auto* scale0 = OpNode("scale0", "scale"); @@ -161,12 +170,16 @@ void TransformerAttentionFuser::BuildPattern() { *transpose1_out; fc2_inputs >> *fc2 >> *fc2_out >> *reshape2 >> *reshape2_out >> *transpose2 >> *transpose2_out; - *reshape0 >> *xshape0; - *reshape1 >> *xshape1; - *reshape2 >> *xshape2; - *transpose0 >> *xshape3; - *transpose1 >> *xshape4; - *transpose2 >> *xshape5; + if (reshape_has_xshape_) { + *reshape0 >> *xshape0; + *reshape1 >> *xshape1; + *reshape2 >> *xshape2; + } + if (transpose_has_xshape_) { + *transpose0 >> *xshape3; + *transpose1 >> *xshape4; + *transpose2 >> *xshape5; + } std::vector matmul0_inputs{scale0_out, transpose1_out}; matmul0_inputs >> *matmul0 >> *matmul0_out; @@ -176,12 +189,16 @@ void TransformerAttentionFuser::BuildPattern() { std::vector matmul1_inputs{softmax0_out, transpose2_out}; matmul1_inputs >> *matmul1 >> *Out; - xshape0->AsIntermediate(); - xshape1->AsIntermediate(); - xshape2->AsIntermediate(); - xshape3->AsIntermediate(); - xshape4->AsIntermediate(); - xshape5->AsIntermediate(); + if (reshape_has_xshape_) { + xshape0->AsIntermediate(); + xshape1->AsIntermediate(); + xshape2->AsIntermediate(); + } + if (transpose_has_xshape_) { + xshape3->AsIntermediate(); + xshape4->AsIntermediate(); + xshape5->AsIntermediate(); + } fc0->AsIntermediate(); fc0_out->AsIntermediate(); reshape0->AsIntermediate(); diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h index c944192c0ec..95ecf7ad207 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h @@ -25,10 +25,16 @@ namespace fusion { class TransformerAttentionFuser : public FuseBase { public: + explicit TransformerAttentionFuser(bool reshape_has_xshape, + bool transpose_has_xshape) + : reshape_has_xshape_(reshape_has_xshape), + transpose_has_xshape_(transpose_has_xshape) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: + bool reshape_has_xshape_; + bool transpose_has_xshape_; }; } // namespace fusion From 7cebdf8186fa9e4956eee8e17e023cf2d651d181 Mon Sep 17 00:00:00 2001 From: sprouteer Date: Mon, 6 Mar 2023 19:24:48 +0800 Subject: [PATCH 3/5] fix conv gemm_sve bug test=develop --- lite/backends/arm/math/conv_impl.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 3a292302371..85151ddee99 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -735,7 +735,6 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, n, k, flag_bias, - GemmMBias, false, scale_group, act_param, @@ -1605,7 +1604,6 @@ void conv_im2col_gemm_int8(const int8_t* i_data, n, k, flag_bias, - GemmMBias, false, scale_group, act_param, From 1089c5633851eded6a90e3b5ebfbf83c11099d1f Mon Sep 17 00:00:00 2001 From: sprouteer Date: Mon, 6 Mar 2023 21:22:37 +0800 Subject: [PATCH 4/5] fix fuse_transformer support matmul op test=develop --- .../fusion/transformer_attention_fuse_pass.cc | 11 +++-- .../mir/fusion/transformer_attention_fuser.cc | 40 +++++++++++++------ .../mir/fusion/transformer_attention_fuser.h | 7 +++- 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc index 517d11b5651..d3f5fb44849 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc @@ -33,12 +33,15 @@ void TransformerAttentionFusePass::Apply( } std::vector reshape_has_xshapes = {false, true}; std::vector transpose_has_xshapes = {false, true}; + std::vector mul_types = {"matmul", "matmul_v2"}; for (auto reshape_has_xshape : reshape_has_xshapes) { for (auto transpose_has_xshape : transpose_has_xshapes) { - fusion::TransformerAttentionFuser fuser(reshape_has_xshape, - transpose_has_xshape); - if ((has_int8)) { - fuser(graph.get()); + for (auto mul_type : mul_types) { + fusion::TransformerAttentionFuser fuser( + reshape_has_xshape, transpose_has_xshape, mul_type); + if ((has_int8)) { + fuser(graph.get()); + } } } } diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc index 44822e6f2ea..94d841e943c 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc @@ -38,7 +38,7 @@ namespace fusion { * scale | | * \ / | * \ / | -* matmul_v2 / +* matmul_v2/matmul / * \ / * \ / * elementwise_add / @@ -47,7 +47,7 @@ namespace fusion { * softmax / * \ / * \ / -* matmul_v2 +* matmul_v2/matmul * | * | * output @@ -56,15 +56,29 @@ namespace fusion { void TransformerAttentionFuser::BuildPattern() { auto matmul0_attr_teller = [](const Node* node) -> bool { auto op_desc = *const_cast(node)->stmt()->op_info(); - auto trans_x = op_desc.GetAttr("trans_x"); - auto trans_y = op_desc.GetAttr("trans_y"); + bool trans_x; + bool trans_y; + if (op_desc.Type() == "matmul") { + trans_x = op_desc.GetAttr("transpose_X"); + trans_y = op_desc.GetAttr("transpose_Y"); + } else { + trans_x = op_desc.GetAttr("trans_x"); + trans_y = op_desc.GetAttr("trans_y"); + } auto res = (trans_x == false && trans_y == true); return res; }; auto matmul1_attr_teller = [](const Node* node) -> bool { auto op_desc = *const_cast(node)->stmt()->op_info(); - auto trans_x = op_desc.GetAttr("trans_x"); - auto trans_y = op_desc.GetAttr("trans_y"); + bool trans_x; + bool trans_y; + if (op_desc.Type() == "matmul") { + trans_x = op_desc.GetAttr("transpose_X"); + trans_y = op_desc.GetAttr("transpose_Y"); + } else { + trans_x = op_desc.GetAttr("trans_x"); + trans_y = op_desc.GetAttr("trans_y"); + } auto res = (trans_x == false && trans_y == false); return res; }; @@ -136,11 +150,11 @@ void TransformerAttentionFuser::BuildPattern() { auto* scale0 = OpNode("scale0", "scale"); auto* scale0_out = VarNode("scale0_out")->assert_is_op_output("scale", "Out"); - // matmul_v2 - auto* matmul0 = OpNode("matmul0", "matmul_v2") - ->assert_node_satisfied(matmul0_attr_teller); + // matmul + auto* matmul0 = + OpNode("matmul0", mul_type_)->assert_node_satisfied(matmul0_attr_teller); auto* matmul0_out = - VarNode("matmul0_out")->assert_is_op_output("matmul_v2", "Out"); + VarNode("matmul0_out")->assert_is_op_output(mul_type_, "Out"); // elementwise_add auto* residual = VarNode("residual") @@ -155,9 +169,9 @@ void TransformerAttentionFuser::BuildPattern() { auto* softmax0_out = VarNode("softmax0_out")->assert_is_op_output("softmax", "Out"); - // matmul_v2 - auto* matmul1 = OpNode("matmul1", "matmul_v2") - ->assert_node_satisfied(matmul1_attr_teller); + // matmul + auto* matmul1 = + OpNode("matmul1", mul_type_)->assert_node_satisfied(matmul1_attr_teller); auto* Out = VarNode("Out"); diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h index 95ecf7ad207..d833909fb6b 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h @@ -26,15 +26,18 @@ namespace fusion { class TransformerAttentionFuser : public FuseBase { public: explicit TransformerAttentionFuser(bool reshape_has_xshape, - bool transpose_has_xshape) + bool transpose_has_xshape, + std::string mul_type) : reshape_has_xshape_(reshape_has_xshape), - transpose_has_xshape_(transpose_has_xshape) {} + transpose_has_xshape_(transpose_has_xshape), + mul_type_(mul_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: bool reshape_has_xshape_; bool transpose_has_xshape_; + std::string mul_type_; }; } // namespace fusion From 7bc2da2525815d87ccdb9b00c290e51f550155f9 Mon Sep 17 00:00:00 2001 From: sprouteer Date: Tue, 7 Mar 2023 13:31:27 +0800 Subject: [PATCH 5/5] fix dropout bug test=develop --- .../identity_dropout_eliminate_pass.cc | 19 +++----- .../fusion/transformer_attention_fuse_pass.cc | 13 +++--- .../mir/fusion/transformer_attention_fuser.cc | 44 ++++++++++++++----- .../mir/fusion/transformer_attention_fuser.h | 3 ++ lite/core/optimizer/optimizer.cc | 2 +- 5 files changed, 50 insertions(+), 31 deletions(-) diff --git a/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc b/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc index 4862bccc8c5..cb6b38e4100 100644 --- a/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc +++ b/lite/core/optimizer/mir/elimination/identity_dropout_eliminate_pass.cc @@ -24,8 +24,6 @@ namespace { class Eliminator : public FuseBase { public: - explicit Eliminator(bool has_mask) : has_mask_(has_mask) {} - static bool DropoutIsTest(const Node* x) { if (x && x->IsStmt()) { auto* op_info = x->stmt()->op_info(); @@ -53,18 +51,15 @@ class Eliminator : public FuseBase { ->assert_op_attr( "dropout_implementation", "upscale_in_train"); auto* out = VarNode("out")->assert_is_op_output("dropout", "Out"); - PMNode* mask = nullptr; - if (has_mask_) { - mask = VarNode("mask")->assert_is_op_output("dropout", "Mask"); - } + auto* mask = VarNode("mask")->assert_is_op_output("dropout", "Mask"); *pre_op >> *x >> *dropout_op >> *out; - if (mask) *dropout_op >> *mask; + *dropout_op >> *mask; // The pre_op will be eliminated, and a new output-updated op will insert. x->AsIntermediate(); // x is pre_op's output, need to update dropout_op->AsIntermediate(); - if (mask) mask->AsIntermediate(); + mask->AsIntermediate(); } private: @@ -78,7 +73,6 @@ class Eliminator : public FuseBase { IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); } - bool has_mask_; }; } // namespace @@ -86,11 +80,8 @@ class Eliminator : public FuseBase { class IdentityDropoutEliminatePass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override { - std::vector has_masks = {false, true}; - for (auto has_mask : has_masks) { - Eliminator eliminator(has_mask); - eliminator(graph.get()); - } + Eliminator eliminator; + eliminator(graph.get()); } }; diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc index d3f5fb44849..d7da59b6869 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc @@ -33,14 +33,17 @@ void TransformerAttentionFusePass::Apply( } std::vector reshape_has_xshapes = {false, true}; std::vector transpose_has_xshapes = {false, true}; + std::vector dropout_masks = {false, true}; std::vector mul_types = {"matmul", "matmul_v2"}; for (auto reshape_has_xshape : reshape_has_xshapes) { for (auto transpose_has_xshape : transpose_has_xshapes) { - for (auto mul_type : mul_types) { - fusion::TransformerAttentionFuser fuser( - reshape_has_xshape, transpose_has_xshape, mul_type); - if ((has_int8)) { - fuser(graph.get()); + for (auto dropout_mask : dropout_masks) { + for (auto mul_type : mul_types) { + fusion::TransformerAttentionFuser fuser( + reshape_has_xshape, transpose_has_xshape, dropout_mask, mul_type); + if ((has_int8)) { + fuser(graph.get()); + } } } } diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc index 94d841e943c..c5906c9666f 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc @@ -38,15 +38,18 @@ namespace fusion { * scale | | * \ / | * \ / | -* matmul_v2/matmul / -* \ / -* \ / -* elementwise_add / -* \ / -* \ / -* softmax / -* \ / -* \ / +* matmul_v2/matmul | +* \ / +* \ / +* elementwise_add / +* \ / +* \ / +* softmax / +* | / +* | / +* dropout / +* \ / +* \ / * matmul_v2/matmul * | * | @@ -169,6 +172,15 @@ void TransformerAttentionFuser::BuildPattern() { auto* softmax0_out = VarNode("softmax0_out")->assert_is_op_output("softmax", "Out"); + // dropout + auto* dropout = OpNode("dropout", "dropout"); + auto* dropout_out = + VarNode("dropout_out")->assert_is_op_output("dropout", "Out"); + PMNode* mask_out = nullptr; + if (dropout_mask_) { + mask_out = VarNode("mask_out")->assert_is_op_output("dropout", "Mask"); + } + // matmul auto* matmul1 = OpNode("matmul1", mul_type_)->assert_node_satisfied(matmul1_attr_teller); @@ -198,9 +210,14 @@ void TransformerAttentionFuser::BuildPattern() { std::vector matmul0_inputs{scale0_out, transpose1_out}; matmul0_inputs >> *matmul0 >> *matmul0_out; std::vector add0_inputs{matmul0_out, residual}; - add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out; + add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out >> *dropout >> + *dropout_out; - std::vector matmul1_inputs{softmax0_out, transpose2_out}; + if (dropout_mask_) { + *dropout >> *mask_out; + } + + std::vector matmul1_inputs{dropout_out, transpose2_out}; matmul1_inputs >> *matmul1 >> *Out; if (reshape_has_xshape_) { @@ -213,6 +230,9 @@ void TransformerAttentionFuser::BuildPattern() { xshape4->AsIntermediate(); xshape5->AsIntermediate(); } + if (dropout_mask_) { + mask_out->AsIntermediate(); + } fc0->AsIntermediate(); fc0_out->AsIntermediate(); reshape0->AsIntermediate(); @@ -239,6 +259,8 @@ void TransformerAttentionFuser::BuildPattern() { add0_out->AsIntermediate(); softmax0->AsIntermediate(); softmax0_out->AsIntermediate(); + dropout->AsIntermediate(); + dropout_out->AsIntermediate(); matmul1->AsIntermediate(); } diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h index d833909fb6b..adcb392319b 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h @@ -27,9 +27,11 @@ class TransformerAttentionFuser : public FuseBase { public: explicit TransformerAttentionFuser(bool reshape_has_xshape, bool transpose_has_xshape, + bool dropout_mask, std::string mul_type) : reshape_has_xshape_(reshape_has_xshape), transpose_has_xshape_(transpose_has_xshape), + dropout_mask_(dropout_mask), mul_type_(mul_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; @@ -37,6 +39,7 @@ class TransformerAttentionFuser : public FuseBase { private: bool reshape_has_xshape_; bool transpose_has_xshape_; + bool dropout_mask_; std::string mul_type_; }; diff --git a/lite/core/optimizer/optimizer.cc b/lite/core/optimizer/optimizer.cc index 0229cce86cb..00f0834de3e 100644 --- a/lite/core/optimizer/optimizer.cc +++ b/lite/core/optimizer/optimizer.cc @@ -187,11 +187,11 @@ std::unique_ptr RunDefaultOptimizer( "lite_elementwise_activation_fuse_pass", "lite_conv_scale_fuse_pass", "lite_conv_elementwise_tree_fuse_pass", + "transformer_attention_fuse_pass", "lite_greater_than_cast_fuse_pass", "identity_dropout_eliminate_pass", "sparse_conv_detect_pass", // "keepdims_convert_pass", - "transformer_attention_fuse_pass", "__xpu__max_pooling_pad_zero_detect_fuse_pass", "__xpu__graph_dedup_pass", "__xpu__resnet_fuse_pass",