Skip to content

Commit

Permalink
[Optimizer]: add quick gelu fusion pass for ViT model.
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenshen36 committed Nov 18, 2022
1 parent 382489d commit a2a25a4
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 3 deletions.
3 changes: 2 additions & 1 deletion lite/api/paddle_place.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ enum class ActivationType : int {
kSoftPlus = 21,
kMish = 22,
kSilu = 23,
NUM = 24,
kQuickGelu = 24,
NUM = 25,
};

static size_t PrecisionTypeLength(PrecisionType type) {
Expand Down
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
USE_MIR_PASS(__xpu__quick_gelu_fuse_pass);
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
USE_MIR_PASS(__xpu__conv2d_fuse_pass);
Expand Down
4 changes: 3 additions & 1 deletion lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ class XPUFcFuser : public FuseBase {
{"leaky_relu", 5},
{"hard_swish", 14},
{"hard_sigmoid", 15},
{"relu6", 17}};
{"relu6", 17},
{"quick_gelu", 19}};

float act_param_ = 0.0f;
if (act_type_ == "leaky_relu") {
Expand Down Expand Up @@ -281,6 +282,7 @@ class XPUFcFusePass : public ProgramPass {
for (auto with_bias : {true, false}) {
for (auto act_type : {"relu",
"gelu",
"quick_gelu",
/*"sigmoid",
"tanh",
"leaky_relu",
Expand Down
126 changes: 126 additions & 0 deletions lite/core/optimizer/mir/fusion/__xpu__quick_gelu_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include "lite/backends/xpu/math.h"
#include "lite/core/optimizer/mir/pass_registry.h"
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

class XPUQuickGELUFuser : public FuseBase {
public:
explicit XPUQuickGELUFuser(const std::string& op_type,
const std::string& act_type) {
op_type_ = op_type;
act_type_ = act_type;
}

void BuildPattern() override {
auto scale_teller = [](const Node* node) -> bool {
bool bias_after_scale =
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<bool>(
"bias_after_scale");
bool has_act =
const_cast<Node*>(node)->AsStmt().op_info()->HasAttr("activation_type");
return bias_after_scale && (!has_act);
};

/* _____________________
/ \
Create node: X----scale----sigmoid---elementwise_mul---output
*/
auto* x = VarNode("x")->assert_is_op_input("scale", "X");
auto* scale = OpNode("scale", "scale")
->assert_is_op("scale")
->assert_node_satisfied(scale_teller);
auto* scale_out = VarNode("scale_out");
auto* sigmoid = OpNode("sigmoid", act_type_);
auto* sigmoid_out = VarNode("sigmoid_out");
auto* element_mul = OpNode("elementwise_mul",op_type_)
->assert_op_attr_satisfied<int>(
"axis",[](int attr){return attr == -1 || attr == 0;});
auto* output = VarNode("Out");

// Construct the topological structure for scale-sigmoid-elementwise_mul
*x >> *scale >> *scale_out >> *sigmoid >> *sigmoid_out;
std::vector<PMNode*> element_mul_inputs{x, sigmoid_out};
element_mul_inputs >> *element_mul >> *output;

// Some op specialities.
scale->AsIntermediate();
scale_out->AsIntermediate();
sigmoid->AsIntermediate();
sigmoid_out->AsIntermediate();
element_mul->AsIntermediate();
}

cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("scale")->stmt()->op_info();
float scale_val = op_desc.GetAttr<float>("scale");
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("quick_gelu");
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetOutput("Out",{matched.at("Out")->arg()->name});
op_desc.SetAttr("scale", scale_val);
return op_desc;
}

void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
// get op_desc for gelu op.
auto op_desc = GenOpDesc(matched);
// Create gelu op.
auto gelu_op = LiteOpRegistry::Global().Create("quick_gelu");

// find scope and valid_places of original scale op.
auto scale = matched.at("scale")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();

// set gelu op's scope and valid_places which aligned with scale op.
gelu_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(gelu_op, valid_places);

// link IO to the new op node.
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}

private:
std::string op_type_;
std::string act_type_;
};

} // namespace fusion

class XPUQuickGELUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
fusion::XPUQuickGELUFuser fuser("elementwise_mul", "sigmoid");
fuser(graph.get());
}
};

} // namespace mir
} // namespace lite
} // namespace paddle

REGISTER_MIR_PASS(__xpu__quick_gelu_fuse_pass, paddle::lite::mir::XPUQuickGELUFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("quick_gelu");
1 change: 1 addition & 0 deletions lite/core/optimizer/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
"__xpu__mmdnn_fuse_pass",
"__xpu__bigru_fuse_pass",
"__xpu__roformer_relative_pos_fuse_pass",
"__xpu__quick_gelu_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
Expand Down
25 changes: 25 additions & 0 deletions lite/kernels/xpu/activation_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ void GeluCompute<T, PType>::Run() {
CHECK_EQ(r, 0);
}

template <typename T, PrecisionType PType>
void QuickGeluCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

int r = xdnn::quick_gelu(ctx.GetRawContext(),
param.X->template data<T>(),
param.Out->template mutable_data<T>(TARGET(kXPU)),
param.X->numel());
CHECK_EQ(r, 0);
}

template <typename T, PrecisionType PType>
void TanhCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
Expand Down Expand Up @@ -333,6 +345,19 @@ REGISTER_LITE_KERNEL(gelu, kXPU, kFP16, kNCHW, geluFP16, geluFP16)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();

using quick_gelu_FP32 =
paddle::lite::kernels::xpu::QuickGeluCompute<float, PRECISION(kFloat)>;
using qucik_gelu_FP16 =
paddle::lite::kernels::xpu::QuickGeluCompute<float16, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(quick_gelu, kXPU, kFloat, kNCHW, quick_gelu_FP32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(quick_gelu, kXPU, kFP16, kNCHW, qucik_gelu_FP16, qucik_gelu_FP16)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();

using tanhFP32 =
paddle::lite::kernels::xpu::TanhCompute<float, PRECISION(kFloat)>;
using tanhFP16 =
Expand Down
10 changes: 10 additions & 0 deletions lite/kernels/xpu/activation_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ class FloorCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
virtual ~FloorCompute() = default;
};

template <typename T, PrecisionType PType>
class QuickGeluCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::ActivationParam;

virtual void Run();

virtual ~QuickGeluCompute() = default;
};

} // namespace xpu
} // namespace kernels
} // namespace lite
Expand Down
5 changes: 4 additions & 1 deletion lite/operators/activation_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
param_.active_type = lite_api::ActivationType::kMish;
param_.threshold = opdesc.GetAttr<float>("threshold");
}

else if(opdesc.Type() == "quick_gelu") {
param_.active_type = lite_api::ActivationType::kQuickGelu;
}
VLOG(4) << "opdesc.Type():" << opdesc.Type();

param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>();
Expand All @@ -126,6 +128,7 @@ REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(gelu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(quick_gelu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(reciprocal, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(mish, paddle::lite::operators::ActivationOp);
Expand Down
3 changes: 3 additions & 0 deletions lite/operators/activation_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class ActivationOp : public OpLite {
case lite_api::ActivationType::kSilu:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kQuickGelu:
ch->macs = param_.X->numel();
break;
default:
LOG(FATAL) << "This Type of Activation:"
<< static_cast<int>(param_.active_type)
Expand Down

0 comments on commit a2a25a4

Please sign in to comment.