Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】16 新增 API RRelu #41823

Merged
merged 29 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b4022f3
rrelu逻辑部分
thunder95 Mar 28, 2022
b03c8d1
unregistered op kernel (unresolved)
thunder95 Apr 14, 2022
a83e2bb
commit before merge
thunder95 Apr 16, 2022
b34da04
merge develop
thunder95 Apr 16, 2022
e5f3910
丰富测试用例
thunder95 Apr 26, 2022
a5388cb
合并develop解决冲突
thunder95 Apr 26, 2022
1ee342e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 26, 2022
f74eb8a
修复rrelu-sig的bug
thunder95 Apr 26, 2022
bb3b47e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 26, 2022
71fdbab
修复cpu环境测试
thunder95 Apr 26, 2022
c51cc32
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 26, 2022
f7cf53b
修改拼写错误
thunder95 Apr 27, 2022
a0cd822
修改code format
thunder95 Apr 27, 2022
e5cf547
尝试优化测试用例timeout的问题
thunder95 Apr 29, 2022
a6f7c04
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 29, 2022
fc85f79
merge develop
thunder95 Apr 29, 2022
b50216e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 29, 2022
db09397
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 29, 2022
90102c0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 Apr 30, 2022
28cd511
优化测试用例
thunder95 Apr 30, 2022
a5fed7f
移除seed, 优化随机函数
thunder95 May 5, 2022
f2a9834
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 5, 2022
916b5b8
update en doc for rrelu
thunder95 May 6, 2022
51493a3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 6, 2022
726a5d5
fix rrelu en docs, test=document_fix
thunder95 May 10, 2022
4502a01
add paper link for en docs, test=document_fix
thunder95 May 10, 2022
2132bf5
udpate en doc
thunder95 May 14, 2022
116873b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thunder95 May 14, 2022
8f626c2
add r,test=document_fix
thunder95 May 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions paddle/fluid/operators/rrelu_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

using framework::Tensor;

class RReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class RReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of RReLU op.");
AddOutput("Out", "The output of RReLU op.");
AddOutput("Noise", "The random sampled RReLU noise.")
.AsIntermediate()
.AsExtra();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
float default_lower = 1. / 8.;
AddAttr<float>("lower", "Lower bound of the uniform distribution.")
.SetDefault(default_lower)
.AddCustomChecker([](const float& lower) {
PADDLE_ENFORCE_EQ(lower >= 0.0f && lower < 1.0f, true,
platform::errors::InvalidArgument(
"'RRelu_lower' must be between 0.0 and 1.0."));
});
float defalut_upper = 1. / 3.;
AddAttr<float>("upper", "Upper bound of the uniform distribution.")
.SetDefault(defalut_upper)
.AddCustomChecker([](const float& upper) {
PADDLE_ENFORCE_EQ(upper > 0.0f && upper <= 1.0f, true,
platform::errors::InvalidArgument(
"'RRelu_upper' must be between 0.0 and 1.0."));
});
AddComment(R"DOC(
RReLU Operator.

Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:

`Empirical Evaluation of Rectified Activations in Convolutional Network`_.

The function is defined as:

.. math::
\text{RReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \\
ax & \text{ otherwise }
\end{cases}

where :math:`a` is randomly sampled from uniform distribution
:math:`\mathcal{U}(\text{lower}, \text{upper})`.

See: https://arxiv.org/pdf/1505.00853.pdf

)DOC");
}
};

class RReluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};

template <typename T>
class RReluGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("rrelu_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Noise", this->Output("Noise"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rrelu, RReluInferShapeFunctor,
PD_INFER_META(phi::RReluInferMeta));

REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker,
ops::RReluGradOpMaker<paddle::framework::OpDesc>,
ops::RReluGradOpMaker<paddle::imperative::OpBase>,
RReluInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor,
PD_INFER_META(phi::RReluGradInferMeta));
REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor);
49 changes: 49 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,55 @@ void RollInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void RReluInferMeta(const MetaTensor& x,
float lower,
float upper,
bool is_test,
MetaTensor* out,
MetaTensor* noise) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(lower,
0,
phi::errors::InvalidArgument(
"The lower value should be greater than or equal to 0. "
"But received lower value = %f.",
lower));
PADDLE_ENFORCE_LE(upper,
1,
phi::errors::InvalidArgument(
"The upper value should be less than or equal to 1. "
"But received upper value = %f.",
upper));
PADDLE_ENFORCE_GE(
upper,
lower,
phi::errors::InvalidArgument(
"The upper value should be greater than or equal to lower value "
"But received upper value = %f, lower value = %f.",
upper,
lower));

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);

if (noise != nullptr) {
noise->set_dims(x_dims);
noise->set_dtype(x.dtype());
noise->set_layout(x.layout());
}
}

void RReluGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& noise,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
x_grad->set_dims(do_dims);
x_grad->set_dtype(out_grad.dtype());
x_grad->share_lod(out_grad);
}

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,17 @@ void RollInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
MetaTensor* out);

void RReluInferMeta(const MetaTensor& x,
float lower,
float upper,
bool is_test,
MetaTensor* out,
MetaTensor* noise);

void RReluGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& noise,
MetaTensor* x_grad);

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out);

void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/kernels/cpu/rrelu_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rrelu_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void RReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& noise,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* n_ptr = noise.data<T>();
const T* x_ptr = x.data<T>();
const T* out_grad_ptr = out_grad.data<T>();
int numel = x.numel();
if (!x_grad) return;

int i = 0;
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
for (i = 0; i < numel; i++) {
x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i] : n_ptr[i] * out_grad_ptr[i];
}
}

} // namespace phi

PD_REGISTER_KERNEL(
rrelu_grad, CPU, ALL_LAYOUT, phi::RReluGradKernel, float, double) {}
77 changes: 77 additions & 0 deletions paddle/phi/kernels/cpu/rrelu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rrelu_kernel.h"

#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void RReluKernel(const Context& dev_ctx,
const DenseTensor& x,
const float lower,
const float upper,
bool is_test,
DenseTensor* out,
DenseTensor* noise) {
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
T* n_ptr = dev_ctx.template Alloc<T>(noise);
T zero = static_cast<T>(0);
int numel = x.numel();
int i = 0;

if (is_test) {
T mid_val = static_cast<T>((lower + upper) / 2.0);
for (i = 0; i < numel; i++) {
if (x_ptr[i] < zero) {
o_ptr[i] = mid_val * x_ptr[i];
n_ptr[i] = mid_val;
} else {
o_ptr[i] = x_ptr[i];
n_ptr[i] = 1.0;
}
}

return;
}

auto engine = paddle::framework::GetCPURandomEngine(0);

std::uniform_real_distribution<float> dist(lower, upper);

for (i = 0; i < numel; i++) {
if (x_ptr[i] < zero) {
T scale = static_cast<T>(dist(*engine));
o_ptr[i] = scale * x_ptr[i];
n_ptr[i] = scale;
} else {
o_ptr[i] = x_ptr[i];
n_ptr[i] = 1.0;
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(rrelu,
CPU,
ALL_LAYOUT,
phi::RReluKernel,
float,
phi::dtype::float16,
double) {}
Loading