Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon】4、在 Paddle 中新增 RReLU #37047

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_activation_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,31 @@ def test_grad(self):
self.func(p)


class TestRReluGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.005
dtype = np.float64
seed = 2022

x = layers.data('x', shape, False, dtype)
x.persistable = True

y = F.rrelu(x, seed=seed)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.02

gradient_checker.grad_check([x], y, x_init=x_arr, place=place, eps=eps)

def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
for p in places:
self.func(p)


class TestELUDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
Expand Down
54 changes: 54 additions & 0 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,60 @@ def test_errors(self):
F.leaky_relu(x_fp16)


class TestRReluAPI(unittest.TestCase):
# test paddle.nn.RReLU, paddle.nn.functional.rrelu,
def setUp(self):
np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.one_np = np.array([-1.]).astype('float32')
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
minus_one = paddle.fluid.data('one', [1])
rand_alpha = F.rrelu(minus_one, seed=2022)
x = paddle.fluid.data('X', [10, 12])
out1 = F.rrelu(x, seed=2022)
m = paddle.nn.RReLU(seed=2022)
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np,
'one': self.one_np},
fetch_list=[out1, out2, rand_alpha])
out_ref = ref_leaky_relu(self.x_np, alpha=-res[2])
for r in range(2):
self.assertEqual(np.allclose(out_ref, res[r]), True)

def test_dygraph_api(self):
paddle.disable_static(self.place)
rand_alpha = F.rrelu(paddle.to_tensor(-1.), seed=2022)
x = paddle.to_tensor(self.x_np)
out1 = F.rrelu(x, seed=2022)
m = paddle.nn.RReLU(seed=2022)
out2 = m(x)
out_ref = ref_leaky_relu(self.x_np, alpha=-rand_alpha.numpy().item())
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)

paddle.enable_static()

def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.rrelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.rrelu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.rrelu(x_fp16)


def gelu(x, approximate):
if approximate:
y_ref = 0.5 * x * (1.0 + np.tanh(
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/test_imperative_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def test_layer_str(self):
module = nn.LeakyReLU()
self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)')

module = nn.RReLU()
self.assertEqual(str(module), 'RReLU(lower=0.125, upper=0.333, seed=0)')

module = nn.Sigmoid()
self.assertEqual(str(module), 'Sigmoid()')

Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .layer.activation import SELU # noqa: F401
from .layer.activation import Silu # noqa: F401
from .layer.activation import LeakyReLU # noqa: F401
from .layer.activation import RReLU # noqa: F401
from .layer.activation import Sigmoid # noqa: F401
from .layer.activation import Hardsigmoid # noqa: F401
from .layer.activation import LogSigmoid # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .activation import hardsigmoid # noqa: F401
from .activation import hardswish # noqa: F401
from .activation import leaky_relu # noqa: F401
from .activation import rrelu # noqa: F401
from .activation import log_sigmoid # noqa: F401
from .activation import maxout # noqa: F401
from .activation import prelu # noqa: F401
Expand Down
56 changes: 56 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle
from paddle import _C_ops
import numpy as np

__all__ = []

Expand Down Expand Up @@ -435,6 +436,61 @@ def leaky_relu(x, negative_slope=0.01, name=None):
return out


def rrelu(x, lower=0.125, upper=0.333, seed=0, name=None):
r"""
rrelu activation

.. math::
leaky\_relu(x)=
\left\{
\begin{array}{rcl}
x, & & if \ x >= 0 \\
negative\_slope * x, & & otherwise \\
\end{array}
\right.
negative\_slope~U(lower,upper)

Args:
x (Tensor): The input Tensor with data type float32, float64.
lower(float, optional): The lower bound of the uniform distribution. Default is 0.125.
upper(float, optional): The upper bound of the uniform distribution. Default is 0.333.
seed(int, optional): The random seed of uniform distribution engin. If seed is 0,
it will use the seed of the global default generator (which can be set by paddle.seed).
Note that if seed is not 0, this operator will always generate the same random negative_slope every
time. Default is 0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Returns:
A Tensor with the same data type and shape as ``x`` .

Examples:
.. code-block:: python

import paddle
import paddle.nn.functional as F

x = paddle.to_tensor([-2., 0., 1.])
out = F.rrelu(x) # [-0.02, 0., 1.]

"""
np.random.seed(seed)
negative_slope = np.random.uniform(lower, upper, [1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

negative_slope 应该对输入Tensor中每个元素做random,而不是一起做。
这个op应该是需要自己实现的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhiboniu 您好,请问是否可以通过逐元素乘一个被mask的随机tensor实现(输入大于0的位置被置为1),还有就是是否可以只提供nn.Layer接口而不提供nn.functional接口

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

乘一个被mask的随机tensor并不是一个好的实现。
nn.Layer实际就是把nn.functional函数封装成了类而已

if in_dygraph_mode():
return _C_ops.leaky_relu(x, 'alpha', negative_slope)

check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'leaky_relu')
helper = LayerHelper('leaky_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='leaky_relu',
inputs={'X': x},
outputs={'Out': out},
attrs={'alpha': negative_slope})
return out


def prelu(x, weight, name=None):
"""
prelu activation.
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .activation import ReLU # noqa: F401
from .activation import ReLU6 # noqa: F401
from .activation import LeakyReLU # noqa: F401
from .activation import RReLU # noqa: F401
from .activation import Sigmoid # noqa: F401
from .activation import Softmax # noqa: F401
from .activation import LogSoftmax # noqa: F401
Expand Down
57 changes: 57 additions & 0 deletions python/paddle/nn/layer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,63 @@ def extra_repr(self):
return 'negative_slope={}{}'.format(self._negative_slope, name_str)


class RReLU(Layer):
r"""
Randomized Leaky ReLU Activation.

.. math::

RReLU(x)=
\left\{
\begin{array}{rcl}
x, & & if \ x >= 0 \\
negative\_slope * x, & & otherwise \\
\end{array}
\right.
negative\_slope~U(lower,upper)


Parameters:
lower(float, optional): The lower bound of the uniform distribution. Default is 0.125.
upper(float, optional): The upper bound of the uniform distribution. Default is 0.333.
seed(int, optional): The random seed of uniform distribution engin. If seed is 0,
it will use the seed of the global default generator (which can be set by paddle.seed).
Note that if seed is not 0, this operator will always generate the same random negative_slope every
time. Default is 0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.

Examples:
.. code-block:: python

import paddle
import numpy as np

m = paddle.nn.RReLU()
x = paddle.to_tensor(np.array([-2, 0, 1], 'float32'))
out = m(x) # [-0.02, 0., 1.]
"""

def __init__(self, lower=0.125, upper=0.333, seed=0, name=None):
super(RReLU, self).__init__()
self._lower = lower
self._upper = upper
self._seed = seed
self._name = name

def forward(self, x):
return F.rrelu(x, self._lower, self._upper, self._seed, self._name)

def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'lower={}, upper={}, seed={}{}'.format(self._lower, self._upper,
self._seed, name_str)


class Sigmoid(Layer):
"""
this interface is used to construct a callable object of the ``Sigmoid`` class. This layer calcluate the `sigmoid` of input x.
Expand Down