-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added GroupNorm Layer #696
Conversation
src/layers/normalise.jl
Outdated
|
||
function(gn::GroupNorm)(x) | ||
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") | ||
ndims(x) > 2 || error("Need to pass atleast 3 channels for Group Norm to work") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atleast
-> at least
src/layers/normalise.jl
Outdated
|
||
mutable struct GroupNorm{F,V,W,N,T} | ||
G::T # number of groups | ||
N::T # Batch Size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to bake N
into the GroupNorm
; we should be able to allow variable batch sizes with this operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@staticfloat The size of the mean and variance matrices depends on N and is (Channels/Groups,Batch SIze). So should'nt that need the implementation of N?
I apologize if it's a silly question but this is what I understand so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right that the size of μ
and σ²
will change depending on N
; but we don't need to know it beforehand. We can just broadcast μ
and σ²
up to the proper size when we get a new batch. (You are correct that we cannot have fully variable batch sizes; e.g. running with N=8
and then N=16
won't work; we will need to reset the μ
and σ²
before doing that).
You can initialize μ
and σ²
as size (G, 1)
where G
is the number of batches, then when you do this line:
gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches))
gn.u
will be automatically broadcast up to the proper size.
src/layers/normalise.jl
Outdated
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") | ||
ndims(x) > 2 || error("Need to pass atleast 3 channels for Group Norm to work") | ||
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))") | ||
(size(x,ndims(x)) == gn.N) || error("Number of samples in batch not equal to that passed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we nee this check; we should be able to deal with variable batch sizes.
src/layers/normalise.jl
Outdated
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))") | ||
(size(x,ndims(x)) == gn.N) || error("Number of samples in batch not equal to that passed") | ||
# γ : (1,1...,C,1) | ||
# β : (1,1...,C,1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what these comments mean. Can you expand them or remove them?
src/layers/normalise.jl
Outdated
groups = gn.G | ||
channels = size(x, dims-1) | ||
batches = size(x,dims) | ||
channels_per_group = convert(Int32,div(channels,groups)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you explicitly want channels_per_group
to be an Int32
? div()
should already give you an integral type.
src/layers/normalise.jl
Outdated
else | ||
T = eltype(x) | ||
og_shape = size(x) | ||
x = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever use of reshape()
, I suggest that you name this something other than x
though, as it makes looking at things like ndims(x)
below needlessly confusing.
75355af
to
595f1cf
Compare
@staticfloat Thank you for your feedback. I have made the requested changes. |
Add some tests as well; I think there might be some problems with variable names and multiple code branches, so it will be good to do tests (both with an |
@staticfloat I have added the tests for Group Normalization. They are passing on my machine. Can you please review it once? |
src/layers/normalise.jl
Outdated
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), | ||
ϵ = 1f-5, momentum = 0.1f0) | ||
|
||
chs is the numebr of channels, the channeld dimension of your input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numebr
, channeld
both look to be typos.
src/layers/normalise.jl
Outdated
|
||
""" | ||
Group Normalization. | ||
Known to improve the overall accuracy in case of classification and segmentation tasks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make the second line mention that this is a normalization layer that can perform better than Batch or Instance normalization.
src/layers/normalise.jl
Outdated
For an array of N dimensions, the (N-1)th index is the channel dimension. | ||
|
||
G is the number of groups along which the statistics would be computed. | ||
The number of groups must divide the number of channels for this to work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better way of saying this is that the number of channels must be an integer multiple of the number of groups
.
src/layers/normalise.jl
Outdated
else | ||
T = eltype(x) | ||
og_shape = size(x) | ||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're calculating the same thing for y
in each branch, you can just pull that out of the if
statement and do it once before.
test/layers/normalisation.jl
Outdated
@@ -1,5 +1,5 @@ | |||
using Flux: testmode! | |||
using Flux.Tracker: data | |||
using Flux.Tracker: data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra whitespace
It's coming together! I have more minor comments this time around, once these are addressed I think we'll be ready to merge! |
@staticfloat Thank you for the feedback. I have made the requested changes. |
Looks like the tests are failing; if it's unrelated it might just need a merge of master. |
@@ -6,10 +6,8 @@ using Base: tail | |||
using MacroTools, Juno, Requires, Reexport, Statistics, Random | |||
using MacroTools: @forward | |||
|
|||
export Chain, Dense, Maxout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maxout
is accidentally deleted here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@johnnychen94 Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't go inside the code details, just add some comments on your docstring. Since this's your first PR, it's better to read the Julia documentation style guide and check it.
src/layers/normalise.jl
Outdated
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), | ||
ϵ = 1f-5, momentum = 0.1f0) | ||
|
||
chs is the number of channels, the channel dimension of your input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best practice of writting variable name is to add `` around it -- REPL can recognize it.
- chs -->
chs
- G -->
G
src/layers/normalise.jl
Outdated
Group Normalization. | ||
This layer can outperform Batch-Normalization and Instance-Normalization. | ||
|
||
GroupNorm(chs::Integer, G::Integer, λ = identity; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://docs.julialang.org/en/v1/manual/documentation/
Always show the signature of a function at the top of the documentation, with a four-space indent so that it is printed as Julia code.
GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used | ||
``` | ||
|
||
Link : https://arxiv.org/pdf/1803.08494.pdf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I prefer to add a title to this link to info the users what it might be
"""
References:
[1] Wu, Y., & He, K. (2018). Group normalization. In Proceedings of the European Conference on Computer Vision (ECCV) (pp. 3-19). https://arxiv.org/abs/1803.08494
"""
src/layers/normalise.jl
Outdated
print(io, "GroupNorm($(join(size(l.β), ", "))") | ||
(l.λ == identity) || print(io, ", λ = $(l.λ)") | ||
print(io, ")") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually, it's best practice to add one newline at EOF
end | ||
|
||
GroupNorm(chs::Integer, G::Integer, λ = identity; | ||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you need is not add Float32
here. Instead, you need to do type conversion in the implementation details function(gn::GroupNorm)(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'll have a look at these and incorporate them in my commit.
Link : https://arxiv.org/pdf/1803.08494.pdf | ||
""" | ||
|
||
mutable struct GroupNorm{F,V,W,N,T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to add constraints GroupNorm{F<:Function, V<:Number, W<:Number, T<:Integer}
and do type conversions with new constructors.
And, do we really need N
here, is it possible to be absorbed by V
or W
?
@MikeInnes The tests related to GroupNorm have passed. The issue was due to the missing export of Maxout as pointed out by @johnnychen94 . |
Thanks @shreyas-kowshik! |
Can you add this layer to the docs, and also an entry to NEWS.md? |
I'm still suspect about the default initializer, we can't just let some type (in this case GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) But since there'll be deprecation on these |
@MikeInnes #728 adds GroupNorm to docs and NEWS.md |
The previous PR was closed due to some silly git errors from my side. Here is the original code added.