From 5a95eb83449a2a0b6df4c85946460bcf176c1648 Mon Sep 17 00:00:00 2001 From: shenyijun01 Date: Fri, 18 Nov 2022 13:35:30 +0800 Subject: [PATCH] [Optimizer]: add quick gelu fusion pass for ViT model. --- lite/api/paddle_use_passes.h | 1 + .../mir/fusion/__xpu__fc_fuse_pass.cc | 4 +- .../mir/fusion/__xpu__quick_gelu_fuse_pass.cc | 128 ++++++++++++++++++ lite/core/optimizer/optimizer.cc | 1 + lite/kernels/xpu/CMakeLists.txt | 1 + lite/kernels/xpu/__xpu__quick_gelu_compute.cc | 53 ++++++++ lite/kernels/xpu/__xpu__quick_gelu_compute.h | 36 +++++ lite/operators/CMakeLists.txt | 1 + lite/operators/__xpu__quick_gelu_op.cc | 47 +++++++ lite/operators/__xpu__quick_gelu_op.h | 48 +++++++ lite/operators/op_params.h | 5 + 11 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 lite/core/optimizer/mir/fusion/__xpu__quick_gelu_fuse_pass.cc create mode 100644 lite/kernels/xpu/__xpu__quick_gelu_compute.cc create mode 100644 lite/kernels/xpu/__xpu__quick_gelu_compute.h create mode 100644 lite/operators/__xpu__quick_gelu_op.cc create mode 100644 lite/operators/__xpu__quick_gelu_op.h diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 8c31ad412bb..d8c6486da3d 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -81,6 +81,7 @@ USE_MIR_PASS(__xpu__resnet_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_fuse_pass); USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass); USE_MIR_PASS(__xpu__fc_fuse_pass); +USE_MIR_PASS(__xpu__quick_gelu_fuse_pass); USE_MIR_PASS(__xpu__mmdnn_fuse_pass); USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass); USE_MIR_PASS(__xpu__conv2d_fuse_pass); diff --git a/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc index 464436f6577..030ff71b4f6 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc @@ -199,7 +199,8 @@ class XPUFcFuser : public FuseBase { {"leaky_relu", 5}, {"hard_swish", 14}, {"hard_sigmoid", 15}, - {"relu6", 17}}; + {"relu6", 17}, + {"__xpu__quick_gelu", 19}}; float act_param_ = 0.0f; if (act_type_ == "leaky_relu") { @@ -281,6 +282,7 @@ class XPUFcFusePass : public ProgramPass { for (auto with_bias : {true, false}) { for (auto act_type : {"relu", "gelu", + "__xpu__quick_gelu", /*"sigmoid", "tanh", "leaky_relu", diff --git a/lite/core/optimizer/mir/fusion/__xpu__quick_gelu_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__quick_gelu_fuse_pass.cc new file mode 100644 index 00000000000..1b7d1874217 --- /dev/null +++ b/lite/core/optimizer/mir/fusion/__xpu__quick_gelu_fuse_pass.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/math.h" +#include "lite/core/optimizer/mir/pass_registry.h" +#include "lite/core/optimizer/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class XPUQuickGELUFuser : public FuseBase { + public: + explicit XPUQuickGELUFuser(const std::string& op_type, + const std::string& act_type) { + op_type_ = op_type; + act_type_ = act_type; + } + + void BuildPattern() override { + auto scale_teller = [](const Node* node) -> bool { + bool bias_after_scale = + const_cast(node)->AsStmt().op_info()->GetAttr( + "bias_after_scale"); + bool has_act = const_cast(node)->AsStmt().op_info()->HasAttr( + "activation_type"); + return bias_after_scale && (!has_act); + }; + + /* _____________________ + / \ + Create node: X----scale----sigmoid---elementwise_mul---output + */ + auto* x = VarNode("x")->assert_is_op_input("scale", "X"); + auto* scale = OpNode("scale", "scale") + ->assert_is_op("scale") + ->assert_node_satisfied(scale_teller); + auto* scale_out = VarNode("scale_out"); + auto* sigmoid = OpNode("sigmoid", act_type_); + auto* sigmoid_out = VarNode("sigmoid_out"); + auto* element_mul = + OpNode("elementwise_mul", op_type_) + ->assert_op_attr_satisfied( + "axis", [](int attr) { return attr == -1 || attr == 0; }); + auto* output = VarNode("Out"); + + // Construct the topological structure for scale-sigmoid-elementwise_mul + *x >> *scale >> *scale_out >> *sigmoid >> *sigmoid_out; + std::vector element_mul_inputs{x, sigmoid_out}; + element_mul_inputs >> *element_mul >> *output; + + // Some op specialities. + scale->AsIntermediate(); + scale_out->AsIntermediate(); + sigmoid->AsIntermediate(); + sigmoid_out->AsIntermediate(); + element_mul->AsIntermediate(); + } + + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) { + auto op_desc = *matched.at("scale")->stmt()->op_info(); + float scale_val = op_desc.GetAttr("scale"); + op_desc.mutable_inputs()->clear(); + op_desc.mutable_outputs()->clear(); + op_desc.SetType("__xpu__quick_gelu"); + op_desc.SetInput("X", {matched.at("x")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr("scale", scale_val); + return op_desc; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + // get op_desc for gelu op. + auto op_desc = GenOpDesc(matched); + // Create gelu op. + auto gelu_op = LiteOpRegistry::Global().Create("__xpu__quick_gelu"); + + // find scope and valid_places of original scale op. + auto scale = matched.at("scale")->stmt()->op(); + auto* scope = scale->scope(); + auto& valid_places = scale->valid_places(); + + // set gelu op's scope and valid_places which aligned with scale op. + gelu_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(gelu_op, valid_places); + + // link IO to the new op node. + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); + } + + private: + std::string op_type_; + std::string act_type_; +}; + +} // namespace fusion + +class XPUQuickGELUFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + fusion::XPUQuickGELUFuser fuser("elementwise_mul", "sigmoid"); + fuser(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__quick_gelu_fuse_pass, + paddle::lite::mir::XPUQuickGELUFusePass) + .BindTargets({TARGET(kXPU)}) + .BindKernel("__xpu__quick_gelu"); diff --git a/lite/core/optimizer/optimizer.cc b/lite/core/optimizer/optimizer.cc index 4d0dcfb1570..7afecb28e56 100644 --- a/lite/core/optimizer/optimizer.cc +++ b/lite/core/optimizer/optimizer.cc @@ -200,6 +200,7 @@ std::unique_ptr RunDefaultOptimizer( "__xpu__mmdnn_fuse_pass", "__xpu__bigru_fuse_pass", "__xpu__roformer_relative_pos_fuse_pass", + "__xpu__quick_gelu_fuse_pass", "__xpu__multi_encoder_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass", "__xpu__fc_fuse_pass", diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index 5e746ed5776..a37e833b628 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -110,6 +110,7 @@ add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc) add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc) add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc) +add_kernel(__xpu__quick_gelu_compute_xpu XPU extra SRCS __xpu__quick_gelu_compute.cc) add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc) add_kernel(__xpu__search_attention_2_compute_xpu XPU extra SRCS __xpu__search_attention_2_compute.cc) add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc) diff --git a/lite/kernels/xpu/__xpu__quick_gelu_compute.cc b/lite/kernels/xpu/__xpu__quick_gelu_compute.cc new file mode 100644 index 00000000000..0a20a2e4517 --- /dev/null +++ b/lite/kernels/xpu/__xpu__quick_gelu_compute.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__quick_gelu_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +void QuickGeluCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + int r = xdnn::quick_gelu(ctx.GetRawContext(), + param.X->template data(), + param.Out->template mutable_data(TARGET(kXPU)), + param.X->numel()); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +using quick_gelu_FP32 = + paddle::lite::kernels::xpu::QuickGeluCompute; +using qucik_gelu_FP16 = + paddle::lite::kernels::xpu::QuickGeluCompute; +REGISTER_LITE_KERNEL(quick_gelu, kXPU, kFloat, kNCHW, quick_gelu_FP32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); +REGISTER_LITE_KERNEL( + quick_gelu, kXPU, kFP16, kNCHW, qucik_gelu_FP16, qucik_gelu_FP16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__quick_gelu_compute.h b/lite/kernels/xpu/__xpu__quick_gelu_compute.h new file mode 100644 index 00000000000..16ba982483f --- /dev/null +++ b/lite/kernels/xpu/__xpu__quick_gelu_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +class QuickGeluCompute : public KernelLite { + public: + using param_t = operators::XPUQuickGeluParam; + + virtual void Run(); + + virtual ~QuickGeluCompute() = default; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index db3dd661f45..9138e42ffe0 100755 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -233,6 +233,7 @@ add_operator(__xpu__softmax_topk_op extra SRCS __xpu__softmax_topk_op.cc) add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc) add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc) add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc) +add_operator(__xpu__quick_gelu_op extra SRCS __xpu__quick_gelu_op.cc) add_operator(__xpu__roformer_relative_embedding_op extra SRCS __xpu__roformer_relative_embedding_op.cc) add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc) add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc) diff --git a/lite/operators/__xpu__quick_gelu_op.cc b/lite/operators/__xpu__quick_gelu_op.cc new file mode 100644 index 00000000000..95f9756548f --- /dev/null +++ b/lite/operators/__xpu__quick_gelu_op.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.i + +#include "lite/operators/__xpu__quick_gelu_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPUQuickGeluOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool XPUQuickGeluOp::InferShapeImpl() const { + param_.Out->Resize(param_.X->dims()); + auto out_lod = param_.Out->mutable_lod(); + *out_lod = param_.X->lod(); + return true; +} + +bool XPUQuickGeluOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + auto x_name = opdesc.Input("X").front(); + auto out_name = opdesc.Output("Out").front(); + param_.X = scope->FindVar(x_name)->GetMutable(); + param_.Out = scope->FindVar(out_name)->GetMutable(); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__quick_gelu, paddle::lite::operators::XPUQuickGeluOp); diff --git a/lite/operators/__xpu__quick_gelu_op.h b/lite/operators/__xpu__quick_gelu_op.h new file mode 100644 index 00000000000..c347cd47f8f --- /dev/null +++ b/lite/operators/__xpu__quick_gelu_op.h @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#ifdef LITE_WITH_PROFILE +#include "lite/api/paddle_place.h" +#endif + +namespace paddle { +namespace lite { +namespace operators { + +class XPUQuickGeluOp : public OpLite { + public: + explicit XPUQuickGeluOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool InferType() override { return true; } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUQuickGelu"; } + + private: + mutable operators::XPUQuickGeluParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index e074b096526..b20fafebd2a 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1767,6 +1767,11 @@ struct XPUEmbeddingWithEltwiseAddParam : ParamBase { int64_t padding_idx{-1}; }; +struct XPUQuickGeluParam : ParamBase { + const lite::Tensor* X{}; + lite::Tensor* Out{}; +}; + struct XPUFcParam : ParamBase { const lite::Tensor* input{nullptr}; const lite::Tensor* w{nullptr};