diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 3a292302371..85151ddee99 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -735,7 +735,6 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, n, k, flag_bias, - GemmMBias, false, scale_group, act_param, @@ -1605,7 +1604,6 @@ void conv_im2col_gemm_int8(const int8_t* i_data, n, k, flag_bias, - GemmMBias, false, scale_group, act_param, diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc index e90464d43d2..d7da59b6869 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuse_pass.cc @@ -25,17 +25,28 @@ namespace mir { void TransformerAttentionFusePass::Apply( const std::unique_ptr& graph) { - fusion::TransformerAttentionFuser fuser; bool has_int8 = false; for (auto& place : graph->valid_places()) { if (place.precision == PRECISION(kInt8)) { has_int8 = true; } } - if ((has_int8)) { - fuser(graph.get()); - } else { - return; + std::vector reshape_has_xshapes = {false, true}; + std::vector transpose_has_xshapes = {false, true}; + std::vector dropout_masks = {false, true}; + std::vector mul_types = {"matmul", "matmul_v2"}; + for (auto reshape_has_xshape : reshape_has_xshapes) { + for (auto transpose_has_xshape : transpose_has_xshapes) { + for (auto dropout_mask : dropout_masks) { + for (auto mul_type : mul_types) { + fusion::TransformerAttentionFuser fuser( + reshape_has_xshape, transpose_has_xshape, dropout_mask, mul_type); + if ((has_int8)) { + fuser(graph.get()); + } + } + } + } } } diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc index a928a837f7c..c5906c9666f 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc @@ -38,16 +38,19 @@ namespace fusion { * scale | | * \ / | * \ / | -* matmul_v2 / -* \ / -* \ / -* elementwise_add / -* \ / -* \ / -* softmax / -* \ / -* \ / -* matmul_v2 +* matmul_v2/matmul | +* \ / +* \ / +* elementwise_add / +* \ / +* \ / +* softmax / +* | / +* | / +* dropout / +* \ / +* \ / +* matmul_v2/matmul * | * | * output @@ -56,20 +59,33 @@ namespace fusion { void TransformerAttentionFuser::BuildPattern() { auto matmul0_attr_teller = [](const Node* node) -> bool { auto op_desc = *const_cast(node)->stmt()->op_info(); - auto trans_x = op_desc.GetAttr("trans_x"); - auto trans_y = op_desc.GetAttr("trans_y"); + bool trans_x; + bool trans_y; + if (op_desc.Type() == "matmul") { + trans_x = op_desc.GetAttr("transpose_X"); + trans_y = op_desc.GetAttr("transpose_Y"); + } else { + trans_x = op_desc.GetAttr("trans_x"); + trans_y = op_desc.GetAttr("trans_y"); + } auto res = (trans_x == false && trans_y == true); return res; }; auto matmul1_attr_teller = [](const Node* node) -> bool { auto op_desc = *const_cast(node)->stmt()->op_info(); - auto trans_x = op_desc.GetAttr("trans_x"); - auto trans_y = op_desc.GetAttr("trans_y"); + bool trans_x; + bool trans_y; + if (op_desc.Type() == "matmul") { + trans_x = op_desc.GetAttr("transpose_X"); + trans_y = op_desc.GetAttr("transpose_Y"); + } else { + trans_x = op_desc.GetAttr("trans_x"); + trans_y = op_desc.GetAttr("trans_y"); + } auto res = (trans_x == false && trans_y == false); return res; }; - auto* input0 = - VarNode("input0")->assert_is_op_input("fc", "Input")->AsInput(); + auto* input = VarNode("input")->assert_is_op_input("fc", "Input")->AsInput(); // fc auto* fc0_w = VarNode("fc0_w")->assert_is_op_input("fc", "W"); auto* fc0_bias = VarNode("fc0_bias")->assert_is_op_input("fc", "Bias"); @@ -99,44 +115,54 @@ void TransformerAttentionFuser::BuildPattern() { auto* reshape2_out = VarNode("reshape2_out")->assert_is_op_output("reshape2", "Out"); - auto* xshape0 = VarNode("xshape0")->assert_is_op_output("reshape2", "XShape"); - auto* xshape1 = VarNode("xshape1")->assert_is_op_output("reshape2", "XShape"); - auto* xshape2 = VarNode("xshape2")->assert_is_op_output("reshape2", "XShape"); + PMNode* xshape0 = nullptr; + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + if (reshape_has_xshape_) { + xshape0 = VarNode("xshape0")->assert_is_op_output("reshape2", "XShape"); + xshape1 = VarNode("xshape1")->assert_is_op_output("reshape2", "XShape"); + xshape2 = VarNode("xshape2")->assert_is_op_output("reshape2", "XShape"); + } + // transpose2 auto* transpose0 = OpNode("transpose0", "transpose2") ->assert_op_attr("axis", std::vector{0, 2, 1, 3}); auto* transpose0_out = VarNode("transpose0_out")->assert_is_op_output("transpose2", "Out"); - auto* xshape3 = - VarNode("xshape3")->assert_is_op_output("transpose2", "XShape"); auto* transpose1 = OpNode("transpose1", "transpose2") ->assert_op_attr("axis", std::vector{0, 2, 1, 3}); auto* transpose1_out = VarNode("transpose1_out")->assert_is_op_output("transpose2", "Out"); - auto* xshape4 = - VarNode("xshape4")->assert_is_op_output("transpose2", "XShape"); auto* transpose2 = OpNode("transpose2", "transpose2") ->assert_op_attr("axis", std::vector{0, 2, 1, 3}); auto* transpose2_out = VarNode("transpose2_out")->assert_is_op_output("transpose2", "Out"); - auto* xshape5 = - VarNode("xshape5")->assert_is_op_output("transpose2", "XShape"); + + PMNode* xshape3 = nullptr; + PMNode* xshape4 = nullptr; + PMNode* xshape5 = nullptr; + if (transpose_has_xshape_) { + xshape3 = VarNode("xshape3")->assert_is_op_output("transpose2", "XShape"); + xshape4 = VarNode("xshape4")->assert_is_op_output("transpose2", "XShape"); + xshape5 = VarNode("xshape5")->assert_is_op_output("transpose2", "XShape"); + } // scale auto* scale0 = OpNode("scale0", "scale"); auto* scale0_out = VarNode("scale0_out")->assert_is_op_output("scale", "Out"); - // matmul_v2 - auto* matmul0 = OpNode("matmul0", "matmul_v2") - ->assert_node_satisfied(matmul0_attr_teller); + // matmul + auto* matmul0 = + OpNode("matmul0", mul_type_)->assert_node_satisfied(matmul0_attr_teller); auto* matmul0_out = - VarNode("matmul0_out")->assert_is_op_output("matmul_v2", "Out"); + VarNode("matmul0_out")->assert_is_op_output(mul_type_, "Out"); // elementwise_add - auto* input1 = - VarNode("input1")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + auto* residual = VarNode("residual") + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); auto* add = OpNode("add", "elementwise_add"); auto* add0_out = VarNode("add0_out")->assert_is_op_output("elementwise_add", "Out"); @@ -146,42 +172,67 @@ void TransformerAttentionFuser::BuildPattern() { auto* softmax0_out = VarNode("softmax0_out")->assert_is_op_output("softmax", "Out"); - // matmul_v2 - auto* matmul1 = OpNode("matmul1", "matmul_v2") - ->assert_node_satisfied(matmul1_attr_teller); + // dropout + auto* dropout = OpNode("dropout", "dropout"); + auto* dropout_out = + VarNode("dropout_out")->assert_is_op_output("dropout", "Out"); + PMNode* mask_out = nullptr; + if (dropout_mask_) { + mask_out = VarNode("mask_out")->assert_is_op_output("dropout", "Mask"); + } + + // matmul + auto* matmul1 = + OpNode("matmul1", mul_type_)->assert_node_satisfied(matmul1_attr_teller); auto* Out = VarNode("Out"); - std::vector fc0_inputs{input0, fc0_w, fc0_bias}; - std::vector fc1_inputs{input0, fc1_w, fc1_bias}; - std::vector fc2_inputs{input0, fc2_w, fc2_bias}; + std::vector fc0_inputs{input, fc0_w, fc0_bias}; + std::vector fc1_inputs{input, fc1_w, fc1_bias}; + std::vector fc2_inputs{input, fc2_w, fc2_bias}; fc0_inputs >> *fc0 >> *fc0_out >> *reshape0 >> *reshape0_out >> *transpose0 >> *transpose0_out >> *scale0 >> *scale0_out; fc1_inputs >> *fc1 >> *fc1_out >> *reshape1 >> *reshape1_out >> *transpose1 >> *transpose1_out; fc2_inputs >> *fc2 >> *fc2_out >> *reshape2 >> *reshape2_out >> *transpose2 >> *transpose2_out; - *reshape0 >> *xshape0; - *reshape1 >> *xshape1; - *reshape2 >> *xshape2; - *transpose0 >> *xshape3; - *transpose1 >> *xshape4; - *transpose2 >> *xshape5; + if (reshape_has_xshape_) { + *reshape0 >> *xshape0; + *reshape1 >> *xshape1; + *reshape2 >> *xshape2; + } + if (transpose_has_xshape_) { + *transpose0 >> *xshape3; + *transpose1 >> *xshape4; + *transpose2 >> *xshape5; + } std::vector matmul0_inputs{scale0_out, transpose1_out}; matmul0_inputs >> *matmul0 >> *matmul0_out; - std::vector add0_inputs{matmul0_out, input1}; - add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out; + std::vector add0_inputs{matmul0_out, residual}; + add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out >> *dropout >> + *dropout_out; + + if (dropout_mask_) { + *dropout >> *mask_out; + } - std::vector matmul1_inputs{softmax0_out, transpose2_out}; + std::vector matmul1_inputs{dropout_out, transpose2_out}; matmul1_inputs >> *matmul1 >> *Out; - xshape0->AsIntermediate(); - xshape1->AsIntermediate(); - xshape2->AsIntermediate(); - xshape3->AsIntermediate(); - xshape4->AsIntermediate(); - xshape5->AsIntermediate(); + if (reshape_has_xshape_) { + xshape0->AsIntermediate(); + xshape1->AsIntermediate(); + xshape2->AsIntermediate(); + } + if (transpose_has_xshape_) { + xshape3->AsIntermediate(); + xshape4->AsIntermediate(); + xshape5->AsIntermediate(); + } + if (dropout_mask_) { + mask_out->AsIntermediate(); + } fc0->AsIntermediate(); fc0_out->AsIntermediate(); reshape0->AsIntermediate(); @@ -208,6 +259,8 @@ void TransformerAttentionFuser::BuildPattern() { add0_out->AsIntermediate(); softmax0->AsIntermediate(); softmax0_out->AsIntermediate(); + dropout->AsIntermediate(); + dropout_out->AsIntermediate(); matmul1->AsIntermediate(); } @@ -312,8 +365,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, auto* scope = fc->scope(); // set input - op_desc.SetInput("Input0", {matched.at("input0")->arg()->name}); - op_desc.SetInput("Input1", {matched.at("input1")->arg()->name}); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Residual", {matched.at("residual")->arg()->name}); // fc auto fc0_op_desc = matched.at("fc0")->stmt()->op_info(); @@ -393,6 +446,7 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, scale0_scale, bias0_dims[0]); op_desc.SetAttr>("fc0_scale", fuse_scales); + op_desc.SetAttr>("Input0_scale", fc0_scale_x); // fc 1 auto matmul0_scale_x = matmul0_op_desc->GetAttr>("X0_scale"); @@ -424,6 +478,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, op_desc.SetInput("Bias", {matched.at("fc0_bias")->arg()->name}); op_desc.SetAttr("op_type", fc0_op_desc->GetAttr("op_type")); + op_desc.SetAttr("in_num_col_dims", + fc0_op_desc->GetAttr("in_num_col_dims")); // reshape auto reshape_op_desc = matched.at("reshape0")->stmt()->op_info(); op_desc.SetAttr>( @@ -440,8 +496,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph, fused_attention_op->Attach(op_desc, scope); auto* new_op_node = graph->GraphCreateInstructNode(fused_attention_op, valid_places); - DirectedLink(matched.at("input0"), new_op_node); - DirectedLink(matched.at("input1"), new_op_node); + DirectedLink(matched.at("input"), new_op_node); + DirectedLink(matched.at("residual"), new_op_node); DirectedLink(matched.at("fc0_w"), new_op_node); DirectedLink(matched.at("fc0_bias"), new_op_node); DirectedLink(new_op_node, matched.at("Out")); diff --git a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h index c944192c0ec..adcb392319b 100644 --- a/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h +++ b/lite/core/optimizer/mir/fusion/transformer_attention_fuser.h @@ -25,10 +25,22 @@ namespace fusion { class TransformerAttentionFuser : public FuseBase { public: + explicit TransformerAttentionFuser(bool reshape_has_xshape, + bool transpose_has_xshape, + bool dropout_mask, + std::string mul_type) + : reshape_has_xshape_(reshape_has_xshape), + transpose_has_xshape_(transpose_has_xshape), + dropout_mask_(dropout_mask), + mul_type_(mul_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: + bool reshape_has_xshape_; + bool transpose_has_xshape_; + bool dropout_mask_; + std::string mul_type_; }; } // namespace fusion diff --git a/lite/core/optimizer/optimizer.cc b/lite/core/optimizer/optimizer.cc index 0229cce86cb..00f0834de3e 100644 --- a/lite/core/optimizer/optimizer.cc +++ b/lite/core/optimizer/optimizer.cc @@ -187,11 +187,11 @@ std::unique_ptr RunDefaultOptimizer( "lite_elementwise_activation_fuse_pass", "lite_conv_scale_fuse_pass", "lite_conv_elementwise_tree_fuse_pass", + "transformer_attention_fuse_pass", "lite_greater_than_cast_fuse_pass", "identity_dropout_eliminate_pass", "sparse_conv_detect_pass", // "keepdims_convert_pass", - "transformer_attention_fuse_pass", "__xpu__max_pooling_pad_zero_detect_fuse_pass", "__xpu__graph_dedup_pass", "__xpu__resnet_fuse_pass", diff --git a/lite/kernels/arm/fused_attention_compute.cc b/lite/kernels/arm/fused_attention_compute.cc index 70731384dbf..e72aa8948aa 100644 --- a/lite/kernels/arm/fused_attention_compute.cc +++ b/lite/kernels/arm/fused_attention_compute.cc @@ -70,7 +70,7 @@ void FusedAttentionCompute::PrepareForRun() { template void FusedAttentionCompute::ReInitWhenNeeded() { auto& param = this->template Param(); - auto input0_dims = param.input0->dims(); + auto input_dims = param.input->dims(); // fc act_param_.has_active = false; @@ -86,29 +86,29 @@ void FusedAttentionCompute::ReInitWhenNeeded() { int in_num_col_dims = param.in_num_col_dims; std::string op_type = param.op_type; if (op_type == "matmul" || op_type == "matmul_v2") { - in_num_col_dims = input0_dims.size() - 1; + in_num_col_dims = input_dims.size() - 1; } - fc_m_ = input0_dims.Slice(0, in_num_col_dims).production(); - fc_k_ = input0_dims.Slice(in_num_col_dims, input0_dims.size()).production(); + fc_m_ = input_dims.Slice(0, in_num_col_dims).production(); + fc_k_ = input_dims.Slice(in_num_col_dims, input_dims.size()).production(); CHECK_EQ(fc_k_, w_dims[0]); fc_n_ = w_dims[1]; - fc_dims_ = DDim(std::vector{input0_dims[0], fc_m_, fc_n_}); + fc_dims_ = DDim(std::vector{input_dims[0], fc_m_, fc_n_}); // reshape - reshape_shape_.push_back(input0_dims[0]); + reshape_shape_.push_back(input_dims[0]); reshape_shape_.push_back(param.reshape_shape[2]); - reshape_shape_.push_back(input0_dims[1]); + reshape_shape_.push_back(input_dims[1]); reshape_shape_.push_back(param.reshape_shape[3]); // transpose transpose_out_dim_ = DDim(std::vector{ - input0_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3]}); + input_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3]}); // fc1 fc1_m_ = transpose_out_dim_[2]; fc1_n_ = transpose_out_dim_[2]; fc1_k_ = transpose_out_dim_[3]; fc1_out_dim_ = DDim(std::vector{ - input0_dims[0], transpose_out_dim_[1], fc1_m_, fc1_n_}); + input_dims[0], transpose_out_dim_[1], fc1_m_, fc1_n_}); // softmax softmax_out_dim_ = fc1_out_dim_; @@ -135,10 +135,10 @@ template <> void FusedAttentionCompute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); - auto* input0_data = param.input0->data(); - auto* input1_data = param.input1->data(); + auto* input_data = param.input->data(); + auto* residual_data = param.residual->data(); auto* o_data = param.output->mutable_data(); - auto input0_dims = param.input0->dims(); + auto input_dims = param.input->dims(); auto out_dims = param.output->dims(); // fc + dequant_scale, bias, quant_scale @@ -153,7 +153,7 @@ void FusedAttentionCompute::Run() { fc_m_, fc_n_, fc_k_, - input0_data, + input_data, w_data, fc_out, b_data, @@ -164,7 +164,7 @@ void FusedAttentionCompute::Run() { &ctx); // transpose2 fuse reshape2 DDim trans_dims = DDim(std::vector{ - input0_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3] * 3}); + input_dims[0], reshape_shape_[1], fc_m_, reshape_shape_[3] * 3}); Tensor trans_t; trans_t.Resize(trans_dims); trans_t.mutable_data(); @@ -174,7 +174,7 @@ void FusedAttentionCompute::Run() { auto* v1 = v0 + stride; auto* v2 = v1 + stride; TransposeCompute_1to3( - fc_out, v0, v1, v2, input0_dims[0], fc_m_, fc_n_, transpose_out_dim_[3]); + fc_out, v0, v1, v2, input_dims[0], fc_m_, fc_n_, transpose_out_dim_[3]); // fc -> out fp32 Tensor fc1_t; fc1_t.Resize(fc1_out_dim_); @@ -183,7 +183,7 @@ void FusedAttentionCompute::Run() { int x_inner = fc1_m_ * fc1_k_; int y_inner = fc1_k_ * fc1_n_; int out_inner = fc1_m_ * fc1_n_; - auto* fc1_b_data = param.input1->data(); + auto* fc1_b_data = param.residual->data(); for (size_t i = 0; i < transpose_out_dim_[1]; ++i) { lite::arm::math::gemm_s8(false, true, @@ -263,9 +263,8 @@ typedef paddle::lite::kernels::arm::FusedAttentionCompute FusedAttentionCompute_Int8; REGISTER_LITE_KERNEL( fused_attention, kARM, kInt8, kNCHW, FusedAttentionCompute_Int8, def) - .BindInput("Input0", - {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) - .BindInput("Input1", + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Residual", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) diff --git a/lite/operators/fused_attention_op.cc b/lite/operators/fused_attention_op.cc index 2c34d42d019..040b755a4ac 100644 --- a/lite/operators/fused_attention_op.cc +++ b/lite/operators/fused_attention_op.cc @@ -20,12 +20,12 @@ namespace lite { namespace operators { bool FusedAttentionOpLite::CheckShape() const { - CHECK_OR_FALSE(param_.input0); - CHECK_OR_FALSE(param_.input1); + CHECK_OR_FALSE(param_.input); + CHECK_OR_FALSE(param_.residual); CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.fc_w); - const auto input_dims = param_.input0->dims(); + const auto input_dims = param_.input->dims(); const auto w_dims = param_.fc_w->dims(); CHECK_EQ_OR_FALSE(w_dims.size(), 2UL); int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; @@ -58,9 +58,28 @@ static bool CheckPositive(const DDim &dims) { } return true; } + bool FusedAttentionOpLite::InferShape() { - lite::DDim x_dims = param_.input0->dims(); - const DDim::value_type input_size = x_dims.production(); + lite::DDim x_dims = param_.input->dims(); + + // infer fc + int in_num_col_dims = param_.in_num_col_dims; + std::string op_type = param_.op_type; + const auto &w_dims = param_.fc_w->dims(); + int64_t w_dims_1 = w_dims[1] / 3; + + if (op_type == "matmul" || op_type == "matmul_v2") { + in_num_col_dims = x_dims.size() - 1; + } + DDim::value_type fc_output_size = 1; + std::vector fc_output_dims(in_num_col_dims + 1); + for (int i = 0; i < in_num_col_dims; ++i) { + fc_output_dims[i] = x_dims[i]; + fc_output_size *= fc_output_dims[i]; + } + fc_output_dims[in_num_col_dims] = w_dims_1; + fc_output_size *= fc_output_dims[in_num_col_dims]; + std::vector shape = param_.reshape_shape; std::vector reshape_output_dims(shape.size()); DDim::value_type capacity = 1; @@ -73,7 +92,7 @@ bool FusedAttentionOpLite::InferShape() { << "Only one input dimension of Attr(shape) can be unknown."; unk_dim_idx = i; } else if (shape[i] == copy_dim_val) { - CHECK_LT(i, x_dims.size()) + CHECK_LT(i, fc_output_dims.size()) << "The index of dimension to copy from input shape must be less " "than the size of input shape."; } else { @@ -82,24 +101,24 @@ bool FusedAttentionOpLite::InferShape() { } DDim::value_type output_dim_i = - shape[i] ? static_cast(shape[i]) : x_dims[i]; + shape[i] ? static_cast(shape[i]) : fc_output_dims[i]; reshape_output_dims[i] = output_dim_i; capacity *= output_dim_i; } if (unk_dim_idx != -1) { - if (CheckPositive(x_dims)) { + if (CheckPositive(lite::DDim(fc_output_dims))) { // input_size < 0 and is un-determinate in compile time, skip the check, // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // capacity = -24, input_size = -8, output_shape[0] = 0 // the following check will fail. - reshape_output_dims[unk_dim_idx] = -input_size / capacity; - CHECK_EQ(reshape_output_dims[unk_dim_idx] * capacity, -input_size) + reshape_output_dims[unk_dim_idx] = -fc_output_size / capacity; + CHECK_EQ(reshape_output_dims[unk_dim_idx] * capacity, -fc_output_size) << "Invalid shape is given."; } else { reshape_output_dims[unk_dim_idx] = -1; } } else { - CHECK_EQ(capacity, input_size) << "Invalid shape is given."; + CHECK_EQ(capacity, fc_output_size) << "Invalid shape is given."; } lite::DDim out_dims = lite::DDim({reshape_output_dims[0], reshape_output_dims[2], @@ -108,19 +127,19 @@ bool FusedAttentionOpLite::InferShape() { param_.output->Resize(out_dims); // share LoD - param_.output->set_lod(param_.input0->lod()); + param_.output->set_lod(param_.input->lod()); return true; } bool FusedAttentionOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { - auto input0 = op_desc.Input("Input0").front(); - auto input1 = op_desc.Input("Input1").front(); + auto input = op_desc.Input("Input").front(); + auto residual = op_desc.Input("Residual").front(); auto fc_w = op_desc.Input("W").front(); auto output = op_desc.Output("Out").front(); - param_.input0 = scope->FindVar(input0)->GetMutable(); - param_.input1 = scope->FindVar(input1)->GetMutable(); + param_.input = scope->FindVar(input)->GetMutable(); + param_.residual = scope->FindVar(residual)->GetMutable(); param_.fc_w = scope->FindVar(fc_w)->GetMutable(); param_.output = scope->FindVar(output)->GetMutable(); @@ -135,6 +154,7 @@ bool FusedAttentionOpLite::AttachImpl(const cpp::OpDesc &op_desc, } } } + param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims"); param_.reshape_shape = op_desc.GetAttr>("reshape_shape"); param_.softmax_axis = op_desc.GetAttr("softmax_axis"); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 052ad9e9ad0..abd9227f8c0 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -118,8 +118,8 @@ struct FcParam : ParamBase { }; struct FusedAttentionParam : ParamBase { - lite::Tensor* input0{nullptr}; - lite::Tensor* input1{nullptr}; + lite::Tensor* input{nullptr}; + lite::Tensor* residual{nullptr}; lite::Tensor* fc_w{nullptr}; lite::Tensor* fc_bias{nullptr}; lite::Tensor* output{nullptr};