diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 6cb564924e..0942a1db3c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -664,7 +664,7 @@ end function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N stride = expand(Val(N), stride) - pad = calc_padding(MaxPool ,pad, k, 1, stride) + pad = calc_padding(MaxPool, pad, k, 1, stride) return MaxPool(k, pad, stride) end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 164fc0d782..d7fabf7dc8 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -146,13 +146,13 @@ testmode!(m::AlphaDropout, mode=true) = LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5) A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be -used with recurrent hidden states. -The argument `sz` should be an integer or a tuple of integers. -In the forward pass, the layer normalises the mean and standard +used with recurrent hidden states. +The argument `sz` should be an integer or a tuple of integers. +In the forward pass, the layer normalises the mean and standard deviation of the input, the applied the elementwise activation `λ`. The input is normalised along the first `length(sz)` dimensions for tuple `sz`, along the first dimension for integer `sz`. -The input is expected to have first dimensions' size equal to `sz`. +The input is expected to have first dimensions' size equal to `sz`. If `affine=true` also applies a learnable shift and rescaling as in the [`Diagonal`](@ref) layer. @@ -192,35 +192,48 @@ end # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm # reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm -function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} +function _norm_layer_forward( + l, x::AbstractArray{T, N}; reduce_dims, affine_shape, +) where {T, N} if !_isactive(l) && l.track_stats # testmode with tracked stats stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) μ = reshape(l.μ, stats_shape) σ² = reshape(l.σ², stats_shape) - else # trainmode or testmode without tracked stats + else # trainmode or testmode without tracked stats μ = mean(x; dims=reduce_dims) σ² = mean((x .- μ).^2; dims=reduce_dims) if l.track_stats - ## update moving mean/std - Zygote.ignore() do - mtm = l.momentum - m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var - μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) - σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) - l.μ = (1-mtm) .* l.μ .+ mtm .* μnew - l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new - end + _track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std end end - if hasaffine(l) - γ = reshape(l.γ, affine_shape) - β = reshape(l.β, affine_shape) - return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β) - else - return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) - end + + o = _norm_layer_forward(x, μ, σ², l.ϵ) + hasaffine(l) || return l.λ.(o) + + γ = reshape(l.γ, affine_shape) + β = reshape(l.β, affine_shape) + return l.λ.(γ .* o .+ β) end +@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ) + +function _track_stats!( + bn, x::AbstractArray{T, N}, μ, σ², reduce_dims, +) where {T, N} + V = eltype(bn.σ²) + mtm = bn.momentum + res_mtm = one(V) - mtm + m = prod(size(x, i) for i in reduce_dims) + + μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) + σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) + + bn.μ = res_mtm .* bn.μ .+ mtm .* μnew + bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new + return nothing +end +Zygote.@nograd _track_stats! + """ BatchNorm(channels::Integer, λ=identity; initβ=zeros32, initγ=ones32, @@ -234,15 +247,15 @@ Given an array with `N` dimensions, call the `N-1`th the channel dimension. For a batch of feature vectors this is just the data dimension, for `WHCN` images it's the usual channel dimension. -`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` +`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` input slice and normalises the input accordingly. -If `affine=true`, it also applies a shift and a rescale to the input +If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias β and scale γ parameters. -After normalisation, elementwise activation `λ` is applied. +After normalisation, elementwise activation `λ` is applied. -If `track_stats=true`, accumulates mean and var statistics in training phase +If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. Use [`testmode!`](@ref) during inference. @@ -272,7 +285,7 @@ mutable struct BatchNorm{F,V,N,W} end function BatchNorm(chs::Int, λ=identity; - initβ=zeros32, initγ=ones32, + initβ=zeros32, initγ=ones32, affine=true, track_stats=true, ϵ=1f-5, momentum=0.1f0) @@ -282,8 +295,8 @@ function BatchNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return BatchNorm(λ, β, γ, - μ, σ², ϵ, momentum, - affine, track_stats, + μ, σ², ϵ, momentum, + affine, track_stats, nothing, chs) end @@ -318,19 +331,19 @@ end [Instance Normalization](https://arxiv.org/abs/1607.08022) layer. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. +Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. For `WHCN` images it's the usual channel dimension. -`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1` +`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1` input slice and normalises the input accordingly. -If `affine=true`, it also applies a shift and a rescale to the input +If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias `β` and scale `γ` parameters. -If `track_stats=true`, accumulates mean and var statistics in training phase +If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. -**Warning**: the defaults for `affine` and `track_stats` used to be `true` +**Warning**: the defaults for `affine` and `track_stats` used to be `true` in previous Flux versions (< v0.12). """ mutable struct InstanceNorm{F,V,N,W} @@ -358,7 +371,7 @@ function InstanceNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return InstanceNorm(λ, β, γ, - μ, σ², ϵ, momentum, + μ, σ², ϵ, momentum, affine, track_stats, nothing, chs) end @@ -401,13 +414,13 @@ The number of channels must be an integer multiple of the number of groups. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. +Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. For `WHCN` images it's the usual channel dimension. -If `affine=true`, it also applies a shift and a rescale to the input +If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias `β` and scale `γ` parameters. -If `track_stats=true`, accumulates mean and var statistics in training phase +If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. """ mutable struct GroupNorm{F,V,N,W} @@ -429,7 +442,7 @@ end trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : () function GroupNorm(chs::Int, G::Int, λ=identity; - initβ=zeros32, initγ=ones32, + initβ=zeros32, initγ=ones32, affine=true, track_stats=false, ϵ=1f-5, momentum=0.1f0) @@ -440,11 +453,11 @@ function GroupNorm(chs::Int, G::Int, λ=identity; μ = track_stats ? zeros32(G) : nothing σ² = track_stats ? ones32(G) : nothing - return GroupNorm(G, λ, + return GroupNorm(G, λ, β, γ, - μ, σ², - ϵ, momentum, - affine, track_stats, + μ, σ², + ϵ, momentum, + affine, track_stats, nothing, chs) end @@ -475,7 +488,7 @@ end """ hasaffine(l) -Return `true` if a normalisation layer has trainable shift and +Return `true` if a normalisation layer has trainable shift and scale parameters, `false` otherwise. See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 5e30bcb94b..9ab74e4a1d 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -149,9 +149,15 @@ end # 1.3 # 1.3 @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - + x′ = m(x) @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) + + @inferred m(x) + end + + let m = BatchNorm(2; track_stats=false), x = [1.0 3.0 5.0; 2.0 4.0 6.0] + @inferred m(x) end # with activation function @@ -159,29 +165,34 @@ end 2.0 4.0 6.0] y = m(x) @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) + @inferred m(x) end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y + @inferred m(x) end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y + @inferred m(x) end let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y + @inferred m(x) end let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); m(x) @test (@allocated m(x)) < 100_000_000 + @inferred m(x) end @test length(Flux.params(BatchNorm(10))) == 2 @@ -232,6 +243,8 @@ end @test length(m.μ) == 2 @test length(m.σ²) == 2 @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 + + @inferred m(x) end # with activation function @@ -242,10 +255,12 @@ end affine_shape[[1,3]] .= 1 y = evalwgrad(m, x) - y = m(x) # inference time after a training step + y = m(x) # inference time after a training step μ = reshape(m.μ, affine_shape...) σ² = reshape(m.σ², affine_shape...) @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) end # with activation function @@ -253,24 +268,28 @@ end x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == true - @test length(params(m)) == 2 + @test length(params(m)) == 2 x = Float64.(x) y = m(x) μ = mean(x, dims=1) σ² = var(x, dims=1, corrected=false) @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) end let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == false @test length(params(m)) == 0 - + x = Float64.(x) y = m(x) μ = mean(x, dims=1) σ² = var(x, dims=1, corrected=false) @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) end @@ -279,6 +298,8 @@ end y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @test m(x) == y + + @inferred m(x) end # check that μ, σ², and the output are the correct size for higher rank tensors @@ -288,6 +309,8 @@ end @test size(m.μ) == (sizes[end - 1], ) @test size(m.σ²) == (sizes[end - 1], ) @test size(y) == sizes + + @inferred m(x) end # show that instance norm is equal to batch norm when channel and batch dims are squashed @@ -299,6 +322,8 @@ end let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); m(x) @test (@allocated m(x)) < 100_000_000 + + @inferred m(x) end @test length(Flux.params(InstanceNorm(10))) == 0