Skip to content

Commit

Permalink
Merge pull request #180 from avik-pal/ap/luxlib_parity
Browse files Browse the repository at this point in the history
Add instancenorm and alpha_dropout implementations
  • Loading branch information
avik-pal authored Oct 27, 2022
2 parents c769ccd + a759abd commit 2aeae0d
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/src/lib/LuxLib/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CurrentModule = LuxLib
## Dropout

```@docs
alpha_dropout
dropout
```

Expand All @@ -13,6 +14,7 @@ dropout
```@docs
batchnorm
groupnorm
instancenorm
layernorm
```

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
5 changes: 3 additions & 2 deletions lib/LuxLib/src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ include("impl/normalization.jl")
include("api/batchnorm.jl")
include("api/dropout.jl")
include("api/groupnorm.jl")
include("api/instancenorm.jl")
include("api/layernorm.jl")

export batchnorm, groupnorm, layernorm
export dropout
export batchnorm, groupnorm, instancenorm, layernorm
export alpha_dropout, dropout

end
2 changes: 1 addition & 1 deletion lib/LuxLib/src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Batch Normalization. For details see [1].
Batch Normalization computes the mean and variance for each
``D_1 \times ... \times D_{N - 2} \times 1 \times D_N` input slice and normalises the input
``D_1 \times ... \times D_{N - 2} \times 1 \times D_N`` input slice and normalises the input
accordingly.
## Arguments
Expand Down
55 changes: 53 additions & 2 deletions lib/LuxLib/src/api/dropout.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@doc doc"""
dropout(rng::AbstractRNG, x, p, ::Val{training}; dims, invp=inv(p))
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims, invp=inv(p))
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims,
invp=inv(p))
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Expand Down Expand Up @@ -29,7 +30,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
## References
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
"""
function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}; dims,
invp::T=inv(p)) where {T}
Expand Down Expand Up @@ -62,6 +63,56 @@ function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{
return (x, mask, rng)
end

@doc doc"""
alpha_dropout(rng::AbstractRNG, x, p, ::Val{training})
alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B)
Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the
input. For details see [1]. Use the second call signature to avoid recomputing the constants
for a fixed dropout probability.
## Arguments
- `rng`: Random number generator
- `x`: Input Array
- `p`: Probability of an element to be dropped out
- `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else,
`x` is returned
- `α`: -1.7580993408473766. Computed at limit x tends to infinity, `selu(x) = -λβ = α`
- `A`: Scaling factor for the mean
- `B`: Scaling factor for the variance
## Returns
- Output Array after applying alpha dropout
- Updated state for the random number generator
## References
[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural
information processing systems 30 (2017).
"""
function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) where {T}
α = T(-1.7580993408473766)
A = T(inv(sqrt((1 - p) * (1 + p * α^2))))
B = T(-A * α * p)

return alpha_dropout(rng, x, p, t, α, A, B)
end

function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false})
return alpha_dropout(rng, x, p, t, 0, 0, 0)
end

function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B)
rng = _replicate(rng)
noise = rand!(rng, similar(x))
return (A .* ifelse.(noise .> p, x, α) .+ B), rng
end

alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng)

# Mask Generation
@inline _dropout_shape(s, ::Colon) = size(s)
@inline function _dropout_shape(s, dims)
return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...)
Expand Down
53 changes: 53 additions & 0 deletions lib/LuxLib/src/api/instancenorm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
@doc doc"""
instancenorm(x, scale, bias; epsilon, training)
Instance Normalization. For details see [1].
Instance Normalization computes the mean and variance for each
``D_1 \times ... \times D_{N - 2} \times 1 \times 1``` input slice and normalises the input
accordingly.
## Arguments
- `x`: Input to be Normalized (must be atleast 3D)
- `scale`: Scale factor (``\gamma``) (can be `nothing`)
- `bias`: Bias factor (``\beta``) (can be `nothing`)
## Keyword Arguments
- `epsilon`: Value added to the denominator for numerical stability
- `training`: Set to `Val(true)` if running in training mode
## Returns
Normalized Array of same size as `x`. And a Named Tuple containing the updated running
mean and variance.
## References
[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The
missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
"""
function instancenorm(x::AbstractArray{<:Real, N},
scale::Union{AbstractVector{<:Real}, Nothing},
bias::Union{AbstractVector{<:Real}, Nothing}; training::Val,
epsilon::Real) where {N}
_test_valid_instancenorm_arguments(x)

x_, xm, xv = _normalization(x, nothing, nothing, scale, bias,
_get_instancenorm_reduce_dims(x), training, zero(eltype(x)),
epsilon)

return x_, (; running_mean=xm, running_var=xv)
end

@generated function _get_instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N}
return :($(Val(Tuple([1:(N - 2)]...))))
end

function _test_valid_instancenorm_arguments(x::AbstractArray{T, N}) where {T, N}
N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2."))
return nothing
end

CRC.@non_differentiable _test_valid_instancenorm_arguments(::Any...)
74 changes: 69 additions & 5 deletions lib/LuxLib/test/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ rng = MersenneTwister(0)

println("DRP_CPU: $T $(x_shape)")

x = randn(T, x_shape)
x = randn(rng, T, x_shape)

@inferred dropout(rng, x, T(0.5), Val(true); dims=Colon())

Expand Down Expand Up @@ -44,7 +44,7 @@ rng = MersenneTwister(0)

println("DRP_GPU: $T $(x_shape)")

x = CUDA.randn(T, x_shape)
x = T.(cu(randn(rng, T, x_shape)))

@inferred dropout(rng, x, T(0.5), Val(true); dims=Colon())

Expand All @@ -71,14 +71,78 @@ rng = MersenneTwister(0)
end
end

@testset "Alpha Dropout" begin
if cpu_testing()
for T in (Float16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

println("ADRP_CPU: $T $(x_shape)")

x = randn(rng, T, x_shape)

@inferred alpha_dropout(rng, x, T(0.5), Val(true))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true))

@test y isa Array{T, length(x_shape)}
@test size(y) == x_shape
@test rng != rng_
# @test isapprox(std(y), std(x); atol=0.4, rtol=0.4)

__f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2)

@inferred alpha_dropout(rng, x, T(0.5), Val(false))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false))

@test y isa Array{T, length(x_shape)}
@test size(y) == x_shape
@test rng == rng_
@test y == x
end
end

if gpu_testing()
for T in (Float16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

println("ADRP_GPU: $T $(x_shape)")

x = T.(cu(randn(rng, T, x_shape)))

@inferred alpha_dropout(rng, x, T(0.5), Val(true))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true))

@test y isa CuArray{T, length(x_shape)}
@test size(y) == x_shape
@test rng != rng_
# @test isapprox(std(y), std(x); atol=0.4, rtol=0.4)

# __f = x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
# test_gradient_correctness_fdm(__f, x; atol=1.0f-2, rtol=1.0f-2)

@inferred alpha_dropout(rng, x, T(0.5), Val(false))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false))

@test y isa CuArray{T, length(x_shape)}
@test size(y) == x_shape
@test rng == rng_
@test y == x
end
end
end

@testset "Dropout with Preset Mask" begin
if cpu_testing()
for T in (Float16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

println("DRP_CPU: $T $(x_shape)")

x = randn(T, x_shape)
x = randn(rng, T, x_shape)
mask = rand(T, x_shape)

# Update mask
Expand Down Expand Up @@ -154,8 +218,8 @@ end

println("DRP_GPU: $T $(x_shape)")

x = CUDA.randn(T, x_shape)
mask = CUDA.rand(T, x_shape)
x = T.(cu(randn(rng, T, x_shape)))
mask = T.(cu(rand(rng, T, x_shape)))

# Update mask
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon())
Expand Down
Loading

2 comments on commit 2aeae0d

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxLib

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/71146

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxLib-v0.1.7 -m "<description of version>" 2aeae0d4af08aae8ad3e82fe806bc8af6fa278bf
git push origin LuxLib-v0.1.7

Please sign in to comment.