From 3e2d3cf3bb949bee5911bf413e90d15b4198ca18 Mon Sep 17 00:00:00 2001 From: Tim Clements Date: Wed, 12 Jan 2022 11:52:43 -0800 Subject: [PATCH 1/5] Fix for quadratic batching in #99 --- src/GNNGraphs/transform.jl | 30 ++++++++++++++++++++++++++++-- src/GNNGraphs/utils.jl | 14 ++++++++++++++ test/GNNGraphs/transform.jl | 1 + 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index fe802a752..5ff8c0cf1 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -206,6 +206,7 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) end return g end +SparseArrays.blockdiag(gs::Vector{GNNGraph}) = SparseArrays.blockdiag(gs...) """ batch(gs::Vector{<:GNNGraph}) @@ -253,8 +254,33 @@ julia> g12.ndata.x 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ``` """ -Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...) - +function Flux.batch(gs::Vector{<:GNNGraph}) + nodes = [g.num_nodes for g in gs] + + if all(y -> isa(y, COO_T), [g.graph for g in gs] ) + edge_indices = [edge_index(g) for g in gs] + nodesum = cumsum([0, nodes...])[1:end-1] + s = reduce(vcat, [ei[1] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)]) + t = reduce(vcat, [ei[2] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)]) + w = reduce(vcat, [get_edge_weight(g) for g in gs]) + w = w isa Vector{Nothing} ? nothing : w + graph = (s, t, w) + graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...) + elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] ) + graph = blockdiag([g.graph for g in gs]...) + graph_indicator = vcat([ones_like(graph,Int,nodes[ii]) .+ (ii - 1) for ii in 1:length(nodes)]...) + end + + GNNGraph(graph, + sum(nodes), + sum([g.num_edges for g in gs]), + sum([g.num_graphs for g in gs]), + graph_indicator, + cat_features([g.ndata for g in gs]), + cat_features([g.edata for g in gs]), + cat_features([g.gdata for g in gs]), + ) +end """ unbatch(g::GNNGraph) diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl index 159227466..19aa0e9dd 100644 --- a/src/GNNGraphs/utils.jl +++ b/src/GNNGraphs/utils.jl @@ -24,6 +24,20 @@ function cat_features(x1::NamedTuple, x2::NamedTuple) NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1)) end +function cat_features(xs::Vector{NamedTuple{T1, T2}}) where {T1, T2} + symbols = [sort(collect(keys(x))) for x in xs] + all(y->y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys" + length(xs) == 1 && return xs[1] + + # concatenate + syms = symbols[1] + dims = [max(1, ndims(xs[1][k])) for k in syms] # promote scalar to 1D + methods = [dim == 1 ? vcat : hcat for dim in dims] # use optimized reduce(hcat,xs) or reduce(vcat,xs) + NamedTuple( + k => reduce(methods[ii],[x[k] for x in xs]) for (ii,k) in enumerate(syms) + ) +end + # Turns generic type into named tuple normalize_graphdata(data::Nothing; kws...) = NamedTuple() diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index 75c062eee..6b24ad466 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -25,6 +25,7 @@ g12 = Flux.batch([g1, g2]) g12b = blockdiag(g1, g2) + @test g12 == g12b g123 = Flux.batch([g1, g2, g3]) @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] From d307c2624fa8ec5fe9b819d156aadc235c84dbf5 Mon Sep 17 00:00:00 2001 From: Tim Clements Date: Thu, 13 Jan 2022 12:34:51 -0800 Subject: [PATCH 2/5] Update src/GNNGraphs/utils.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl index 19aa0e9dd..892445229 100644 --- a/src/GNNGraphs/utils.jl +++ b/src/GNNGraphs/utils.jl @@ -24,7 +24,7 @@ function cat_features(x1::NamedTuple, x2::NamedTuple) NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1)) end -function cat_features(xs::Vector{NamedTuple{T1, T2}}) where {T1, T2} +function cat_features(xs::Vector{<:NamedTuple}) symbols = [sort(collect(keys(x))) for x in xs] all(y->y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys" length(xs) == 1 && return xs[1] From a304a233aabc1f641708f3363908f5903c3b0597 Mon Sep 17 00:00:00 2001 From: Tim Clements Date: Thu, 13 Jan 2022 12:35:06 -0800 Subject: [PATCH 3/5] Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/transform.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 5ff8c0cf1..5fab64869 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -206,7 +206,6 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) end return g end -SparseArrays.blockdiag(gs::Vector{GNNGraph}) = SparseArrays.blockdiag(gs...) """ batch(gs::Vector{<:GNNGraph}) From d4e068833d5eed4e8902988aa525480106413d8a Mon Sep 17 00:00:00 2001 From: Tim Clements Date: Thu, 13 Jan 2022 12:35:38 -0800 Subject: [PATCH 4/5] Update src/GNNGraphs/transform.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/transform.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 5fab64869..5a1068243 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -259,8 +259,8 @@ function Flux.batch(gs::Vector{<:GNNGraph}) if all(y -> isa(y, COO_T), [g.graph for g in gs] ) edge_indices = [edge_index(g) for g in gs] nodesum = cumsum([0, nodes...])[1:end-1] - s = reduce(vcat, [ei[1] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)]) - t = reduce(vcat, [ei[2] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)]) + s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) w = reduce(vcat, [get_edge_weight(g) for g in gs]) w = w isa Vector{Nothing} ? nothing : w graph = (s, t, w) From 3dbf70adfa4836647d98d6d884f8cc7c4bb46339 Mon Sep 17 00:00:00 2001 From: Tim Clements Date: Thu, 13 Jan 2022 12:35:49 -0800 Subject: [PATCH 5/5] Update src/GNNGraphs/utils.jl Co-authored-by: Carlo Lucibello --- src/GNNGraphs/utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl index 892445229..60d6e819d 100644 --- a/src/GNNGraphs/utils.jl +++ b/src/GNNGraphs/utils.jl @@ -31,10 +31,8 @@ function cat_features(xs::Vector{<:NamedTuple}) # concatenate syms = symbols[1] - dims = [max(1, ndims(xs[1][k])) for k in syms] # promote scalar to 1D - methods = [dim == 1 ? vcat : hcat for dim in dims] # use optimized reduce(hcat,xs) or reduce(vcat,xs) NamedTuple( - k => reduce(methods[ii],[x[k] for x in xs]) for (ii,k) in enumerate(syms) + k => cat_features([x[k] for x in xs]) for (ii,k) in enumerate(syms) ) end