Skip to content

Commit

Permalink
Added GroupNorm Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas-kowshik committed Mar 25, 2019
1 parent b348e31 commit 75355af
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ BatchNorm
Dropout
AlphaDropout
LayerNorm
GroupNorm
```
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm
params, mapleaves, cpu, gpu, f32, f64

@reexport using NNlib
Expand Down
87 changes: 87 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,90 @@ function Base.show(io::IO, l::InstanceNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

"""
Group Normalization. Known to improve the overall accuracy in case of classification and segmentation tasks.
Link : https://arxiv.org/pdf/1803.08494.pdf
"""

mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups
N::T # Batch Size
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end

GroupNorm(chs::Integer, G::Integer,N::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
GroupNorm(G,N,λ, param(initβ(chs)), param(initγ(chs)),
zeros(G,N), ones(G,N), ϵ, momentum, true)

function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass atleast 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
(size(x,ndims(x)) == gn.N) || error("Number of samples in batch not equal to that passed")
# γ : (1,1...,C,1)
# β : (1,1...,C,1)

dims = length(size(x))
groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = convert(Int32,div(channels,groups))
affine_shape = ones(Int, dims)

# Output reshaped to (W,H...,C/G,G,N)
affine_shape[end-1] = channels

m = prod(size(x)[1:end-2]) * channels_per_group
γ = reshape(gn.γ, affine_shape...)
β = reshape(gn.β, affine_shape...)
if !gn.active
μ = reshape(gn.μ, affine_shape...)
σ² = reshape(gn.σ², affine_shape...)
ϵ = gn.ϵ
else
T = eltype(x)
og_shape = size(x)
x = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
axes = [(1:ndims(x)-2)...]# axes to reduce along (all but channels axis)
μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, gn.ϵ))
# update moving mean/std
mtm = data(convert(T, gn.momentum))

gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches))
gn.σ² = (1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches))
end

let λ = gn.λ
= (x .- μ) ./ sqrt.(σ² .+ ϵ)

# Reshape x and x̂
x = reshape(x,og_shape)
= reshape(x̂,og_shape)
λ.(γ .*.+ β)
end
end

children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active)

mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active)

_testmode!(gn::GroupNorm, test) = (gn.active = !test)

function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

0 comments on commit 75355af

Please sign in to comment.