From 129a708b6f0c36b794729d99535ac56e5a63f4fb Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 20 Feb 2019 14:01:05 +0100 Subject: [PATCH 1/4] instance normalization --- src/Flux.jl | 2 +- src/layers/normalise.jl | 98 ++++++++++++++++++++++++++++++++++++ test/layers/normalisation.jl | 81 +++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 32982131ab..a8bd4f0bf9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, LayerNorm, BatchNorm, + DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index e48d26fb1d..eaa994b22d 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -155,3 +155,101 @@ function Base.show(io::IO, l::BatchNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + + +""" + InstanceNorm(channels::Integer, σ = identity; + initβ = zeros, initγ = ones, + ϵ = 1e-8, momentum = .1) + +Instance Normalization layer. The `channels` input should be the size of the +channel dimension in your data (see below). + +Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For +a batch of feature vectors this is just the data dimension, for `WHCN` images +it's the usual channel dimension.) + +`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and +shifts them to have a new mean and variance (corresponding to the learnable, +per-channel `bias` and `scale` parameters). + +See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). + +Example: +```julia +m = Chain( + Dense(28^2, 64), + InstanceNorm(64, relu), + Dense(64, 10), + InstanceNorm(10), + softmax) +``` +""" +mutable struct InstanceNorm{F,V,W,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N + active::Bool +end + +InstanceNorm(chs::Integer, λ = identity; + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = + InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), + zeros(chs), ones(chs), ϵ, momentum, true) + +function (IN::InstanceNorm)(x) + size(x, ndims(x)-1) == length(IN.β) || + error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))") + ndims(x) > 2 || + error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") + # these are repeated later on depending on the batch size + γ, β = IN.γ, IN.β + dims = length(size(x)) + c = size(x, dims-1) + bs = size(x, dims) + affine_shape = ones(Int, dims) + affine_shape[end-1] = c + affine_shape[end] = bs + m = prod(size(x)[1:end-2]) + + if !IN.active + μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...) + σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...) + else + T = eltype(x) + + ϵ = data(convert(T, IN.ϵ)) + axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) + μ = mean(x, dims = axes) + σ² = mean((x .- μ) .^ 2, dims = axes) + + # update moving mean/std + mtm = data(convert(T, IN.momentum)) + IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) + IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :) + end + + let λ = IN.λ + temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ)) + # This is intentionally not fused because of an extreme slowdown doing so + λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...)) + end +end + +children(IN::InstanceNorm) = + (IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active) + +mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN) + InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active) + +_testmode!(IN::InstanceNorm, test) = (IN.active = !test) + +function Base.show(io::IO, l::InstanceNorm) + print(io, "InstanceNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 3ef9eb7ae8..a249a4f42f 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -104,3 +104,84 @@ end @test (@allocated m(x)) < 100_000_000 end end + + +@testset "InstanceNorm" begin + # helper functions + expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + # begin tests + let m = InstanceNorm(2), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + + @test m.active + + m(x) + + #julia> x + #[:, :, 1] = + # 1.0 4.0 + # 2.0 5.0 + # 3.0 6.0 + # + #[:, :, 2] = + # 7.0 10.0 + # 8.0 11.0 + # 9.0 12.0 + # + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 + # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 + @test m.μ ≈ [0.5, 0.8] + # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq + # julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + # 2-element Array{Float64,1}: + # 1. + # 1. + @test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5) + end + # with activation function + let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + affine_shape = collect(sizes) + affine_shape[1] = 1 + + @test m.active + m(x) + + testmode!(m) + @test !m.active + + y = m(x).data + @test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7) + end + + let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), sizes...) + @test m(x) == y + end + + let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); + m(x) + @test (@allocated m(x)) < 100_000_000 + end + +end From c41f8910052ab4ee85374c339bd8651fd84b4597 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 20 Feb 2019 14:51:55 +0100 Subject: [PATCH 2/4] changes based on the improved batchnorm in PR#633 --- src/layers/normalise.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index eaa994b22d..168f33632e 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -185,6 +185,7 @@ m = Chain( softmax) ``` """ +expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) mutable struct InstanceNorm{F,V,W,N} λ::F # activation function β::V # bias @@ -207,7 +208,6 @@ function (IN::InstanceNorm)(x) ndims(x) > 2 || error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") # these are repeated later on depending on the batch size - γ, β = IN.γ, IN.β dims = length(size(x)) c = size(x, dims-1) bs = size(x, dims) @@ -215,10 +215,12 @@ function (IN::InstanceNorm)(x) affine_shape[end-1] = c affine_shape[end] = bs m = prod(size(x)[1:end-2]) + γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape) if !IN.active - μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...) - σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...) + μ = expand_inst(IN.μ, affine_shape) + σ² = expand_inst(IN.σ², affine_shape) + ϵ = IN.ϵ else T = eltype(x) @@ -229,14 +231,13 @@ function (IN::InstanceNorm)(x) # update moving mean/std mtm = data(convert(T, IN.momentum)) - IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) - IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :) + IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) + IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :) end let λ = IN.λ - temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end From 83b4b3a7140592f2a8860cb12af23f55ae407a29 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 27 Feb 2019 12:03:29 +0100 Subject: [PATCH 3/4] changes based on PR comments --- src/layers/normalise.jl | 5 +++-- test/layers/normalisation.jl | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 168f33632e..7562e84f5d 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -186,6 +186,7 @@ m = Chain( ``` """ expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + mutable struct InstanceNorm{F,V,W,N} λ::F # activation function β::V # bias @@ -231,8 +232,8 @@ function (IN::InstanceNorm)(x) # update moving mean/std mtm = data(convert(T, IN.momentum)) - IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) - IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :) + IN.μ = dropdims(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) + IN.σ² = dropdims(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) end let λ = IN.λ diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index a249a4f42f..d862944506 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -179,7 +179,22 @@ end @test m(x) == y end - let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); + # check that μ, σ², and the output are the correct size for higher rank tensors + let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = m(x) + @test size(m.μ) == (sizes[end - 1], ) + @test size(m.σ²) == (sizes[end - 1], ) + @test size(y) == sizes + end + + # show that instance norm is equal to batch norm when channel and batch dims are squashed + let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) + end + + let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); m(x) @test (@allocated m(x)) < 100_000_000 end From 7b9b64f1cbfae549acac8f9b4f8996395f322b51 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Thu, 7 Mar 2019 09:44:55 +0100 Subject: [PATCH 4/4] change IN to in --- src/layers/normalise.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7562e84f5d..054ca08b78 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -203,9 +203,9 @@ InstanceNorm(chs::Integer, λ = identity; InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), zeros(chs), ones(chs), ϵ, momentum, true) -function (IN::InstanceNorm)(x) - size(x, ndims(x)-1) == length(IN.β) || - error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))") +function (in::InstanceNorm)(x) + size(x, ndims(x)-1) == length(in.β) || + error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") ndims(x) > 2 || error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") # these are repeated later on depending on the batch size @@ -216,39 +216,39 @@ function (IN::InstanceNorm)(x) affine_shape[end-1] = c affine_shape[end] = bs m = prod(size(x)[1:end-2]) - γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape) + γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - if !IN.active - μ = expand_inst(IN.μ, affine_shape) - σ² = expand_inst(IN.σ², affine_shape) - ϵ = IN.ϵ + if !in.active + μ = expand_inst(in.μ, affine_shape) + σ² = expand_inst(in.σ², affine_shape) + ϵ = in.ϵ else T = eltype(x) - ϵ = data(convert(T, IN.ϵ)) + ϵ = data(convert(T, in.ϵ)) axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) μ = mean(x, dims = axes) σ² = mean((x .- μ) .^ 2, dims = axes) # update moving mean/std - mtm = data(convert(T, IN.momentum)) - IN.μ = dropdims(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) - IN.σ² = dropdims(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) + mtm = data(convert(T, in.momentum)) + in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) + in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) end - let λ = IN.λ + let λ = in.λ x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) λ.(γ .* x̂ .+ β) end end -children(IN::InstanceNorm) = - (IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active) +children(in::InstanceNorm) = + (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) -mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN) - InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active) +mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) + InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) -_testmode!(IN::InstanceNorm, test) = (IN.active = !test) +_testmode!(in::InstanceNorm, test) = (in.active = !test) function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))")