Skip to content

Commit

Permalink
implement #643
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Mar 8, 2019
1 parent 2f256b3 commit 5514a0f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 50 deletions.
6 changes: 2 additions & 4 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ module Flux
using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient

export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
params, mapleaves, cpu, gpu, f32, f64

@reexport using NNlib

using Zygote

include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
Expand Down
66 changes: 21 additions & 45 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
"""
testmode!(m)
testmode!(m, false)
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
(or back to training mode with `false`).
"""
function testmode!(m, val::Bool=true)
prefor(x -> _testmode!(x, val), m)
return m
end
istraining() = false

_testmode!(m, test) = nothing
@adjoint istraining() = true, _ -> nothing

"""
Dropout(p)
Expand All @@ -23,44 +13,38 @@ Does nothing to the input once in [`testmode!`](@ref).
"""
mutable struct Dropout{F}
p::F
active::Bool
end

function Dropout(p)
@assert 0 p 1
Dropout{typeof(p)}(p, true)
function Dropout(p)
@assert 0 p 1
new{typeof(p)}(p)
end
end

_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)

function (a::Dropout)(x)
a.active || return x
istraining() || return x
y = similar(x)
rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end

_testmode!(a::Dropout, test) = (a.active = !test)

"""
AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks.
A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
"""
mutable struct AlphaDropout{F}
p::F
active::Bool
end

function AlphaDropout(p)
@assert 0 p 1
AlphaDropout(p,true)
function AlphaDropout(p)
@assert 0 p 1
new{typeof(p)}(p)
end
end

function (a::AlphaDropout)(x)
a.active || return x
istraining() || return x
λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α)
Expand All @@ -72,8 +56,6 @@ function (a::AlphaDropout)(x)
return x
end

_testmode!(a::AlphaDropout, test) = (a.active = !test)

"""
LayerNorm(h::Integer)
Expand Down Expand Up @@ -133,13 +115,12 @@ mutable struct BatchNorm{F,V,W,N}
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end

BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, true)
zeros(chs), ones(chs), ϵ, momentum)

function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
Expand All @@ -151,7 +132,7 @@ function (BN::BatchNorm)(x)
m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
if !BN.active
if !istraining()
μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ
Expand All @@ -174,12 +155,10 @@ function (BN::BatchNorm)(x)
end

children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)

mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)

_testmode!(BN::BatchNorm, test) = (BN.active = !test)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
Expand Down Expand Up @@ -226,13 +205,12 @@ mutable struct InstanceNorm{F,V,W,N}
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end

InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, true)
zeros(chs), ones(chs), ϵ, momentum)

function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
Expand All @@ -249,7 +227,7 @@ function (in::InstanceNorm)(x)
m = prod(size(x)[1:end-2])
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)

if !in.active
if !istraining()
μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ
Expand All @@ -274,12 +252,10 @@ function (in::InstanceNorm)(x)
end

children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)

mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)

_testmode!(in::InstanceNorm, test) = (in.active = !test)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)

function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
Expand Down
2 changes: 1 addition & 1 deletion src/treelike.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Adapt: adapt, adapt_storage
import .Zygote: IdSet
import Zygote: IdSet

children(x) = ()
mapchildren(f, x) = x
Expand Down

0 comments on commit 5514a0f

Please sign in to comment.