Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GNNChain #202

Merged
merged 7 commits into from
Jul 29, 2022
Merged

Update GNNChain #202

merged 7 commits into from
Jul 29, 2022

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Jul 22, 2022

Keeping in sync GNNChain with last months' changes with Flux.Chain implementation (FluxML/Flux.jl#1809)

TODO:

  • improve GNNChain docstring
  • benchmarks

@CarloLucibello CarloLucibello changed the title Update a GNNChain Update GNNChain Jul 22, 2022
Comment on lines +108 to +115
## TODO see if this is faster for small chains
## see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180
# @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N}
# symbols = vcat(:x, [gensym() for _ in 1:N])
# calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N]
# Expr(:block, calls...)
# end
# _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to myself: remember to benchmark this before merging

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarked on a small graph / net

        n, deg = 10, 4
        din, d, dout = 10, 3, 4, 2

        g = GNNGraph(random_regular_graph(n, deg), 
                    graph_type=GRAPH_T,
                    ndata= randn(Float32, din, n))
        x = g.ndata.x

        gnn = GNNChain(GCNConv(din => d),
                       BatchNorm(d),
                       x -> tanh.(x),
                       GraphConv(d => d, tanh),
                       Dropout(0.5),
                       Dense(d, dout))

There is a performance increase with the generated _applychain but not large enough for the change to be worthwhile

julia> using BenchmarkTools

### without @generated _applychain

julia> @btime gnn(g, x)
  7.469 μs (84 allocations: 12.48 KiB)
2×10 Matrix{Float32}:
 -0.8186    -0.570312  -0.777638    -0.641642  -0.684857  -0.975505
  0.305567   0.559996   0.631279      0.4687     0.479899   0.321139

julia> @btime gradient(x -> sum(gnn(g, x)), x)
  515.917 μs (2422 allocations: 160.52 KiB)
(Float32[0.3974119 -0.5917164  -0.9200875 1.1957061; -0.54502636 -1.5056851  -2.6915464 2.5114572; -0.97105116 0.7726713  1.0995824 -1.5013595],)

### with @generated _applychain

julia> @btime gnn(g, x)
  6.825 μs (73 allocations: 11.55 KiB)
2×10 Matrix{Float32}:
 -0.8186    -0.570312  -0.777638    -0.641642  -0.684857  -0.975505
  0.305567   0.559996   0.631279      0.4687     0.479899   0.321139

julia> @btime gradient(x -> sum(gnn(g, x)), x)
  454.750 μs (2157 allocations: 161.00 KiB)
(Float32[-0.564121 0.3105453  0.19531891 -0.22819248; -0.6428803 0.13550264  0.9421329 -0.79201597; 0.7816532 -0.4734739  0.23667078 0.033573348],)

In both cases the gradient is very slow, this should be further investigated

@codecov
Copy link

codecov bot commented Jul 22, 2022

Codecov Report

Merging #202 (1558ffa) into master (93b6fa2) will increase coverage by 0.90%.
The diff coverage is 89.65%.

@@            Coverage Diff             @@
##           master     #202      +/-   ##
==========================================
+ Coverage   86.44%   87.35%   +0.90%     
==========================================
  Files          15       15              
  Lines        1365     1368       +3     
==========================================
+ Hits         1180     1195      +15     
+ Misses        185      173      -12     
Impacted Files Coverage Δ
src/layers/basic.jl 80.00% <89.65%> (+30.00%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 93b6fa2...1558ffa. Read the comment docs.

@CarloLucibello CarloLucibello merged commit 651c216 into master Jul 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant