diff --git a/src/Flux.jl b/src/Flux.jl index 32982131ab..a8bd4f0bf9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, LayerNorm, BatchNorm, + DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index e48d26fb1d..054ca08b78 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -155,3 +155,103 @@ function Base.show(io::IO, l::BatchNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + + +""" + InstanceNorm(channels::Integer, σ = identity; + initβ = zeros, initγ = ones, + ϵ = 1e-8, momentum = .1) + +Instance Normalization layer. The `channels` input should be the size of the +channel dimension in your data (see below). + +Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For +a batch of feature vectors this is just the data dimension, for `WHCN` images +it's the usual channel dimension.) + +`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and +shifts them to have a new mean and variance (corresponding to the learnable, +per-channel `bias` and `scale` parameters). + +See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). + +Example: +```julia +m = Chain( + Dense(28^2, 64), + InstanceNorm(64, relu), + Dense(64, 10), + InstanceNorm(10), + softmax) +``` +""" +expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + +mutable struct InstanceNorm{F,V,W,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N + active::Bool +end + +InstanceNorm(chs::Integer, λ = identity; + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = + InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), + zeros(chs), ones(chs), ϵ, momentum, true) + +function (in::InstanceNorm)(x) + size(x, ndims(x)-1) == length(in.β) || + error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") + ndims(x) > 2 || + error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") + # these are repeated later on depending on the batch size + dims = length(size(x)) + c = size(x, dims-1) + bs = size(x, dims) + affine_shape = ones(Int, dims) + affine_shape[end-1] = c + affine_shape[end] = bs + m = prod(size(x)[1:end-2]) + γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) + + if !in.active + μ = expand_inst(in.μ, affine_shape) + σ² = expand_inst(in.σ², affine_shape) + ϵ = in.ϵ + else + T = eltype(x) + + ϵ = data(convert(T, in.ϵ)) + axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) + μ = mean(x, dims = axes) + σ² = mean((x .- μ) .^ 2, dims = axes) + + # update moving mean/std + mtm = data(convert(T, in.momentum)) + in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) + in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) + end + + let λ = in.λ + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) + end +end + +children(in::InstanceNorm) = + (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) + +mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) + InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) + +_testmode!(in::InstanceNorm, test) = (in.active = !test) + +function Base.show(io::IO, l::InstanceNorm) + print(io, "InstanceNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 3ef9eb7ae8..d862944506 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -104,3 +104,99 @@ end @test (@allocated m(x)) < 100_000_000 end end + + +@testset "InstanceNorm" begin + # helper functions + expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + # begin tests + let m = InstanceNorm(2), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + + @test m.active + + m(x) + + #julia> x + #[:, :, 1] = + # 1.0 4.0 + # 2.0 5.0 + # 3.0 6.0 + # + #[:, :, 2] = + # 7.0 10.0 + # 8.0 11.0 + # 9.0 12.0 + # + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 + # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 + @test m.μ ≈ [0.5, 0.8] + # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq + # julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + # 2-element Array{Float64,1}: + # 1. + # 1. + @test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5) + end + # with activation function + let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + affine_shape = collect(sizes) + affine_shape[1] = 1 + + @test m.active + m(x) + + testmode!(m) + @test !m.active + + y = m(x).data + @test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7) + end + + let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), sizes...) + @test m(x) == y + end + + # check that μ, σ², and the output are the correct size for higher rank tensors + let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = m(x) + @test size(m.μ) == (sizes[end - 1], ) + @test size(m.σ²) == (sizes[end - 1], ) + @test size(y) == sizes + end + + # show that instance norm is equal to batch norm when channel and batch dims are squashed + let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) + end + + let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); + m(x) + @test (@allocated m(x)) < 100_000_000 + end + +end