Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support selu activation function #12059

Merged
merged 1 commit into from
Aug 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,9 @@ class SELU(HybridBlock):
"""
def __init__(self, **kwargs):
super(SELU, self).__init__(**kwargs)
self._scale = 1.0507009873554804934193349852946
self._alpha = 1.6732632423543772848170429916717

def hybrid_forward(self, F, x):
return self._scale * F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))
return F.LeakyReLU(x, act_type='selu', name='fwd')


class Swish(HybridBlock):
Expand Down
19 changes: 18 additions & 1 deletion src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace op {
namespace leakyrelu {
enum LeakyReLUOpInputs {kData, kGamma};
enum LeakyReLUOpOutputs {kOut, kMask};
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU};
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU};
enum LeakyReLUOpResource {kRandom};
} // namespace leakyrelu

Expand All @@ -63,6 +63,7 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
.add_enum("leaky", leakyrelu::kLeakyReLU)
.add_enum("prelu", leakyrelu::kPReLU)
.add_enum("elu", leakyrelu::kELU)
.add_enum("selu", leakyrelu::kSELU)
.describe("Activation function to be applied.");
DMLC_DECLARE_FIELD(slope).set_default(0.25f)
.describe("Init slope for the activation. (For leaky and elu only)");
Expand Down Expand Up @@ -182,6 +183,13 @@ class LeakyReLUOp : public Operator {
});
break;
}
case leakyrelu::kSELU: {
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::selu, Req>, xpu>::Launch(
s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_);
});
break;
}
default:
LOG(FATAL) << "Not implmented";
}
Expand Down Expand Up @@ -270,6 +278,15 @@ class LeakyReLUOp : public Operator {
});
break;
}
case leakyrelu::kSELU: {
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<
mxnet_op::backward_grad_tuned<mshadow_op::selu_grad>, Req>, xpu>::Launch(
s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
output.dptr_);
});
break;
}
default:
LOG(FATAL) << "Not implmented";
}
Expand Down
2 changes: 2 additions & 0 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ when the input is negative and has a slope of one when input is positive.
The following modified ReLU Activation functions are supported:

- *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)`
- *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where
*lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*.
- *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x`
- *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training.
- *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from
Expand Down
10 changes: 10 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ namespace mshadow_op {

#ifdef __CUDA_ARCH__
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
using std::isnan;
#endif
using std::enable_if;
Expand Down Expand Up @@ -126,6 +130,12 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

MXNET_UNARY_MATH_OP_NC(selu_grad,
DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));

MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);

MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu); // NOLINT()
Expand Down
31 changes: 31 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,37 @@ def fprelu_grad(x, y, gamma):
check_symbolic_backward(y, [xa, gam_full], [np.ones(shape), np.ones(gam_full.shape)],
[g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)

@with_seed()
def test_selu():
alpha = 1.6732632423543772848170429916717
lamb = 1.0507009873554804934193349852946
def fselu(x):
neg_indices = x < 0
out = x.copy()
out[neg_indices] = alpha * np.expm1(out[neg_indices])
return out * lamb
def fselu_grad(grad, x, y):
neg_indices = x < 0
out = np.ones(x.shape).astype(x.dtype)
out[neg_indices] = y[neg_indices] + alpha
return out * lamb

shape = (3, 4)
x = mx.sym.Variable("x")
y = mx.sym.LeakyReLU(data=x, act_type="selu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
ya = fselu(xa)
ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya)
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)


@with_seed()
def test_sigmoid():
def fsigmoid(a):
Expand Down