From e2b85505fbbd41d15c8827ffc182bfdc1a839f8c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 22 Jul 2022 17:28:13 +0200 Subject: [PATCH 1/7] update GNNChain --- docs/src/index.md | 2 +- docs/src/tutorials/gnn_intro_pluto.jl | 4 +- .../tutorials/graph_classification_pluto.jl | 2 +- examples/graph_classification_tudataset.jl | 2 +- examples/link_prediction_pubmed.jl | 2 +- examples/neural_ode_cora.jl | 2 +- examples/node_classification_cora.jl | 2 +- perf/neural_ode_mnist.jl | 2 +- .../node_classification_cora_geometricflux.jl | 2 +- src/layers/basic.jl | 97 ++++++++++--------- test.jl | 7 ++ test/examples/node_classification_cora.jl | 2 +- 12 files changed, 68 insertions(+), 58 deletions(-) create mode 100644 test.jl diff --git a/docs/src/index.md b/docs/src/index.md index 3da002eba..5a7d96e71 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -54,7 +54,7 @@ model = GNNChain(GCNConv(16 => 64), Dense(64, 1)) |> device ps = Flux.params(model) -opt = ADAM(1f-4) +opt = Adam(1f-4) ``` ### Training diff --git a/docs/src/tutorials/gnn_intro_pluto.jl b/docs/src/tutorials/gnn_intro_pluto.jl index ce81d1f78..6b903deac 100644 --- a/docs/src/tutorials/gnn_intro_pluto.jl +++ b/docs/src/tutorials/gnn_intro_pluto.jl @@ -266,7 +266,7 @@ Since everything in our model is differentiable and parameterized, we can add so Here, we make use of a semi-supervised or transductive learning procedure: We simply train against one node per class, but are allowed to make use of the complete input graph data. Training our model is very similar to any other Flux model. -In addition to defining our network architecture, we define a loss criterion (here, `logitcrossentropy` and initialize a stochastic gradient optimizer (here, `ADAM`). +In addition to defining our network architecture, we define a loss criterion (here, `logitcrossentropy` and initialize a stochastic gradient optimizer (here, `Adam`). After that, we perform multiple rounds of optimization, where each round consists of a forward and backward pass to compute the gradients of our model parameters w.r.t. to the loss derived from the forward pass. If you are not new to Flux, this scheme should appear familar to you. @@ -285,7 +285,7 @@ Let us now start training and see how our node embeddings evolve over time (best begin model = GCN(num_features, num_classes) ps = Flux.params(model) - opt = ADAM(1e-2) + opt = Adam(1e-2) epochs = 2000 emb = h diff --git a/docs/src/tutorials/graph_classification_pluto.jl b/docs/src/tutorials/graph_classification_pluto.jl index a54a19372..ed80e2810 100644 --- a/docs/src/tutorials/graph_classification_pluto.jl +++ b/docs/src/tutorials/graph_classification_pluto.jl @@ -202,7 +202,7 @@ function train!(model; epochs=200, η=1e-2, infotime=10) device = Flux.cpu model = model |> device ps = Flux.params(model) - opt = ADAM(1e-3) + opt = Adam(1e-3) function report(epoch) diff --git a/examples/graph_classification_tudataset.jl b/examples/graph_classification_tudataset.jl index a724165c4..c2c5c68ce 100644 --- a/examples/graph_classification_tudataset.jl +++ b/examples/graph_classification_tudataset.jl @@ -82,7 +82,7 @@ function train(; kws...) Dense(nhidden, 1)) |> device ps = Flux.params(model) - opt = ADAM(args.η) + opt = Adam(args.η) # LOGGING FUNCTION diff --git a/examples/link_prediction_pubmed.jl b/examples/link_prediction_pubmed.jl index d7c5d88ea..5693c651b 100644 --- a/examples/link_prediction_pubmed.jl +++ b/examples/link_prediction_pubmed.jl @@ -77,7 +77,7 @@ function train(; kws...) pred = DotPredictor() ps = Flux.params(model) - opt = ADAM(args.η) + opt = Adam(args.η) ### LOSS FUNCTION ############ diff --git a/examples/neural_ode_cora.jl b/examples/neural_ode_cora.jl index f5a340fc6..f90c9e4b9 100644 --- a/examples/neural_ode_cora.jl +++ b/examples/neural_ode_cora.jl @@ -48,7 +48,7 @@ model = GNNChain(GCNConv(nin => nhidden, relu), ps = Flux.params(model); # ## Optimizer -opt = ADAM(0.01) +opt = Adam(0.01) function eval_loss_accuracy(X, y, mask) diff --git a/examples/node_classification_cora.jl b/examples/node_classification_cora.jl index 3746b02ba..185f40a05 100644 --- a/examples/node_classification_cora.jl +++ b/examples/node_classification_cora.jl @@ -57,7 +57,7 @@ function train(; kws...) Dense(nhidden, nout)) |> device ps = Flux.params(model) - opt = ADAM(args.η) + opt = Adam(args.η) display(g) diff --git a/perf/neural_ode_mnist.jl b/perf/neural_ode_mnist.jl index 0ddc8cded..6c41f2420 100644 --- a/perf/neural_ode_mnist.jl +++ b/perf/neural_ode_mnist.jl @@ -40,7 +40,7 @@ model = Chain(Flux.flatten, ps = Flux.params(model); # ## Optimizer -opt = ADAM(0.01) +opt = Adam(0.01) function eval_loss_accuracy(X, y) ŷ = model(X) diff --git a/perf/node_classification_cora_geometricflux.jl b/perf/node_classification_cora_geometricflux.jl index 4188bb579..e3d39e9d5 100644 --- a/perf/node_classification_cora_geometricflux.jl +++ b/perf/node_classification_cora_geometricflux.jl @@ -59,7 +59,7 @@ function train(; kws...) Dense(nhidden, nout)) |> device ps = Flux.params(model) - opt = ADAM(args.η) + opt = Adam(args.η) @info g diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e0b7b9882..1faae2c3a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -49,20 +49,6 @@ WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph @functor WithGraph Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model,) -# Work around -# https://github.com/FluxML/Flux.jl/issues/1733 -# Revisit after -# https://github.com/FluxML/Flux.jl/pull/1742 -function Flux.destructure(m::WithGraph) - @assert m.traingraph == false # TODO - p, re = Flux.destructure(m.model) - function re_withgraph(x) - WithGraph(re(x), m.g, m.traingraph) - end - - return p, re_withgraph -end - (l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...) (l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...) @@ -99,60 +85,77 @@ julia> m(g, x) -0.0134364 -0.0120716 -0.0172505 ``` """ -struct GNNChain{T} <: GNNLayer +struct GNNChain{T<:Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer layers::T +end - GNNChain(xs...) = new{typeof(xs)}(xs) - - function GNNChain(; kw...) - :layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) - isempty(kw) && return new{Tuple{}}(()) - new{typeof(values(kw))}(values(kw)) - end +@functor GNNChain + +GNNChain(xs...) = GNNChain(xs) + +function GNNChain(; kw...) + :layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) + isempty(kw) && return GNNChain(()) + GNNChain(values(kw)) end @forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last, - Base.iterate, Base.lastindex, Base.keys + Base.iterate, Base.lastindex, Base.keys, Base.firstindex -Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...) -Flux.functor(::Type{<:GNNChain}, c::Tuple) = c, ls -> GNNChain(ls...) +(c::GNNChain)(g::GNNGraph, x) = _applychain(c.layers, g, x) +(c::GNNChain)(g::GNNGraph) = _applychain(c.layers, g) -# input from graph -applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g))) -applylayer(l::GNNLayer, g::GNNGraph) = l(g) +## TODO see if this is faster for small chains +# @generated function applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N} +# symbols = vcat(:x, [gensym() for _ in 1:N]) +# calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N] +# Expr(:block, calls...) +# end -# explicit input -applylayer(l, g::GNNGraph, x) = l(x) -applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) -# Handle Flux.Parallel -applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=applylayer(l, g, node_features(g))) -applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers) +function _applychain(layers, g::GNNGraph, x) # type-unstable path, helps compile times + for l in layers + x = _applylayer(l, g, x) + end + x +end -# input from graph -applychain(::Tuple{}, g::GNNGraph) = g -applychain(fs::Tuple, g::GNNGraph) = applychain(tail(fs), applylayer(first(fs), g)) +function _applychain(layers, g::GNNGraph) # type-unstable path, helps compile times + for l in layers + g = _applylayer(l, g) + end + g +end -# explicit input -applychain(::Tuple{}, g::GNNGraph, x) = x -applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x)) +# # explicit input +_applylayer(l, g::GNNGraph, x) = l(x) +_applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) -(c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x) -(c::GNNChain)(g::GNNGraph) = applychain(Tuple(c.layers), g) +# input from graph +_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g))) +_applylayer(l::GNNLayer, g::GNNGraph) = l(g) +# # Handle Flux.Parallel +_applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=_applylayer(l, g, node_features(g))) -Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...) -Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) = - GNNChain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) +function _applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) + closures = map(f -> (x -> _applylayer(f, g, x)), l.layers) + return Parallel(l.connection, closures)(x) +end + +Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]) +Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) = + GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) function Base.show(io::IO, c::GNNChain) print(io, "GNNChain(") _show_layers(io, c.layers) print(io, ")") end + _show_layers(io, layers::Tuple) = join(io, layers, ", ") _show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") - +_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]")) """ DotDecoder() @@ -181,5 +184,5 @@ struct DotDecoder <: GNNLayer end function (::DotDecoder)(g, x) check_num_nodes(g, x) - apply_edges(xi_dot_xj, g, xi=x, xj=x) + return apply_edges(xi_dot_xj, g, xi=x, xj=x) end diff --git a/test.jl b/test.jl new file mode 100644 index 000000000..bd946bd23 --- /dev/null +++ b/test.jl @@ -0,0 +1,7 @@ +using GraphNeuralNetworks, Flux + +chain = GNNChain(GraphConv(2=>2)) + +params, restructure = Flux.destructure(chain) + +restructure(params) \ No newline at end of file diff --git a/test/examples/node_classification_cora.jl b/test/examples/node_classification_cora.jl index 7056e3d21..0334c8ef5 100644 --- a/test/examples/node_classification_cora.jl +++ b/test/examples/node_classification_cora.jl @@ -53,7 +53,7 @@ function train(Layer; verbose=false, kws...) Dense(nhidden, nout)) |> device ps = Flux.params(model) - opt = ADAM(args.η) + opt = Adam(args.η) ## TRAINING From dfc2444e14a83c277503e1e223da2a4729c8e750 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 22 Jul 2022 17:28:30 +0200 Subject: [PATCH 2/7] cleanup --- test.jl | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 test.jl diff --git a/test.jl b/test.jl deleted file mode 100644 index bd946bd23..000000000 --- a/test.jl +++ /dev/null @@ -1,7 +0,0 @@ -using GraphNeuralNetworks, Flux - -chain = GNNChain(GraphConv(2=>2)) - -params, restructure = Flux.destructure(chain) - -restructure(params) \ No newline at end of file From d5bdcdd24e1b176491fb5fc1229dce3cbfb69132 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 22 Jul 2022 17:32:07 +0200 Subject: [PATCH 3/7] cleanup --- src/layers/basic.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1faae2c3a..db56ff591 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -106,25 +106,26 @@ end (c::GNNChain)(g::GNNGraph) = _applychain(c.layers, g) ## TODO see if this is faster for small chains -# @generated function applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N} +## see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180 +# @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N} # symbols = vcat(:x, [gensym() for _ in 1:N]) # calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N] # Expr(:block, calls...) # end - +# _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x) function _applychain(layers, g::GNNGraph, x) # type-unstable path, helps compile times for l in layers x = _applylayer(l, g, x) end - x + return x end function _applychain(layers, g::GNNGraph) # type-unstable path, helps compile times for l in layers g = _applylayer(l, g) end - g + return g end # # explicit input From db8489f715a9ad2bd3a47840cd9bc3e42e50e4bf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 22 Jul 2022 17:39:49 +0200 Subject: [PATCH 4/7] update compat bounds --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 91ecb24f2..64c20388e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GraphNeuralNetworks" uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" authors = ["Carlo Lucibello and contributors"] -version = "0.4.4" +version = "0.4.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -29,7 +29,7 @@ Adapt = "3" CUDA = "3.3" ChainRulesCore = "1" DataStructures = "0.18" -Flux = "0.13" +Flux = "0.13.4" Functors = "0.2, 0.3" Graphs = "1.4" KrylovKit = "0.5" From 9c17125fe7b20677496b3b846ec2c39f09bcafad Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 28 Jul 2022 09:11:03 +0200 Subject: [PATCH 5/7] improve docstring --- src/layers/basic.jl | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index db56ff591..f9bab8a9e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -71,18 +71,43 @@ and if names are given, `m[:name] == m[1]` etc. # Examples ```juliarepl -julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4)); +julia> using Flux, GraphNeuralNetworks + +julia> m = GNNChain(GCNConv(2=>5), + BatchNorm(5), + x -> relu.(x), + Dense(5, 4)) +GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)) julia> x = randn(Float32, 2, 3); -julia> g = GNNGraph([1,1,2,3], [2,3,1,1]); +julia> g = rand_graph(3, 6) +GNNGraph: + num_nodes = 3 + num_edges = 6 julia> m(g, x) 4×3 Matrix{Float32}: - 0.157941 0.15443 0.193471 - 0.0819516 0.0503105 0.122523 - 0.225933 0.267901 0.241878 - -0.0134364 -0.0120716 -0.0172505 + -0.795592 -0.795592 -0.795592 + -0.736409 -0.736409 -0.736409 + 0.994925 0.994925 0.994925 + 0.857549 0.857549 0.857549 + +julia> m2 = GNNChain(enc = m, + dec = DotDecoder()) + m2 = GNNChain(enc = m, + dec = DotDecoder()) + +julia> m2 = GNNChain(enc = m, + dec = DotDecoder()) +GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder()) + +julia> m2(g, x) +1×6 Matrix{Float32}: + 2.90053 2.90053 2.90053 2.90053 2.90053 2.90053 + +julia> m2[:enc](g, x) == m(g, x) +true ``` """ struct GNNChain{T<:Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer From 10d73db1029774d0c7dce8cbcdac667bec8d7793 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 28 Jul 2022 09:16:50 +0200 Subject: [PATCH 6/7] add tests --- src/layers/basic.jl | 5 ----- test/layers/basic.jl | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f9bab8a9e..d8b958950 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -93,11 +93,6 @@ julia> m(g, x) 0.994925 0.994925 0.994925 0.857549 0.857549 0.857549 -julia> m2 = GNNChain(enc = m, - dec = DotDecoder()) - m2 = GNNChain(enc = m, - dec = DotDecoder()) - julia> m2 = GNNChain(enc = m, dec = DotDecoder()) GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder()) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c0a2d8382..71d4414bb 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -17,6 +17,20 @@ test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[:μ, :σ²]) + @testset "constructor with names" begin + m = GNNChain(GCNConv(2=>5), + BatchNorm(5), + x -> relu.(x), + Dense(5, 4)) + x = randn(Float32, 2, 3); + g = rand_graph(3, 6) + + m2 = GNNChain(enc = m, + dec = DotDecoder()) + + @test m2[:enc] === m + @test m2(g, x) == m2[:dec](g, m2[:enc](g, x)) + end @testset "Parallel" begin AddResidual(l) = Parallel(+, identity, l) From 1558ffa3209f755e6602905457f55d93e67690f0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 28 Jul 2022 09:26:04 +0200 Subject: [PATCH 7/7] more tests --- test/layers/basic.jl | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 71d4414bb..15422c855 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -1,11 +1,13 @@ @testset "basic" begin @testset "GNNChain" begin n, din, d, dout = 10, 3, 4, 2 + deg = 4 - g = GNNGraph(random_regular_graph(n, 4), + g = GNNGraph(random_regular_graph(n, deg), graph_type=GRAPH_T, ndata= randn(Float32, din, n)) - + x = g.ndata.x + gnn = GNNChain(GCNConv(din => d), BatchNorm(d), x -> tanh.(x), @@ -18,13 +20,11 @@ test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[:μ, :σ²]) @testset "constructor with names" begin - m = GNNChain(GCNConv(2=>5), - BatchNorm(5), + m = GNNChain(GCNConv(din=>d), + BatchNorm(d), x -> relu.(x), - Dense(5, 4)) - x = randn(Float32, 2, 3); - g = rand_graph(3, 6) - + Dense(d, dout)) + m2 = GNNChain(enc = m, dec = DotDecoder()) @@ -32,6 +32,15 @@ @test m2(g, x) == m2[:dec](g, m2[:enc](g, x)) end + @testset "constructor with vector" begin + m = GNNChain(GCNConv(din=>d), + BatchNorm(d), + x -> relu.(x), + Dense(d, dout)) + m2 = GNNChain([m.layers...]) + @test m2(g, x) == m(g, x) + end + @testset "Parallel" begin AddResidual(l) = Parallel(+, identity, l)