-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for quadratic batching in #99 #100
Conversation
Just tested, my fix will fail on 3D arrays due to using hcat..coming up with a solution for that now. |
Fast N-d array concatenation (for N > 2) seems to be a common problem that needs to be solved in Base JuliaLang/julia#21672. Might need to table this PR until that is implemented or use an explicit solution for 3D arrays. |
I think we can introduce something like this: cat_fatures(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)
cat_fatures(xs::AbstractVector{<:AbstractMatrix}) = reduce(hcat, xs)
cat_fatures(xs::AbstractVector{<:AbstractArray{T,N}}) where {T,N} = reduce((x1, x2) -> cat(x1, x2, dims=N), xs) |
graph = (s, t, w) | ||
graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...) | ||
elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] ) | ||
graph = blockdiag([g.graph for g in gs]...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this blockdiag
method does not exist for dense matrices.
Maybe we can use dispatch as follows:
Flux.batch(gs::Vector{<:GNNGraph{<:COO_T}) = ... # specialized method
Flux.batch(gs::Vector{<:GNNGraph{<:SPARSE_T}) = ... # specialized method
Flux.batch(gs::Vector{<:GNNGraph{<:SPARSE_T}) = blockdiag(gs...) # old slow fallback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments! I'll work on the dispatching
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
thanks @tclements, I finished up and merged this in #122 |
This PR greatly reduces the time and memory use for
batch
ing large numbers of graphs as described in #99 . The key here was using the optimizedreduce(hcat,...)
andreduce(vcat,...)
to concatenate a large number of Nd arrays. Timings and memory usage:Old Version:
New version: