Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GroupNorm Layer #696

Merged
merged 9 commits into from
Mar 29, 2019
Merged

Conversation

shreyas-kowshik
Copy link
Contributor

The previous PR was closed due to some silly git errors from my side. Here is the original code added.

@MikeInnes MikeInnes requested a review from staticfloat March 25, 2019 14:12
src/layers/normalise.jl Show resolved Hide resolved

function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass atleast 3 channels for Group Norm to work")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atleast -> at least


mutable struct GroupNorm{F,V,W,N,T}
G::T # number of groups
N::T # Batch Size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to bake N into the GroupNorm; we should be able to allow variable batch sizes with this operation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@staticfloat The size of the mean and variance matrices depends on N and is (Channels/Groups,Batch SIze). So should'nt that need the implementation of N?
I apologize if it's a silly question but this is what I understand so far.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that the size of μ and σ² will change depending on N; but we don't need to know it beforehand. We can just broadcast μ and σ² up to the proper size when we get a new batch. (You are correct that we cannot have fully variable batch sizes; e.g. running with N=8 and then N=16 won't work; we will need to reset the μ and σ² before doing that).

You can initialize μ and σ² as size (G, 1) where G is the number of batches, then when you do this line:

gn.μ = (1 - mtm) .* gn.μ .+ mtm .* reshape(data(μ), (groups,batches))

gn.u will be automatically broadcast up to the proper size.

size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass atleast 3 channels for Group Norm to work")
(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
(size(x,ndims(x)) == gn.N) || error("Number of samples in batch not equal to that passed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we nee this check; we should be able to deal with variable batch sizes.

(size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))")
(size(x,ndims(x)) == gn.N) || error("Number of samples in batch not equal to that passed")
# γ : (1,1...,C,1)
# β : (1,1...,C,1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what these comments mean. Can you expand them or remove them?

groups = gn.G
channels = size(x, dims-1)
batches = size(x,dims)
channels_per_group = convert(Int32,div(channels,groups))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you explicitly want channels_per_group to be an Int32? div() should already give you an integral type.

else
T = eltype(x)
og_shape = size(x)
x = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever use of reshape(), I suggest that you name this something other than x though, as it makes looking at things like ndims(x) below needlessly confusing.

@shreyas-kowshik
Copy link
Contributor Author

@staticfloat Thank you for your feedback. I have made the requested changes.

@staticfloat
Copy link
Contributor

Add some tests as well; I think there might be some problems with variable names and multiple code branches, so it will be good to do tests (both with an active layer and a !active layer). Take a look at the instance normalization tests for inspiration.

@shreyas-kowshik
Copy link
Contributor Author

@staticfloat I have added the tests for Group Normalization. They are passing on my machine. Can you please review it once?

initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)

chs is the numebr of channels, the channeld dimension of your input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numebr, channeld both look to be typos.


"""
Group Normalization.
Known to improve the overall accuracy in case of classification and segmentation tasks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the second line mention that this is a normalization layer that can perform better than Batch or Instance normalization.

For an array of N dimensions, the (N-1)th index is the channel dimension.

G is the number of groups along which the statistics would be computed.
The number of groups must divide the number of channels for this to work.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better way of saying this is that the number of channels must be an integer multiple of the number of groups.

else
T = eltype(x)
og_shape = size(x)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're calculating the same thing for y in each branch, you can just pull that out of the if statement and do it once before.

@@ -1,5 +1,5 @@
using Flux: testmode!
using Flux.Tracker: data
using Flux.Tracker: data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra whitespace

test/layers/normalisation.jl Show resolved Hide resolved
@staticfloat
Copy link
Contributor

It's coming together! I have more minor comments this time around, once these are addressed I think we'll be ready to merge!

@shreyas-kowshik
Copy link
Contributor Author

@staticfloat Thank you for the feedback. I have made the requested changes.

@MikeInnes
Copy link
Member

Looks like the tests are failing; if it's unrelated it might just need a merge of master.

@@ -6,10 +6,8 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, Maxout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maxout is accidentally deleted here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnnychen94 Thanks!

Copy link
Contributor

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't go inside the code details, just add some comments on your docstring. Since this's your first PR, it's better to read the Julia documentation style guide and check it.

https://docs.julialang.org/en/v1/manual/documentation/

initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i),
ϵ = 1f-5, momentum = 0.1f0)

chs is the number of channels, the channel dimension of your input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

best practice of writting variable name is to add `` around it -- REPL can recognize it.

  • chs --> chs
  • G --> G

Group Normalization.
This layer can outperform Batch-Normalization and Instance-Normalization.

GroupNorm(chs::Integer, G::Integer, λ = identity;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.julialang.org/en/v1/manual/documentation/

Always show the signature of a function at the top of the documentation, with a four-space indent so that it is printed as Julia code.

GroupNorm(32,16)) # 32 channels, 16 groups (G = 16), thus 2 channels per group used
```

Link : https://arxiv.org/pdf/1803.08494.pdf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I prefer to add a title to this link to info the users what it might be
"""
References:
[1] Wu, Y., & He, K. (2018). Group normalization. In Proceedings of the European Conference on Computer Vision (ECCV) (pp. 3-19). https://arxiv.org/abs/1803.08494
"""

print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually, it's best practice to add one newline at EOF

end

GroupNorm(chs::Integer, G::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you need is not add Float32 here. Instead, you need to do type conversion in the implementation details function(gn::GroupNorm)(x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll have a look at these and incorporate them in my commit.

Link : https://arxiv.org/pdf/1803.08494.pdf
"""

mutable struct GroupNorm{F,V,W,N,T}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to add constraints GroupNorm{F<:Function, V<:Number, W<:Number, T<:Integer} and do type conversions with new constructors.
And, do we really need N here, is it possible to be absorbed by V or W?

@shreyas-kowshik
Copy link
Contributor Author

@MikeInnes The tests related to GroupNorm have passed. The issue was due to the missing export of Maxout as pointed out by @johnnychen94 .

@staticfloat staticfloat merged commit 7418a2d into FluxML:master Mar 29, 2019
@staticfloat
Copy link
Contributor

Thanks @shreyas-kowshik!

@MikeInnes
Copy link
Member

Can you add this layer to the docs, and also an entry to NEWS.md?

@johnnychen94
Copy link
Contributor

I'm still suspect about the default initializer, we can't just let some type (in this case Float32) be the default type.

GroupNorm(chs::Integer, G::Integer, λ = identity;
          initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0)

But since there'll be deprecation on these init keywords, it doesn't matter much. #671

@shreyas-kowshik
Copy link
Contributor Author

@MikeInnes #728 adds GroupNorm to docs and NEWS.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants