Skip to content

Commit

Permalink
support selu activation function (apache#12059)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Aug 12, 2018
1 parent 22817be commit 774a75e
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 4 deletions.
4 changes: 1 addition & 3 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,9 @@ class SELU(HybridBlock):
"""
def __init__(self, **kwargs):
super(SELU, self).__init__(**kwargs)
self._scale = 1.0507009873554804934193349852946
self._alpha = 1.6732632423543772848170429916717

def hybrid_forward(self, F, x):
return self._scale * F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))
return F.LeakyReLU(x, act_type='selu', name='fwd')


class Swish(HybridBlock):
Expand Down
19 changes: 18 additions & 1 deletion src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace op {
namespace leakyrelu {
enum LeakyReLUOpInputs {kData, kGamma};
enum LeakyReLUOpOutputs {kOut, kMask};
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU};
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU};
enum LeakyReLUOpResource {kRandom};
} // namespace leakyrelu

Expand All @@ -63,6 +63,7 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
.add_enum("leaky", leakyrelu::kLeakyReLU)
.add_enum("prelu", leakyrelu::kPReLU)
.add_enum("elu", leakyrelu::kELU)
.add_enum("selu", leakyrelu::kSELU)
.describe("Activation function to be applied.");
DMLC_DECLARE_FIELD(slope).set_default(0.25f)
.describe("Init slope for the activation. (For leaky and elu only)");
Expand Down Expand Up @@ -182,6 +183,13 @@ class LeakyReLUOp : public Operator {
});
break;
}
case leakyrelu::kSELU: {
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::selu, Req>, xpu>::Launch(
s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_);
});
break;
}
default:
LOG(FATAL) << "Not implmented";
}
Expand Down Expand Up @@ -270,6 +278,15 @@ class LeakyReLUOp : public Operator {
});
break;
}
case leakyrelu::kSELU: {
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<
mxnet_op::backward_grad_tuned<mshadow_op::selu_grad>, Req>, xpu>::Launch(
s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
output.dptr_);
});
break;
}
default:
LOG(FATAL) << "Not implmented";
}
Expand Down
2 changes: 2 additions & 0 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ when the input is negative and has a slope of one when input is positive.
The following modified ReLU Activation functions are supported:
- *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)`
- *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where
*lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*.
- *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x`
- *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training.
- *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from
Expand Down
10 changes: 10 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ namespace mshadow_op {

#ifdef __CUDA_ARCH__
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
using std::isnan;
#endif
using std::enable_if;
Expand Down Expand Up @@ -126,6 +130,12 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

MXNET_UNARY_MATH_OP_NC(selu_grad,
DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));

MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);

MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu); // NOLINT()
Expand Down
31 changes: 31 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,37 @@ def fprelu_grad(x, y, gamma):
check_symbolic_backward(y, [xa, gam_full], [np.ones(shape), np.ones(gam_full.shape)],
[g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)

@with_seed()
def test_selu():
alpha = 1.6732632423543772848170429916717
lamb = 1.0507009873554804934193349852946
def fselu(x):
neg_indices = x < 0
out = x.copy()
out[neg_indices] = alpha * np.expm1(out[neg_indices])
return out * lamb
def fselu_grad(grad, x, y):
neg_indices = x < 0
out = np.ones(x.shape).astype(x.dtype)
out[neg_indices] = y[neg_indices] + alpha
return out * lamb

shape = (3, 4)
x = mx.sym.Variable("x")
y = mx.sym.LeakyReLU(data=x, act_type="selu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
ya = fselu(xa)
ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya)
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)


@with_seed()
def test_sigmoid():
def fsigmoid(a):
Expand Down

0 comments on commit 774a75e

Please sign in to comment.