Skip to content

Commit

Permalink
[xpu] Add __xpu__quick_gelu op and fuse it into __xpu__multi_encoder…
Browse files Browse the repository at this point in the history
…_op for ViT model (PaddlePaddle#9755)
  • Loading branch information
stevenshen36 authored and qfyinbd committed Nov 29, 2022
1 parent 12ab107 commit 0b7755f
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 2 deletions.
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
USE_MIR_PASS(__xpu__quick_gelu_fuse_pass);
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
USE_MIR_PASS(__xpu__conv2d_fuse_pass);
Expand Down
4 changes: 3 additions & 1 deletion lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ class XPUFcFuser : public FuseBase {
{"leaky_relu", 5},
{"hard_swish", 14},
{"hard_sigmoid", 15},
{"relu6", 17}};
{"relu6", 17},
{"__xpu__quick_gelu", 19}};

float act_param_ = 0.0f;
if (act_type_ == "leaky_relu") {
Expand Down Expand Up @@ -281,6 +282,7 @@ class XPUFcFusePass : public ProgramPass {
for (auto with_bias : {true, false}) {
for (auto act_type : {"relu",
"gelu",
"__xpu__quick_gelu",
/*"sigmoid",
"tanh",
"leaky_relu",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ class XPUMultiEncoderFusePass : public ProgramPass {
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
// TODO(miaotianxiang): backup graph, recover from failed match
std::vector<std::string> act_types{"gelu", "relu"};
std::vector<std::string> act_types{"gelu", "relu", "__xpu__quick_gelu"};
std::vector<std::string> input_poss{"X", "Y"};
std::vector<std::string> qkv_ln_2_out_poss{"X", "Y"};
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
Expand Down
122 changes: 122 additions & 0 deletions lite/core/optimizer/mir/fusion/__xpu__quick_gelu_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <math.h>
#include <memory>
#include <string>
#include "lite/backends/xpu/math.h"
#include "lite/core/optimizer/mir/pass_registry.h"
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

class XPUQuickGELUFuser : public FuseBase {
public:
XPUQuickGELUFuser() {}

void BuildPattern() override {
auto scale_teller = [](const Node* node) -> bool {
float bias_v =
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<float>("bias");
float scale_v =
const_cast<Node*>(node)->AsStmt().op_info()->GetAttr<float>("scale");
bool expect_bias = (bias_v == 0.0) ? true : false;
bool expect_scale = (abs(scale_v - 1.702) < 1e-5) ? true : false;
bool has_act = const_cast<Node*>(node)->AsStmt().op_info()->HasAttr(
"activation_type");
return (expect_bias) && (expect_scale) && (!has_act);
};

/* _____________________
/ \
Create node: X----scale----sigmoid---elementwise_mul---output
*/
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
auto* scale = OpNode("scale", "scale")->assert_node_satisfied(scale_teller);
auto* scale_out = VarNode("scale_out");
auto* sigmoid = OpNode("sigmoid", "sigmoid");
auto* sigmoid_out = VarNode("sigmoid_out");
auto* element_mul =
OpNode("elementwise_mul", "elementwise_mul")
->assert_op_attr_satisfied<int>(
"axis", [](int attr) { return attr == -1 || attr == 0; });
auto* output = VarNode("Out")->AsOutput();

// Construct the topological structure for scale-sigmoid-elementwise_mul
*x >> *scale >> *scale_out >> *sigmoid >> *sigmoid_out;
std::vector<PMNode*> element_mul_inputs{x, sigmoid_out};
element_mul_inputs >> *element_mul >> *output;

// Some op specialities.
scale->AsIntermediate();
scale_out->AsIntermediate();
sigmoid->AsIntermediate();
sigmoid_out->AsIntermediate();
element_mul->AsIntermediate();
}

cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("scale")->stmt()->op_info();
float scale_val = op_desc.GetAttr<float>("scale");
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("__xpu__quick_gelu");
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr("scale", scale_val);
return op_desc;
}

void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
// get op_desc for gelu op.
auto op_desc = GenOpDesc(matched);
// Create gelu op.
auto gelu_op = LiteOpRegistry::Global().Create("__xpu__quick_gelu");

// find scope and valid_places of original scale op.
auto scale = matched.at("scale")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();

// set gelu op's scope and valid_places which aligned with scale op.
gelu_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(gelu_op, valid_places);

// link IO to the new op node.
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
};

} // namespace fusion

class XPUQuickGELUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
fusion::XPUQuickGELUFuser fuser;
fuser(graph.get());
}
};

} // namespace mir
} // namespace lite
} // namespace paddle

REGISTER_MIR_PASS(__xpu__quick_gelu_fuse_pass,
paddle::lite::mir::XPUQuickGELUFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("__xpu__quick_gelu");
1 change: 1 addition & 0 deletions lite/core/optimizer/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
"__xpu__mmdnn_fuse_pass",
"__xpu__bigru_fuse_pass",
"__xpu__roformer_relative_pos_fuse_pass",
"__xpu__quick_gelu_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/xpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc)
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc)
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc)
add_kernel(__xpu__quick_gelu_compute_xpu XPU extra SRCS __xpu__quick_gelu_compute.cc)
add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc)
add_kernel(__xpu__search_attention_2_compute_xpu XPU extra SRCS __xpu__search_attention_2_compute.cc)
add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc)
Expand Down
2 changes: 2 additions & 0 deletions lite/kernels/xpu/__xpu__multi_encoder_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ void XPUMultiEncoderCompute::PrepareForRun() {
// prepare act_type
if (param.act_type == "gelu") {
qkv_act = xdnn::Activation_t::GELU;
} else if (param.act_type == "__xpu__quick_gelu") {
qkv_act = xdnn::Activation_t::QUICK_GELU;
} else if (param.act_type != "relu") {
CHECK(false) << "Invalid QKV Activation Type: " << param.act_type;
}
Expand Down
54 changes: 54 additions & 0 deletions lite/kernels/xpu/__xpu__quick_gelu_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/__xpu__quick_gelu_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T, PrecisionType PType>
void QuickGeluCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

int r = xdnn::quick_gelu(ctx.GetRawContext(),
param.X->template data<T>(),
param.Out->template mutable_data<T>(TARGET(kXPU)),
param.X->numel());
CHECK_EQ(r, 0);
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

using quick_gelu_FP32 =
paddle::lite::kernels::xpu::QuickGeluCompute<float, PRECISION(kFloat)>;
using qucik_gelu_FP16 =
paddle::lite::kernels::xpu::QuickGeluCompute<float16, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(
__xpu__quick_gelu, kXPU, kFloat, kNCHW, quick_gelu_FP32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(
__xpu__quick_gelu, kXPU, kFP16, kNCHW, qucik_gelu_FP16, qucik_gelu_FP16)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();
36 changes: 36 additions & 0 deletions lite/kernels/xpu/__xpu__quick_gelu_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T, PrecisionType PType>
class QuickGeluCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::XPUQuickGeluParam;

virtual void Run();

virtual ~QuickGeluCompute() = default;
};

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ add_operator(__xpu__softmax_topk_op extra SRCS __xpu__softmax_topk_op.cc)
add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc)
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc)
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc)
add_operator(__xpu__quick_gelu_op extra SRCS __xpu__quick_gelu_op.cc)
add_operator(__xpu__roformer_relative_embedding_op extra SRCS __xpu__roformer_relative_embedding_op.cc)
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc)
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc)
Expand Down
47 changes: 47 additions & 0 deletions lite/operators/__xpu__quick_gelu_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.i

#include "lite/operators/__xpu__quick_gelu_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool XPUQuickGeluOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}

bool XPUQuickGeluOp::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims());
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
return true;
}

bool XPUQuickGeluOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto x_name = opdesc.Input("X").front();
auto out_name = opdesc.Output("Out").front();
param_.X = scope->FindVar(x_name)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>();
return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(__xpu__quick_gelu, paddle::lite::operators::XPUQuickGeluOp);
48 changes: 48 additions & 0 deletions lite/operators/__xpu__quick_gelu_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include "lite/core/op_lite.h"
#ifdef LITE_WITH_PROFILE
#include "lite/api/paddle_place.h"
#endif

namespace paddle {
namespace lite {
namespace operators {

class XPUQuickGeluOp : public OpLite {
public:
explicit XPUQuickGeluOp(const std::string& type) : OpLite(type) {}

bool CheckShape() const override;

bool InferShapeImpl() const override;

bool InferType() override { return true; }

bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;

void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }

std::string DebugString() const override { return "XPUQuickGelu"; }

private:
mutable operators::XPUQuickGeluParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
5 changes: 5 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,11 @@ struct XPUEmbeddingWithEltwiseAddParam : ParamBase {
int mask_dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
};

struct XPUQuickGeluParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
};

struct XPUFcParam : ParamBase {
const lite::Tensor* input{nullptr};
const lite::Tensor* w{nullptr};
Expand Down

0 comments on commit 0b7755f

Please sign in to comment.