Skip to content

Commit

Permalink
Merge pull request #1856 from pxl-th/master
Browse files Browse the repository at this point in the history
Fix type-stability for normalization layers
  • Loading branch information
ToucheSir authored Feb 3, 2022
2 parents 8d3b8d3 + d151080 commit 5244ade
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ end

function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
stride = expand(Val(N), stride)
pad = calc_padding(MaxPool ,pad, k, 1, stride)
pad = calc_padding(MaxPool, pad, k, 1, stride)
return MaxPool(k, pad, stride)
end

Expand Down
101 changes: 57 additions & 44 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) =
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
used with recurrent hidden states.
The argument `sz` should be an integer or a tuple of integers.
In the forward pass, the layer normalises the mean and standard
used with recurrent hidden states.
The argument `sz` should be an integer or a tuple of integers.
In the forward pass, the layer normalises the mean and standard
deviation of the input, the applied the elementwise activation `λ`.
The input is normalised along the first `length(sz)` dimensions
for tuple `sz`, along the first dimension for integer `sz`.
The input is expected to have first dimensions' size equal to `sz`.
The input is expected to have first dimensions' size equal to `sz`.
If `affine=true` also applies a learnable shift and rescaling
as in the [`Diagonal`](@ref) layer.
Expand Down Expand Up @@ -192,35 +192,48 @@ end
# Compute the statistics on the slices specified by reduce_dims.
# reduce_dims=[1,...,N-2,N] for BatchNorm
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
function _norm_layer_forward(
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
) where {T, N}
if !_isactive(l) && l.track_stats # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
else # trainmode or testmode without tracked stats
else # trainmode or testmode without tracked stats
μ = mean(x; dims=reduce_dims)
σ² = mean((x .- μ).^2; dims=reduce_dims)
if l.track_stats
## update moving mean/std
Zygote.ignore() do
mtm = l.momentum
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
end
_track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std
end
end
if hasaffine(l)
γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β)
else
return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
end

o = _norm_layer_forward(x, μ, σ², l.ϵ)
hasaffine(l) || return l.λ.(o)

γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
end

@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)

function _track_stats!(
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
V = eltype(bn.σ²)
mtm = bn.momentum
res_mtm = one(V) - mtm
m = prod(size(x, i) for i in reduce_dims)

μnew = vec(N reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N reduce_dims ? σ² : mean(σ², dims=N))

bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
return nothing
end
Zygote.@nograd _track_stats!

"""
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
Expand All @@ -234,15 +247,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
input slice and normalises the input accordingly.
If `affine=true`, it also applies a shift and a rescale to the input
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias β and scale γ parameters.
After normalisation, elementwise activation `λ` is applied.
After normalisation, elementwise activation `λ` is applied.
If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
Use [`testmode!`](@ref) during inference.
Expand Down Expand Up @@ -272,7 +285,7 @@ mutable struct BatchNorm{F,V,N,W}
end

function BatchNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true,
ϵ=1f-5, momentum=0.1f0)

Expand All @@ -282,8 +295,8 @@ function BatchNorm(chs::Int, λ=identity;
σ² = track_stats ? ones32(chs) : nothing

return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
end

Expand Down Expand Up @@ -318,19 +331,19 @@ end
[Instance Normalization](https://arxiv.org/abs/1607.08022) layer.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1`
input slice and normalises the input accordingly.
If `affine=true`, it also applies a shift and a rescale to the input
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
in previous Flux versions (< v0.12).
"""
mutable struct InstanceNorm{F,V,N,W}
Expand Down Expand Up @@ -358,7 +371,7 @@ function InstanceNorm(chs::Int, λ=identity;
σ² = track_stats ? ones32(chs) : nothing

return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
end
Expand Down Expand Up @@ -401,13 +414,13 @@ The number of channels must be an integer multiple of the number of groups.
`channels` should be the size of the channel dimension in your data (see below).
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension.
For `WHCN` images it's the usual channel dimension.
If `affine=true`, it also applies a shift and a rescale to the input
If `affine=true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias `β` and scale `γ` parameters.
If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats=true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.
"""
mutable struct GroupNorm{F,V,N,W}
Expand All @@ -429,7 +442,7 @@ end
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)

Expand All @@ -440,11 +453,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
μ = track_stats ? zeros32(G) : nothing
σ² = track_stats ? ones32(G) : nothing

return GroupNorm(G, λ,
return GroupNorm(G, λ,
β, γ,
μ, σ²,
ϵ, momentum,
affine, track_stats,
μ, σ²,
ϵ, momentum,
affine, track_stats,
nothing, chs)
end

Expand Down Expand Up @@ -475,7 +488,7 @@ end
"""
hasaffine(l)
Return `true` if a normalisation layer has trainable shift and
Return `true` if a normalisation layer has trainable shift and
scale parameters, `false` otherwise.
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
Expand Down
33 changes: 29 additions & 4 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,39 +149,50 @@ end
# 1.3
# 1.3
@test m.σ² .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]

x′ = m(x)
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)

@inferred m(x)
end

let m = BatchNorm(2; track_stats=false), x = [1.0 3.0 5.0; 2.0 4.0 6.0]
@inferred m(x)
end

# with activation function
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
y = m(x)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
@inferred m(x)
end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
@inferred m(x)
end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
@inferred m(x)
end

let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
@inferred m(x)
end

let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
@inferred m(x)
end

@test length(Flux.params(BatchNorm(10))) == 2
Expand Down Expand Up @@ -232,6 +243,8 @@ end
@test length(m.μ) == 2
@test length(m.σ²) == 2
@test y (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5

@inferred m(x)
end

# with activation function
Expand All @@ -242,35 +255,41 @@ end
affine_shape[[1,3]] .= 1

y = evalwgrad(m, x)
y = m(x) # inference time after a training step
y = m(x) # inference time after a training step
μ = reshape(m.μ, affine_shape...)
σ² = reshape(m.σ², affine_shape...)
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7

@inferred m(x)
end

# with activation function
let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)

@test Flux.hasaffine(m) == true
@test length(params(m)) == 2
@test length(params(m)) == 2
x = Float64.(x)
y = m(x)
μ = mean(x, dims=1)
σ² = var(x, dims=1, corrected=false)
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7

@inferred m(x)
end

let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
@test Flux.hasaffine(m) == false
@test length(params(m)) == 0

x = Float64.(x)
y = m(x)
μ = mean(x, dims=1)
σ² = var(x, dims=1, corrected=false)
@test y sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7

@inferred m(x)
end


Expand All @@ -279,6 +298,8 @@ end
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y

@inferred m(x)
end

# check that μ, σ², and the output are the correct size for higher rank tensors
Expand All @@ -288,6 +309,8 @@ end
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes

@inferred m(x)
end

# show that instance norm is equal to batch norm when channel and batch dims are squashed
Expand All @@ -299,6 +322,8 @@ end
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000

@inferred m(x)
end

@test length(Flux.params(InstanceNorm(10))) == 0
Expand Down

0 comments on commit 5244ade

Please sign in to comment.