Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for quadratic batching in #99 #100

Closed
wants to merge 5 commits into from

Conversation

tclements
Copy link
Contributor

@tclements tclements commented Jan 12, 2022

This PR greatly reduces the time and memory use for batching large numbers of graphs as described in #99 . The key here was using the optimized reduce(hcat,...) and reduce(vcat,...) to concatenate a large number of Nd arrays. Timings and memory usage:

Old Version:

using BenchmarkTools
using GraphNeuralNetworks

for ngraphs in 2 .^ (10:12)
    gs = [rand_graph(4, 6, ndata=ones(8, 4)) for _ in 1:ngraphs]
    println("\n=======================\nBatchsize = $ngraphs graphs\n=======================\n")
    b = @benchmark GraphNeuralNetworks.blockdiag($gs...)
    display(b)
end

=======================
Batchsize = 1024 graphs
=======================

BenchmarkTools.Trial: 76 samples with 1 evaluation.
 Range (min  max):  51.236 ms  117.125 ms  ┊ GC (min  max): 16.32%  24.58%
 Time  (median):     63.494 ms               ┊ GC (median):    21.49%
 Time  (mean ± σ):   66.366 ms ±  10.085 ms  ┊ GC (mean ± σ):  20.39% ±  3.10%

        ▂▂ ▅▅ █ ▅█ █ █▅     ▂        ▂
  ▅▁▁▁▁▁██▅███████▅█▁████▅▁▁█▅▅▅▅█▅▁▁█▁█▁▁█▁▅▁▅▁▁▁▁▁▁▁▁▅▁▁▁▁▁▅ ▁
  51.2 ms         Histogram: frequency by time         93.7 ms <

 Memory estimate: 196.20 MiB, allocs estimate: 73941.

=======================
Batchsize = 2048 graphs
=======================

BenchmarkTools.Trial: 20 samples with 1 evaluation.
 Range (min  max):  225.263 ms  300.921 ms  ┊ GC (min  max): 18.52%  17.75%
 Time  (median):     247.739 ms               ┊ GC (median):    17.96%
 Time  (mean ± σ):   256.413 ms ±  23.805 ms  ┊ GC (mean ± σ):  17.74% ±  0.95%

  ▁    █▁   ▁▁ ▁▁ ▁ █▁   ▁        ▁       ▁       ▁ ▁▁   ▁    ▁
  █▁▁▁▁██▁▁▁██▁██▁█▁██▁▁▁█▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█▁██▁▁▁█▁▁▁▁█ ▁
  225 ms           Histogram: frequency by time          301 ms <

 Memory estimate: 776.31 MiB, allocs estimate: 149717.

=======================
Batchsize = 4096 graphs
=======================

BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range (min  max):  922.417 ms    1.126 s  ┊ GC (min  max): 16.82%  18.22%
 Time  (median):     976.010 ms              ┊ GC (median):    16.93%
 Time  (mean ± σ):      1.009 s ± 80.602 ms  ┊ GC (mean ± σ):  17.26% ±  0.59%

  █            █ █                      █                    █
  █▁▁▁▁▁▁▁▁▁▁▁▁█▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  922 ms          Histogram: frequency by time          1.13 s <

 Memory estimate: 3.02 GiB, allocs estimate: 301269.

New version:

using BenchmarkTools
using GraphNeuralNetworks

for ngraphs in 2 .^ (10:12)
    gs = [rand_graph(4, 6, ndata=ones(8, 4)) for _ in 1:ngraphs]
    println("\n=======================\nBatchsize = $ngraphs graphs\n=======================\n")
    b = @benchmark GraphNeuralNetworks.batch($gs)
    display(b)
end

=======================
Batchsize = 1024 graphs
=======================

BenchmarkTools.Trial: 7919 samples with 1 evaluation.
 Range (min  max):  359.600 μs   15.122 ms  ┊ GC (min  max):  0.00%  88.54%
 Time  (median):     483.500 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   627.847 μs ± 906.415 μs  ┊ GC (mean ± σ):  17.96% ± 11.83%

  ██▅▂▂▁▁                                                       ▂
  █████████▇▇▅▆▅▅▅▃▄▁▄▁▁▃▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▆▅▇▆▇ █
  360 μs        Histogram: log(frequency) by time        6.7 ms <

 Memory estimate: 1.38 MiB, allocs estimate: 13389.

=======================
Batchsize = 2048 graphs
=======================

BenchmarkTools.Trial: 3931 samples with 1 evaluation.
 Range (min  max):  727.700 μs  13.003 ms  ┊ GC (min  max):  0.00%  82.01%
 Time  (median):       1.033 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):     1.268 ms ±  1.190 ms  ┊ GC (mean ± σ):  16.59% ± 15.12%

  ▅█▇█▆▃▁▁▁                                                    ▁
  ██████████▇▆▃▅▅▃▄▁▁▃▃▃▃▁▃▁▁▁▁▁▁▁▃▁▁▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▃▃▆▆▆▇█▇█▇ █
  728 μs        Histogram: log(frequency) by time      7.43 ms <

 Memory estimate: 2.75 MiB, allocs estimate: 26705.

=======================
Batchsize = 4096 graphs
=======================

BenchmarkTools.Trial: 1308 samples with 1 evaluation.
 Range (min  max):  1.877 ms  36.357 ms  ┊ GC (min  max):  0.00%  87.78%
 Time  (median):     2.747 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   3.809 ms ±  2.920 ms  ┊ GC (mean ± σ):  15.78% ± 18.05%

  ▂▃▆█▇▆▄▂▂▂ ▂▂▂▁▁
  ████████████████▇█▅▆▇▆▆▇▆▅▅▅▆▅▅▁▅▅▅▄▅▆▁▆▆▄▇▆█▇▇▆▅▅▇▆▅▄▅▅▅▄ █
  1.88 ms      Histogram: log(frequency) by time     13.3 ms <

 Memory estimate: 5.50 MiB, allocs estimate: 53343.

@tclements
Copy link
Contributor Author

Just tested, my fix will fail on 3D arrays due to using hcat..coming up with a solution for that now.

@tclements
Copy link
Contributor Author

Fast N-d array concatenation (for N > 2) seems to be a common problem that needs to be solved in Base JuliaLang/julia#21672. Might need to table this PR until that is implemented or use an explicit solution for 3D arrays.

@CarloLucibello
Copy link
Member

I think we can introduce something like this:

cat_fatures(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)
cat_fatures(xs::AbstractVector{<:AbstractMatrix}) = reduce(hcat, xs)
cat_fatures(xs::AbstractVector{<:AbstractArray{T,N}}) where {T,N} = reduce((x1, x2) -> cat(x1, x2, dims=N), xs)

src/GNNGraphs/transform.jl Outdated Show resolved Hide resolved
src/GNNGraphs/transform.jl Outdated Show resolved Hide resolved
src/GNNGraphs/utils.jl Outdated Show resolved Hide resolved
graph = (s, t, w)
graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...)
elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] )
graph = blockdiag([g.graph for g in gs]...)
Copy link
Member

@CarloLucibello CarloLucibello Jan 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this blockdiag method does not exist for dense matrices.
Maybe we can use dispatch as follows:

Flux.batch(gs::Vector{<:GNNGraph{<:COO_T}) = ...  # specialized method
Flux.batch(gs::Vector{<:GNNGraph{<:SPARSE_T}) = ...  # specialized method
Flux.batch(gs::Vector{<:GNNGraph{<:SPARSE_T}) = blockdiag(gs...)  # old slow fallback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments! I'll work on the dispatching

tclements and others added 4 commits January 13, 2022 12:34
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
@CarloLucibello CarloLucibello mentioned this pull request Jan 29, 2022
@CarloLucibello
Copy link
Member

thanks @tclements, I finished up and merged this in #122

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants