Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 10, 2020
1 parent 25b5f6e commit a0661ff
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 165 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ Requires = "0.5, 1.0"
julia = "1.3"

[extras]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Zygote"]
test = ["Test", "Statistics", "Zygote"]
9 changes: 3 additions & 6 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,16 @@ end
∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
∇logsoftmax!(Δ, xs) = ∇logsoftmax!(Δ, Δ, xs)



"""
logsumexp(x; dims=:)
Computes `log(sum(exp.(x); dims=dims)) in a numerically stable
Computes `log.(sum(exp.(x); dims=dims))` in a numerically stable
way.
See also [`logsoftmax`](@ref).
"""
function logsumexp(xs::AbstractArray; dims=:)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
log_ = log.(sum(exp_, dims=dims))
max_ .+ log_
log_ = log.(sum(exp.(xs .- max_), dims=dims))
return max_ .+ log_
end
295 changes: 148 additions & 147 deletions test/activation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using NNlib, Test, Zygote

ACTIVATION_FUNCTIONS = [σ, hardσ, logσ, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu, softplus, softsign, logcosh, mish, tanhshrink, softshrink];
ACTIVATION_FUNCTIONS = [σ, hardσ, logσ, hardtanh, relu, leakyrelu,
relu6, rrelu, elu, gelu, celu, swish, lisht,
selu, trelu, softplus, softsign, logcosh, mish,
tanhshrink, softshrink];

function test_value_float_precision_preserving(a)
@testset "$(a): " begin
Expand Down Expand Up @@ -35,174 +38,172 @@ function test_gradient_float_precision_preserving(a)
end
end

@testset "Activation Functions" begin
@test σ(0.0) == 0.5
@test hardσ(0.0) == 0.5
@test hardtanh(0.0) == 0.0
@test relu(0.0) == 0.0
@test leakyrelu(0.0) == 0.0
@test relu6(0.0) == 0.0
@test rrelu(0.0) == 0.0
@test elu(0.0) == 0.0
@test gelu(0.0) == 0.0
@test swish(0.0) == 0.0
@test lisht(0.0) == 0.0
@test softplus(0.0) log(2.0)
@test softplus(1e8) 1e8
@test softplus(-1e8) 0.0
@test softsign(0.0) == 0.0
@test selu(0.0) == 0.0
@test celu(0.0) == 0.0
@test trelu(0.0) == 0.0
@test logcosh(0.0) == log(cosh(0.0))
@test mish(0.0) == 0.0
@test tanhshrink(0.0) == 0.0
@test softshrink(0.0) == 0.0

@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
@test hardσ(1.0) == max(0,min(1,0.2*1.0 + 0.5))
@test hardtanh(1.0) == 1.0
@test relu(1.0) == 1.0
@test leakyrelu(1.0) == 1.0
@test relu6(1.0) == 1.0
@test rrelu(1.0) == 1.0
@test elu(1.0) == 1.0
@test gelu(1.0) == 0.8411919906082768
@test swish(1.0) == σ(1.0)
@test lisht(1.0) 1.0 * tanh(1.0)
@test softplus(1.0) log(exp(1.0) + 1.0)
@test softsign(1.0) == 0.5
@test selu(1.0) == 1.0507009873554804934193349852946
@test celu(1.0) == 1.0
@test trelu(1.0) == 0.0
@test logcosh(1.0) log(cosh(1.0))
@test mish(1.0) tanh(log(1.0 + exp(1.0)))
@test tanhshrink(1.0) 0.23840584404423515
@test softshrink(1.0) == 0.5

@test σ(-1.0) == exp(-1.0) / (1.0 + exp(-1.0))
@test hardσ(-1.0) == max(0,min(1,0.2*-1.0 + 0.5))
@test hardtanh(-1.0) == -1.0
@test relu(-1.0) == 0.0
@test leakyrelu(-1.0) == -0.01
@test relu6(-1.0) == 0.0
@test -1/3.0 <= rrelu(-1.0) <= -1/8.0
@test elu(-1.0) == exp(-1.0) - 1.0
@test gelu(-1.0) == -0.15880800939172324
@test swish(-1.0) == -σ(-1.0)
@test lisht(-1.0) -1.0 * tanh(-1.0)
@test softplus(-1.0) log(exp(-1.0) + 1.0)
@test softsign(-1.0) == -0.5
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
@test celu(-1.0) == exp(-1.0) - 1
@test trelu(-1.0) == 0.0
@test log(cosh(-1.0)) log(cosh(-1.0))
@test mish(-1.0) -tanh(log(1.0 + exp(-1.0)))
@test tanhshrink(-1.0) -0.23840584404423515
@test softshrink(-1.0) == -0.5

@testset "Float inference" begin
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
end
@test σ(0.0) == 0.5
@test hardσ(0.0) == 0.5
@test hardtanh(0.0) == 0.0
@test relu(0.0) == 0.0
@test leakyrelu(0.0) == 0.0
@test relu6(0.0) == 0.0
@test rrelu(0.0) == 0.0
@test elu(0.0) == 0.0
@test gelu(0.0) == 0.0
@test swish(0.0) == 0.0
@test lisht(0.0) == 0.0
@test softplus(0.0) log(2.0)
@test softplus(1e8) 1e8
@test softplus(-1e8) 0.0
@test softsign(0.0) == 0.0
@test selu(0.0) == 0.0
@test celu(0.0) == 0.0
@test trelu(0.0) == 0.0
@test logcosh(0.0) == log(cosh(0.0))
@test mish(0.0) == 0.0
@test tanhshrink(0.0) == 0.0
@test softshrink(0.0) == 0.0

@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
@test hardσ(1.0) == max(0,min(1,0.2*1.0 + 0.5))
@test hardtanh(1.0) == 1.0
@test relu(1.0) == 1.0
@test leakyrelu(1.0) == 1.0
@test relu6(1.0) == 1.0
@test rrelu(1.0) == 1.0
@test elu(1.0) == 1.0
@test gelu(1.0) == 0.8411919906082768
@test swish(1.0) == σ(1.0)
@test lisht(1.0) 1.0 * tanh(1.0)
@test softplus(1.0) log(exp(1.0) + 1.0)
@test softsign(1.0) == 0.5
@test selu(1.0) == 1.0507009873554804934193349852946
@test celu(1.0) == 1.0
@test trelu(1.0) == 0.0
@test logcosh(1.0) log(cosh(1.0))
@test mish(1.0) tanh(log(1.0 + exp(1.0)))
@test tanhshrink(1.0) 0.23840584404423515
@test softshrink(1.0) == 0.5

@test σ(-1.0) == exp(-1.0) / (1.0 + exp(-1.0))
@test hardσ(-1.0) == max(0,min(1,0.2*-1.0 + 0.5))
@test hardtanh(-1.0) == -1.0
@test relu(-1.0) == 0.0
@test leakyrelu(-1.0) == -0.01
@test relu6(-1.0) == 0.0
@test -1/3.0 <= rrelu(-1.0) <= -1/8.0
@test elu(-1.0) == exp(-1.0) - 1.0
@test gelu(-1.0) == -0.15880800939172324
@test swish(-1.0) == -σ(-1.0)
@test lisht(-1.0) -1.0 * tanh(-1.0)
@test softplus(-1.0) log(exp(-1.0) + 1.0)
@test softsign(-1.0) == -0.5
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
@test celu(-1.0) == exp(-1.0) - 1
@test trelu(-1.0) == 0.0
@test log(cosh(-1.0)) log(cosh(-1.0))
@test mish(-1.0) -tanh(log(1.0 + exp(-1.0)))
@test tanhshrink(-1.0) -0.23840584404423515
@test softshrink(-1.0) == -0.5

@testset "Float inference" begin
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
end

@testset "Array input" begin
x = rand(5)
for a in ACTIVATION_FUNCTIONS
@test_throws ErrorException a(x)
end
@testset "Array input" begin
x = rand(5)
for a in ACTIVATION_FUNCTIONS
@test_throws ErrorException a(x)
end
end

@testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin
test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS))

@testset "relu: " begin
# relu doesn't have to force floating point outputs
@test typeof(relu(Int64(1))) == Int64
@test typeof(relu(Int32(1))) == Int32
end
@testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin
test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS))

@testset "relu6: " begin
# relu6 doesn't have to force floating point outputs
@test typeof(relu6(Int64(1))) == Int64
@test typeof(relu6(Int32(1))) == Int32
end
@testset "relu: " begin
# relu doesn't have to force floating point outputs
@test typeof(relu(Int64(1))) == Int64
@test typeof(relu(Int32(1))) == Int32
end

@testset "hardtanh: " begin
# hardtanh doesn't have to force floating point outputs
@test typeof(hardtanh(Int64(1))) == Int64
@test typeof(hardtanh(Int32(1))) == Int32
end
@testset "relu6: " begin
# relu6 doesn't have to force floating point outputs
@test typeof(relu6(Int64(1))) == Int64
@test typeof(relu6(Int32(1))) == Int32
end

@testset "trelu: " begin
# trelu doesn't have to force floating point outputs
@test typeof(trelu(Int64(1))) == Int64
@test typeof(trelu(Int32(1))) == Int32
end
@testset "hardtanh: " begin
# hardtanh doesn't have to force floating point outputs
@test typeof(hardtanh(Int64(1))) == Int64
@test typeof(hardtanh(Int32(1))) == Int32
end

@testset "Float gradient inference" begin
test_gradient_float_precision_preserving.(ACTIVATION_FUNCTIONS)
@testset "trelu: " begin
# trelu doesn't have to force floating point outputs
@test typeof(trelu(Int64(1))) == Int64
@test typeof(trelu(Int32(1))) == Int32
end
end

@testset "elu" begin
@test elu(42) == 42
@test elu(42.) == 42.
@testset "Float gradient inference" begin
test_gradient_float_precision_preserving.(ACTIVATION_FUNCTIONS)
end

@test elu(-4) (exp(-4) - 1)
end
@testset "elu" begin
@test elu(42) == 42
@test elu(42.) == 42.

@testset "mish" begin
@test mish(-5) -0.033576237730161704
@test mish(9) == 9*tanh(log(1 + exp(9)))
xs = Float32[1 2 3; 1000 2000 3000]
@test typeof(mish.(xs)) == typeof(xs)
end
@test elu(-4) (exp(-4) - 1)
end

@test leakyrelu( 0.4,0.3) 0.4
@test leakyrelu(-0.4,0.3) -0.12
@testset "mish" begin
@test mish(-5) -0.033576237730161704
@test mish(9) == 9*tanh(log(1 + exp(9)))
xs = Float32[1 2 3; 1000 2000 3000]
@test typeof(mish.(xs)) == typeof(xs)
end

@test relu6(10.0) == 6.0
@test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1
@test leakyrelu( 0.4,0.3) 0.4
@test leakyrelu(-0.4,0.3) -0.12

@testset "celu" begin
@test celu(42) == 42
@test celu(42.) == 42.
@test relu6(10.0) == 6.0
@test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1

@test celu(-4, 0.5) 0.5*(exp(-4.0/0.5) - 1)
end
@testset "celu" begin
@test celu(42) == 42
@test celu(42.) == 42.

@testset "softshrink" begin
@test softshrink(15., 5.) == 10.
@test softshrink(4., 5.) == 0.
@test softshrink(-15., 5.) == -10.
end
@test celu(-4, 0.5) 0.5*(exp(-4.0/0.5) - 1)
end

@testset "logsigmoid" begin
xs = randn(10,10)
@test logsigmoid.(xs) log.(sigmoid.(xs))
for T in [:Float32, :Float64]
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
end
@testset "softshrink" begin
@test softshrink(15., 5.) == 10.
@test softshrink(4., 5.) == 0.
@test softshrink(-15., 5.) == -10.
end

@testset "logsigmoid" begin
xs = randn(10,10)
@test logsigmoid.(xs) log.(sigmoid.(xs))
for T in [:Float32, :Float64]
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
end
end

@test logcosh(1_000.0) + log(2) == 1_000.0
@test logcosh(1_000.0) + log(2) == 1_000.0

@testset "hardsigmoid" begin
@test hardsigmoid(0.3) == 0.56
@test hardsigmoid(-0.3) == 0.44
@test hardsigmoid(0.1,0.5) == 0.55
for T in [:Float32, :Float64]
@eval @test hardsigmoid.($T[-100_000, 100_000.]) $T[0., 1.]
end
@testset "hardsigmoid" begin
@test hardsigmoid(0.3) == 0.56
@test hardsigmoid(-0.3) == 0.44
@test hardsigmoid(0.1,0.5) == 0.55
for T in [:Float32, :Float64]
@eval @test hardsigmoid.($T[-100_000, 100_000.]) $T[0., 1.]
end
end

@test hardtanh(10.0) == 1.0
@test lisht(2.5) == 2.5*tanh(2.5)
@test hardtanh(10.0) == 1.0
@test lisht(2.5) == 2.5*tanh(2.5)

@testset "trelu" begin
@test trelu(0.5) == 0.0
@test trelu(1.0) == 0.0
@test trelu(1.1) == 1.1
@test trelu(0.9,0.5) == 0.9
end
@testset "trelu" begin
@test trelu(0.5) == 0.0
@test trelu(1.0) == 0.0
@test trelu(1.1) == 1.1
@test trelu(0.9,0.5) == 0.9
end
25 changes: 19 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
using NNlib, Test
using NNlib, Test, Statistics

include("activation.jl")
include("conv.jl")
include("batchedmul.jl")
include("pooling.jl")
include("inference.jl")
@testset "Activation Functions" begin
include("activation.jl")
end
@testset "Batched Multiplication" begin
include("batchedmul.jl")
end
@testset "Convolution" begin
include("conv.jl")
end
@testset "Inference" begin
include("inference.jl")
end
@testset "Pooling" begin
include("pooling.jl")
end
@testset "Softmax" begin
include("softmax.jl")
end
Loading

0 comments on commit a0661ff

Please sign in to comment.