-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added GroupNorm Layer #696
Changes from 7 commits
b64a984
35431e3
595f1cf
8033dca
671aed9
61c1fbd
c810fd4
b6fcd1d
4cb7b92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -286,3 +286,109 @@ function Base.show(io::IO, l::InstanceNorm) | |
(l.λ == identity) || print(io, ", λ = $(l.λ)") | ||
print(io, ")") | ||
end | ||
|
||
""" | ||
Group Normalization. | ||
This layer can outperform Batch-Normalization and Instance-Normalization. | ||
|
||
GroupNorm(chs::Integer, G::Integer, λ = identity; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to https://docs.julialang.org/en/v1/manual/documentation/
|
||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), | ||
ϵ = 1f-5, momentum = 0.1f0) | ||
|
||
chs is the number of channels, the channel dimension of your input. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. best practice of writting variable name is to add `` around it -- REPL can recognize it.
|
||
For an array of N dimensions, the (N-1)th index is the channel dimension. | ||
|
||
G is the number of groups along which the statistics would be computed. | ||
The number of channels must be an integer multiple of the number of groups. | ||
|
||
Example: | ||
``` | ||
m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), | ||
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used | ||
``` | ||
|
||
Link : https://arxiv.org/pdf/1803.08494.pdf | ||
staticfloat marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally, I prefer to add a title to this link to info the users what it might be |
||
""" | ||
|
||
mutable struct GroupNorm{F,V,W,N,T} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to add constraints |
||
G::T # number of groups | ||
λ::F # activation function | ||
β::V # bias | ||
γ::V # scale | ||
μ::W # moving mean | ||
σ²::W # moving std | ||
ϵ::N | ||
momentum::N | ||
active::Bool | ||
end | ||
|
||
GroupNorm(chs::Integer, G::Integer, λ = identity; | ||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What you need is not add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I'll have a look at these and incorporate them in my commit. |
||
GroupNorm(G, λ, param(initβ(chs)), param(initγ(chs)), | ||
zeros(G,1), ones(G,1), ϵ, momentum, true) | ||
|
||
function(gn::GroupNorm)(x) | ||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") | ||
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work") | ||
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))") | ||
|
||
dims = length(size(x)) | ||
groups = gn.G | ||
channels = size(x, dims-1) | ||
batches = size(x,dims) | ||
channels_per_group = div(channels,groups) | ||
affine_shape = ones(Int, dims) | ||
|
||
# Output reshaped to (W,H...,C/G,G,N) | ||
affine_shape[end-1] = channels | ||
|
||
μ_affine_shape = ones(Int,dims + 1) | ||
μ_affine_shape[end-1] = groups | ||
|
||
m = prod(size(x)[1:end-2]) * channels_per_group | ||
γ = reshape(gn.γ, affine_shape...) | ||
β = reshape(gn.β, affine_shape...) | ||
|
||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) | ||
if !gn.active | ||
og_shape = size(x) | ||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) | ||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) | ||
ϵ = gn.ϵ | ||
else | ||
T = eltype(x) | ||
og_shape = size(x) | ||
axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis) | ||
μ = mean(y, dims = axes) | ||
σ² = mean((y .- μ) .^ 2, dims = axes) | ||
|
||
ϵ = data(convert(T, gn.ϵ)) | ||
# update moving mean/std | ||
mtm = data(convert(T, gn.momentum)) | ||
|
||
gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches)),dims=2) | ||
gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (groups,batches)),dims=2) | ||
end | ||
|
||
let λ = gn.λ | ||
x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ) | ||
|
||
# Reshape x̂ | ||
x̂ = reshape(x̂,og_shape) | ||
λ.(γ .* x̂ .+ β) | ||
end | ||
end | ||
|
||
children(gn::GroupNorm) = | ||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum, gn.active) | ||
|
||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN) | ||
GroupNorm(gn,G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum, gn.active) | ||
|
||
_testmode!(gn::GroupNorm, test) = (gn.active = !test) | ||
|
||
function Base.show(io::IO, l::GroupNorm) | ||
print(io, "GroupNorm($(join(size(l.β), ", "))") | ||
(l.λ == identity) || print(io, ", λ = $(l.λ)") | ||
print(io, ")") | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually, it's best practice to add one newline at EOF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maxout
is accidentally deleted here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@johnnychen94 Thanks!