Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type-stability for normalization layers #1856

Merged
merged 9 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ end

function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
stride = expand(Val(N), stride)
pad = calc_padding(MaxPool ,pad, k, 1, stride)
pad = calc_padding(MaxPool, pad, k, 1, stride)
return MaxPool(k, pad, stride)
end

Expand Down
101 changes: 57 additions & 44 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) =
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
used with recurrent hidden states.
The argument `sz` should be an integer or a tuple of integers.
In the forward pass, the layer normalises the mean and standard
used with recurrent hidden states.
The argument `sz` should be an integer or a tuple of integers.
In the forward pass, the layer normalises the mean and standard
deviation of the input, the applied the elementwise activation `λ`.
The input is normalised along the first `length(sz)` dimensions
for tuple `sz`, along the first dimension for integer `sz`.
The input is expected to have first dimensions' size equal to `sz`.
The input is expected to have first dimensions' size equal to `sz`.
If `affine=true` also applies a learnable shift and rescaling
as in the [`Diagonal`](@ref) layer.
Expand Down Expand Up @@ -192,35 +192,48 @@ end
# Compute the statistics on the slices specified by reduce_dims.
# reduce_dims=[1,...,N-2,N] for BatchNorm
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
function _norm_layer_forward(
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
) where {T, N}
if !_isactive(l) && l.track_stats # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
else # trainmode or testmode without tracked stats
else # trainmode or testmode without tracked stats
μ = mean(x; dims=reduce_dims)
σ² = mean((x .- μ).^2; dims=reduce_dims)
if l.track_stats
## update moving mean/std
Zygote.ignore() do
mtm = l.momentum
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
end
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
end
end
if hasaffine(l)
γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β)
else
return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
end

o = _norm_layer_forward(x, μ, σ², l.ϵ)
hasaffine(l) || return l.λ.(o)

γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
end

@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)

function _track_stats!(
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
V = eltype(bn.σ²)
mtm = bn.momentum
res_mtm = one(V) - mtm
m = prod(size(x, i) for i in reduce_dims)

μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))

bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
return nothing
end
Zygote.@nograd _track_stats!

"""
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
Expand All @@ -234,15 +247,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
input slice and normalises the input accordingly.
If `affine=true`, it also applies a shift and a rescale to the input
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias β and scale γ parameters.
After normalisation, elementwise activation `λ` is applied.
After normalisation, elementwise activation `λ` is applied.
If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
Use [`testmode!`](@ref) during inference.
Expand Down Expand Up @@ -272,7 +285,7 @@ mutable struct BatchNorm{F,V,N,W}
end

function BatchNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true,
ϵ=1f-5, momentum=0.1f0)

Expand All @@ -282,8 +295,8 @@ function BatchNorm(chs::Int, λ=identity;
σ² = track_stats ? ones32(chs) : nothing

return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
end

Expand Down Expand Up @@ -318,19 +331,19 @@ end
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
input slice and normalises the input accordingly.
If `affine=true`, it also applies a shift and a rescale to the input
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
in previous Flux versions (< v0.12).
"""
mutable struct InstanceNorm{F,V,N,W}
Expand Down Expand Up @@ -358,7 +371,7 @@ function InstanceNorm(chs::Int, λ=identity;
σ² = track_stats ? ones32(chs) : nothing

return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
end
Expand Down Expand Up @@ -401,13 +414,13 @@ The number of channels must be an integer multiple of the number of groups.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
If `affine=true`, it also applies a shift and a rescale to the input
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
"""
mutable struct GroupNorm{F,V,N,W}
Expand All @@ -429,7 +442,7 @@ end
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)

Expand All @@ -440,11 +453,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
μ = track_stats ? zeros32(G) : nothing
σ² = track_stats ? ones32(G) : nothing

return GroupNorm(G, λ,
return GroupNorm(G, λ,
β, γ,
μ, σ²,
ϵ, momentum,
affine, track_stats,
μ, σ²,
ϵ, momentum,
affine, track_stats,
nothing, chs)
end

Expand Down Expand Up @@ -475,7 +488,7 @@ end
"""
hasaffine(l)
Return `true` if a normalisation layer has trainable shift and
Return `true` if a normalisation layer has trainable shift and
scale parameters, `false` otherwise.
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
Expand Down
33 changes: 29 additions & 4 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,39 +149,50 @@ end
# 1.3
# 1.3
@test m.σ² .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]

x′ = m(x)
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)

@inferred m(x)
end

let m = BatchNorm(2; track_stats=false), x = [1.0 3.0 5.0; 2.0 4.0 6.0]
@inferred m(x)
end

# with activation function
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
y = m(x)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
@inferred m(x)
end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
@inferred m(x)
end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
@inferred m(x)
end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
@inferred m(x)
end

let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
@inferred m(x)
end

@test length(Flux.params(BatchNorm(10))) == 2
Expand Down Expand Up @@ -232,6 +243,8 @@ end
@test length(m.μ) == 2
@test length(m.σ²) == 2
@test y (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5

@inferred m(x)
end

# with activation function
Expand All @@ -242,35 +255,41 @@ end
affine_shape[[1,3]] .= 1

y = evalwgrad(m, x)
y = m(x) # inference time after a training step
y = m(x) # inference time after a training step
μ = reshape(m.μ, affine_shape...)
σ² = reshape(m.σ², affine_shape...)
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7

@inferred m(x)
end

# with activation function
let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)

@test Flux.hasaffine(m) == true
@test length(params(m)) == 2
@test length(params(m)) == 2
x = Float64.(x)
y = m(x)
μ = mean(x, dims=1)
σ² = var(x, dims=1, corrected=false)
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7

@inferred m(x)
end

let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
@test Flux.hasaffine(m) == false
@test length(params(m)) == 0

x = Float64.(x)
y = m(x)
μ = mean(x, dims=1)
σ² = var(x, dims=1, corrected=false)
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7

@inferred m(x)
end


Expand All @@ -279,6 +298,8 @@ end
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y

@inferred m(x)
end

# check that μ, σ², and the output are the correct size for higher rank tensors
Expand All @@ -288,6 +309,8 @@ end
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes

@inferred m(x)
end

# show that instance norm is equal to batch norm when channel and batch dims are squashed
Expand All @@ -299,6 +322,8 @@ end
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000

@inferred m(x)
end

@test length(Flux.params(InstanceNorm(10))) == 0
Expand Down