Skip to content

Commit

Permalink
add activation xpu gelu (PaddlePaddle#7527)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gradie authored and newway committed Dec 29, 2021
1 parent 7fb282e commit e7f907b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
17 changes: 17 additions & 0 deletions lite/kernels/xpu/activation_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ void Relu6Compute::Run() {
CHECK_EQ(r, 0);
}

void GeluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();

int r = xdnn::gelu(ctx.GetRawContext(),
param.X->data<float>(),
param.Out->mutable_data<float>(TARGET(kXPU)),
param.X->numel());
CHECK_EQ(r, 0);
}

void TanhCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
Expand Down Expand Up @@ -266,6 +277,12 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(
gelu, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::GeluCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(
tanh, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::TanhCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
Expand Down
9 changes: 9 additions & 0 deletions lite/kernels/xpu/activation_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ class Relu6Compute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
virtual ~Relu6Compute() = default;
};

class GeluCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;

virtual void Run();

virtual ~GeluCompute() = default;
};

class TanhCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
Expand Down

0 comments on commit e7f907b

Please sign in to comment.