Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GroupNorm Layer #696

Merged
merged 9 commits into from
Mar 29, 2019
10 changes: 10 additions & 0 deletions docs/src/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,13 @@ truncate!(m)
Calling `truncate!` wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.

`truncate!` makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. `reset!` does this for you.

In general, when training with recurrent layers in your model, you'll want to call `reset!` or `truncate!` for each loss calculation:

```julia
function loss(x,y)
l = Flux.mse(m(x), y)
Flux.reset!(m)
return l
end
```
6 changes: 2 additions & 4 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maxout is accidentally deleted here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnnychen94 Thanks!

RNN, LSTM, GRU,
Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64

@reexport using NNlib
Expand Down
106 changes: 106 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,109 @@ function Base.show(io::IO, l::InstanceNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

"""
Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization.

GroupNorm(chs::Integer, G::Integer, λ = identity;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.julialang.org/en/v1/manual/documentation/

Always show the signature of a function at the top of the documentation, with a four-space indent so that it is printed as Julia code.

initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)

chs is the number of channels, the channel dimension of your input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best practice of writting variable name is to add `` around it -- REPL can recognize it.

  • chs --> chs
  • G --> G

For an array of N dimensions, the (N-1)th index is the channel dimension.

G is the number of groups along which the statistics would be computed.
The number of channels must be an integer multiple of the number of groups.

Example:
```
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1),
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
```

Link : https://arxiv.org/pdf/1803.08494.pdf
staticfloat marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I prefer to add a title to this link to info the users what it might be
"""
References:
[1] Wu, Y., & He, K. (2018). Group normalization. In Proceedings of the European Conference on Computer Vision (ECCV) (pp. 3-19). https://arxiv.org/abs/1803.08494
"""

"""

mutable struct GroupNorm{F,V,W,N,T}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to add constraints GroupNorm{F<:Function, V<:Number, W<:Number, T<:Integer} and do type conversions with new constructors.
And, do we really need N here, is it possible to be absorbed by V or W?

G::T # number of groups
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end

GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you need is not add Float32 here. Instead, you need to do type conversion in the implementation details function(gn::GroupNorm)(x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll have a look at these and incorporate them in my commit.

GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)),
zeros(G,1), ones(G,1), ϵ, momentum, true)

function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")

dims = length(size(x))
groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = div(channels,groups)
affine_shape = ones(Int, dims)

# Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels

μ_affine_shape = ones(Int,dims + 1)
μ_affine_shape[end-1] = groups

m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)

y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
ϵ = gn.ϵ
else
T = eltype(x)
og_shape = size(x)
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis)
μ = mean(y, dims = axes)
σ² = mean((y .- μ) .^ 2, dims = axes)

ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))

gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2)
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2)
end

let λ = gn.λ
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ)

# Reshape x̂
x̂ = reshape(x̂,og_shape)
λ.(γ .* x̂ .+ β)
end
end

children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)

mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)

_testmode!(gn::GroupNorm, test) = (gn.active = !test)

function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually, it's best practice to add one newline at EOF

111 changes: 111 additions & 0 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,114 @@ end
end

end

@testset "GroupNorm" begin
# begin tests
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions

let m = GroupNorm(4,2), sizes = (3,4,2),
x = param(reshape(collect(1:prod(sizes)), sizes))

@test m.β.data == [0, 0, 0, 0] # initβ(32)
@test m.γ.data == [1, 1, 1, 1] # initγ(32)

@test m.active

m(x)

#julia> x
#[:, :, 1] =
# 1.0 4.0 7.0 10.0
# 2.0 5.0 8.0 11.0
# 3.0 6.0 9.0 12.0
#
#[:, :, 2] =
# 13.0 16.0 19.0 22.0
# 14.0 17.0 20.0 23.0
# 15.0 18.0 21.0 24.0
#
# μ will be
# (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5
# (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5
#
# (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5
# (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5
#
# μ =
# 3.5 15.5
# 9.5 21.5
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95
# (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55
@test m.μ ≈ [0.95, 1.55]

# julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1.
# 2-element Array{Tracker.TrackedReal{Float64},1}:
# 1.25
# 1.25
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.

testmode!(m)
@test !m.active

x′ = m(x).data
println(x′[1])
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))

μ_affine_shape = ones(Int,length(sizes) + 1)
μ_affine_shape[end-1] = 2 # Number of groups

affine_shape = ones(Int,length(sizes) + 1)
affine_shape[end-2] = 2 # Channels per group
affine_shape[end-1] = 2 # Number of groups
affine_shape[1] = sizes[1]
affine_shape[end] = sizes[end]

og_shape = size(x)

@test m.active
m(x)

testmode!(m)
@test !m.active

y = m(x)
x_ = reshape(x,affine_shape...)
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
@test isapprox(y, out, atol = 1.0e-7)
end

let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end

# check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1)
@test size(y) == sizes
end

# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test IN(x) ≈ GN(x)
end
staticfloat marked this conversation as resolved.
Show resolved Hide resolved

# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) ≈ GN(x)
end
staticfloat marked this conversation as resolved.
Show resolved Hide resolved

end