Skip to content

Commit

Permalink
Merge pull request #347 from mcabbott/leakyrelu
Browse files Browse the repository at this point in the history
Improve a few activation functions
  • Loading branch information
CarloLucibello authored Oct 18, 2021
2 parents 20c110b + b6a8964 commit eb2f248
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
# Aliases
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu

# of type float
# of type float (to allow for integer inputs)
oftf(x, y) = oftype(float(x), y)

"""
Expand All @@ -34,11 +34,11 @@ end
const sigmoid = σ

"""
hardσ(x) = max(0, min(1, (x + 3) / 6)
hardσ(x) = max(0, min(1, (x + 3) / 6))
Piecewise linear approximation of sigmoid.
"""
hardσ(x) = max(0, min(1, (x + 3) / 6))
hardσ(x) = clamp((x + 3) / 6, 0, 1)

# https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html

Expand All @@ -58,15 +58,15 @@ const logsigmoid = logσ
Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh.
See [Large Scale Machine Learning](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf).
"""
hardtanh(x) = max(-one(x), min(one(x), x))
hardtanh(x) = clamp(x, oftype(x, -1), oftype(x, 1)) # clamp(x, -1, 1) is type-stable, but would promote Int32, for which we have tests

"""
relu(x) = max(0, x)
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
activation function.
"""
relu(x) = max(zero(x), x)
relu(x) = ifelse(x<0, zero(x), x) # faster than max(zero(x), x), still preserves NaN

"""
leakyrelu(x, a=0.01) = max(a*x, x)
Expand All @@ -75,7 +75,7 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne
activation function.
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
"""
leakyrelu(x, a=oftf(x, 0.01)) = max(a * x, x)
leakyrelu(x, a=oftf(x, 0.01)) = ifelse(x>0, float(x), oftf(x, a*x)) # max(a*x, x) is 3x slower

"""
relu6(x) = min(max(0, x), 6)
Expand All @@ -84,7 +84,7 @@ leakyrelu(x, a=oftf(x, 0.01)) = max(a * x, x)
activation function capped at 6.
See [Convolutional Deep Belief Networks on CIFAR-10](https://www.cs.toronto.edu/~kriz/conv-cifar10-aug2010.pdf)
"""
relu6(x) = min(relu(x), oftype(x, 6))
relu6(x) = clamp(x, oftype(x, 0), oftype(x, 6)) # clamp promotes, but clamp(x, 0, 6) would promote x::Int32

"""
rrelu(x, l=1/8, u=1/3) = max(a*x, x)
Expand All @@ -109,7 +109,7 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
"""
elu(x, α=1) = ifelse(x 0, float(x), α * (exp(x) - 1))

deriv_elu(Ω, α=1) = ifelse 0, 1, Ω + α)
deriv_elu(Ω, α=1) = ifelse 0, one(Ω), Ω + α)

"""
gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
Expand Down Expand Up @@ -179,7 +179,7 @@ celu(x, α=1) = ifelse(x ≥ 0, float(x), α * (exp(x/α) - 1))
Threshold Gated Rectified Linear.
See [ThresholdRelu](https://arxiv.org/abs/1402.3337)
"""
trelu(x, theta=1) = ifelse(x > theta, x, zero(x))
trelu(x, theta=1) = ifelse(x <= theta, zero(x), x)

const thresholdrelu = trelu

Expand Down Expand Up @@ -260,14 +260,17 @@ for (f, df) in UNARY_ACTS
@eval function rrule(::typeof(broadcasted),
::typeof($f), x::Numeric)
Ω = $f.(x)
function $pullback(Δ)
NoTangent(), NoTangent(), @.* $df)
function $pullback(Δ)
x_thunk = InplaceableThunk(
dx -> @.(dx += Δ * $df),
@thunk @.* $df)
)
NoTangent(), NoTangent(), x_thunk
end
return Ω, $pullback
end
end


BINARY_ACTS = [ # f, df1, df2
(:elu, :(deriv_elu(Ω, x2)), :(NoTangent())), # TODO use real deriv instead of DNE
]
Expand Down
15 changes: 15 additions & 0 deletions test/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ end
end
end

@testset "NaN propagation" begin
@testset "$a" for a in ACTIVATION_FUNCTIONS
# With NaN input, all should produce NaN output:
@test isnan(a(NaN32))

# Ideally +-Inf would not lead to NaN, but perhaps
# these aren't worth the complication of fixing:
a == softsign && continue
@test !isnan(a(Inf32))

a in [gelu, swish, logcosh, mish] && continue
@test !isnan(a(-Inf32))
end
end

@testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin
test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS))

Expand Down

0 comments on commit eb2f248

Please sign in to comment.