diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ecac67b5a..3a1d0188e 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,8 +1,8 @@ module GNNLux using ConcreteStructs: @concrete -using NNlib: NNlib, sigmoid, relu +using NNlib: NNlib, sigmoid, relu, swish using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer -using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer +using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer using Reexport: @reexport using Random: AbstractRNG using GNNlib: GNNlib @@ -18,9 +18,9 @@ export AGNNConv, CGConv, ChebConv, EdgeConv, - # EGNNConv, - # DConv, - # GATConv, + EGNNConv, + DConv, + GATConv, # GATv2Conv, # GatedGraphConv, GCNConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 2dc638d95..049be7331 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -255,3 +255,165 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st) end +@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)} + ϕe + ϕx + ϕh + num_features + residual::Bool +end + +function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) + return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) +end + +#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py +function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], + residual = false) + (in_size, edge_feat_size), out_size = ch + act_fn = swish + + # +1 for the radial feature: ||x_i - x_j||^2 + ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), + Dense(hidden_size => hidden_size, act_fn)) + + ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish), + Dense(hidden_size => out_size)) + + ϕx = Chain(Dense(hidden_size => hidden_size, swish), + Dense(hidden_size => 1, use_bias = false)) + + num_features = (in = in_size, edge = edge_feat_size, out = out_size, + hidden = hidden_size) + if residual + @assert in_size==out_size "Residual connection only possible if in_size == out_size" + end + return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) +end + +LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,) + +(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st) + +function (l::EGNNConv)(g, h, x, e, ps, st) + ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) + ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx)) + ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh)) + m = (; ϕe, ϕx, ϕh, l.residual, l.num_features) + return GNNlib.egnn_conv(m, g, h, x, e), st +end + +function Base.show(io::IO, l::EGNNConv) + ne = l.num_features.edge + nin = l.num_features.in + nout = l.num_features.out + nh = l.num_features.hidden + print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") + if l.residual + print(io, ", residual=true") + end + print(io, ")") +end + +@concrete struct DConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + init_weight + init_bias + use_bias::Bool +end + +function DConv(ch::Pair{Int, Int}, k::Int; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias = true) + in, out = ch + return DConv(in, out, k, init_weight, init_bias, use_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::DConv) + weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weights, bias) + else + return (; weights) + end +end + +LuxCore.parameterlength(l::DConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims : + l.in_dims * l.out_dims * l.k + +function (l::DConv)(g, x, ps, st) + m = (; ps.weights, bias = _getbias(ps), l.k) + return GNNlib.d_conv(m, g, x), st +end + +function Base.show(io::IO, l::DConv) + print(io, "DConv($(l.in) => $(l.out), k=$(l.k))") +end + +@concrete struct GATConv <: GNNLayer + dense_x + dense_e + init_weight + init_bias + use_bias::Bool + σ + negative_slope + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout +end + + +GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) + +function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init_weight = glorot_uniform, init_bias = zeros32, + use_bias::Bool = true, + add_self_loops = true, dropout=0.0) + (in, ein), out = ch + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_x = Dense(in => out * heads, use_bias = false) + dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing + negative_slope = convert(Float32, negative_slope) + return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias, + σ, negative_slope, ch, heads, concat, add_self_loops, dropout) +end + +# Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a) +function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv) + (in, ein), out = l.channel + dense_x = initialparameters(rng, l.dense_x) + a = init_weight(ein > 0 ? 3out : 2out, heads) + ps = (; dense_x, a) + if ein > 0 + ps = (ps..., dense_e = initialparameters(rng, l.dense_e)) + end + if use_bias + ps = (ps..., bias = l.init_bias(rng, concat ? out * l.heads : out)) + end + return ps +end + +(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::GATConv)(g, x, e, ps, st) + return GNNlib.gat_conv(l, g, x, e), st +end + +function Base.show(io::IO, l::GATConv) + (in, ein), out = l.channel + print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 520fcc570..dd48f0d94 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,36 +1,63 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) g = rand_graph(10, 40, seed=1234) - x = randn(rng, Float32, 3, 10) + in_dims = 3 + out_dims = 5 + x = randn(rng, Float32, in_dims, 10) @testset "GCNConv" begin - l = GCNConv(3 => 5, relu) - test_lux_layer(rng, l, g, x, outputsize=(5,)) + l = GCNConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @testset "ChebConv" begin - l = ChebConv(3 => 5, 2) - test_lux_layer(rng, l, g, x, outputsize=(5,)) + l = ChebConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @testset "GraphConv" begin - l = GraphConv(3 => 5, relu) - test_lux_layer(rng, l, g, x, outputsize=(5,)) + l = GraphConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @testset "AGNNConv" begin l = AGNNConv(init_beta=1.0f0) - test_lux_layer(rng, l, g, x, sizey=(3,10)) + test_lux_layer(rng, l, g, x, sizey=(in_dims, 10)) end @testset "EdgeConv" begin - nn = Chain(Dense(6 => 5, relu), Dense(5 => 5)) + nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) l = EdgeConv(nn, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(5,10), container=true) + test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) end @testset "CGConv" begin - l = CGConv(3 => 3, residual = true) - test_lux_layer(rng, l, g, x, outputsize=(3,), container=true) + l = CGConv(in_dims => in_dims, residual = true) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true) + end + + @testset "DConv" begin + l = DConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(5,)) + end + + @testset "EGNNConv" begin + hin = 6 + hout = 7 + hidden = 8 + l = EGNNConv(hin => hout, hidden) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + h = randn(rng, Float32, hin, g.num_nodes) + (hnew, xnew), stnew = l(g, h, x, ps, st) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_dims, g.num_nodes) + end + + @testset "GATConv" begin + x = randn(rng, Float32, 6, 10) + l = GATConv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(8,)) end end + diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index cd3606291..431561337 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -703,14 +703,14 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x T0 = x - if l.K > 1 + if l.k > 1 # T1_in = T0 * deg_in * A' #T1_out = T0 * deg_out' * A T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out') T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in) h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out end - for i in 2:l.K + for i in 2:l.k T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in) T2_in = 2 * T2_in - T0 T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out') diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ddfa4e945..89a3da750 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1536,14 +1536,14 @@ function Base.show(io::IO, l::TransformerConv) end """ - DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true) + DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). # Arguments - `ch`: Pair of input and output dimensions. -- `K`: Number of diffusion steps. +- `k`: Number of diffusion steps. - `init`: Weights' initializer. Default `glorot_uniform`. - `bias`: Add learnable bias. Default `true`. @@ -1552,7 +1552,7 @@ Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Ne julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10)); julia> dconv = DConv(2 => 4, 4) -DConv(2 => 4, K=4) +DConv(2 => 4, 4) julia> y = dconv(g, g.ndata.x); @@ -1565,20 +1565,20 @@ struct DConv <: GNNLayer out::Int weights::AbstractArray bias::AbstractArray - K::Int + k::Int end @functor DConv -function DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true) +function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) in, out = ch - weights = init(2, K, out, in) + weights = init(2, k, out, in) b = bias ? Flux.create_bias(weights, true, out) : false - DConv(in, out, weights, b, K) + return DConv(in, out, weights, b, k) end (l::DConv)(g, x) = GNNlib.d_conv(l, g, x) function Base.show(io::IO, l::DConv) - print(io, "DConv($(l.in) => $(l.out), K=$(l.K))") -end \ No newline at end of file + print(io, "DConv($(l.in) => $(l.out), $(l.k))") +end