Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type-stability for normalization layers #1856

Merged
merged 9 commits into from
Feb 3, 2022
Merged

Fix type-stability for normalization layers #1856

merged 9 commits into from
Feb 3, 2022

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Feb 3, 2022

This PR fixes type-stability for normalization layers.
This also makes ResNet model type-stable (as well as others).

While there is PR that reworks normalization layers, it is not clear what is the status of it.
@ToucheSir has also suggested that functional part should be moved into NNlib.jl, so I'm not sure if this PR should be accepted in the first place...
But at least we can use it to look at what improvements type-stability can bring in this case...

Also I feel it is not the cleanest solution, as it essentially computes the output type o::_basetype(typeof(x)){O, N}.

Benchmark:

using Test
using Flux
using BenchmarkTools

function main()
    x = rand(Float32, 28, 28, 64, 4)
    bn = BatchNorm(64)
    θ = Flux.params(bn)

    @info "forward"
    @time bn(x)
    @btime $bn($x)

    @info "grad"
    @time gradient(θ) do
        sum(bn(x))
    end
    @btime gradient($θ) do
        sum($bn($x))
    end

    @inferred bn(x) # master fails here
end

Before:

[ Info: forward
  0.254515 seconds (856.55 k allocations: 43.045 MiB, 98.02% compilation time)
  911.059 μs (29 allocations: 785.33 KiB)
[ Info: grad
  4.050493 seconds (12.31 M allocations: 635.845 MiB, 3.65% gc time, 99.81% compilation time)
  2.838 ms (804 allocations: 13.05 MiB)

After:

[ Info: forward
  0.000917 seconds (15 allocations: 1.532 MiB)
  455.609 μs (15 allocations: 1.53 MiB)
[ Info: grad
  2.245501 seconds (7.16 M allocations: 372.776 MiB, 3.75% gc time, 99.17% compilation time)
  2.337 ms (694 allocations: 12.28 MiB)

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

@codecov-commenter
Copy link

Codecov Report

Merging #1856 (c85ec1e) into master (8d3b8d3) will increase coverage by 0.20%.
The diff coverage is 94.44%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1856      +/-   ##
==========================================
+ Coverage   73.85%   74.06%   +0.20%     
==========================================
  Files          28       28              
  Lines        1683     1704      +21     
==========================================
+ Hits         1243     1262      +19     
- Misses        440      442       +2     
Impacted Files Coverage Δ
src/layers/normalise.jl 83.33% <94.28%> (+0.94%) ⬆️
src/layers/conv.jl 79.66% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8d3b8d3...c85ec1e. Read the comment docs.

src/layers/normalise.jl Outdated Show resolved Hide resolved
@pxl-th pxl-th requested a review from ToucheSir February 3, 2022 16:15
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do the times look now? More or less the same?

@pxl-th
Copy link
Member Author

pxl-th commented Feb 3, 2022

How do the times look now? More or less the same?

Yes, the times are the same.

[ Info: forward
  0.000813 seconds (15 allocations: 1.532 MiB)
  461.305 μs (15 allocations: 1.53 MiB)
[ Info: grad
  1.970701 seconds (7.04 M allocations: 366.512 MiB, 4.10% gc time, 99.81% compilation time)
  2.209 ms (634 allocations: 12.28 MiB)

@ToucheSir
Copy link
Member

ToucheSir commented Feb 3, 2022

Magnificent, thanks!

P.S.

@ToucheSir has also suggested that functional part should be moved into NNlib.jl, so I'm not sure if this PR should be accepted in the first place...

For any prospective contributors reading this comment, here's a "good first issue" ;)

@ToucheSir ToucheSir merged commit 5244ade into FluxML:master Feb 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants