Skip to content

Commit

Permalink
layers
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 29, 2024
1 parent fc67808 commit fddb701
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 28 deletions.
10 changes: 5 additions & 5 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu
using NNlib: NNlib, sigmoid, relu, swish
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -18,9 +18,9 @@ export AGNNConv,
CGConv,
ChebConv,
EdgeConv,
# EGNNConv,
# DConv,
# GATConv,
EGNNConv,
DConv,
GATConv,
# GATv2Conv,
# GatedGraphConv,
GCNConv,
Expand Down
162 changes: 162 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,165 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
end


@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
ϕe
ϕx
ϕh
num_features
residual::Bool
end

function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false)
return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
end

#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1],
residual = false)
(in_size, edge_feat_size), out_size = ch
act_fn = swish

# +1 for the radial feature: ||x_i - x_j||^2
ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
Dense(hidden_size => hidden_size, act_fn))

ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish),
Dense(hidden_size => out_size))

ϕx = Chain(Dense(hidden_size => hidden_size, swish),
Dense(hidden_size => 1, use_bias = false))

num_features = (in = in_size, edge = edge_feat_size, out = out_size,
hidden = hidden_size)
if residual
@assert in_size==out_size "Residual connection only possible if in_size == out_size"
end
return EGNNConv(ϕe, ϕx, ϕh, num_features, residual)
end

LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,)

(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st)

function (l::EGNNConv)(g, h, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx))
ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh))
m = (; ϕe, ϕx, ϕh, l.residual, l.num_features)
return GNNlib.egnn_conv(m, g, h, x, e), st
end

function Base.show(io::IO, l::EGNNConv)
ne = l.num_features.edge
nin = l.num_features.in
nout = l.num_features.out
nh = l.num_features.hidden
print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh")
if l.residual
print(io, ", residual=true")
end
print(io, ")")
end

@concrete struct DConv <: GNNLayer
in_dims::Int
out_dims::Int
k::Int
init_weight
init_bias
use_bias::Bool
end

function DConv(ch::Pair{Int, Int}, k::Int;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias = true)
in, out = ch
return DConv(in, out, k, init_weight, init_bias, use_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::DConv)
weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weights, bias)
else
return (; weights)
end
end

LuxCore.parameterlength(l::DConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims :
l.in_dims * l.out_dims * l.k

function (l::DConv)(g, x, ps, st)
m = (; ps.weights, bias = _getbias(ps), l.k)
return GNNlib.d_conv(m, g, x), st
end

function Base.show(io::IO, l::DConv)
print(io, "DConv($(l.in) => $(l.out), k=$(l.k))")
end

@concrete struct GATConv <: GNNLayer
dense_x
dense_e
init_weight
init_bias
use_bias::Bool
σ
negative_slope
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
dropout
end


GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)

function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init_weight = glorot_uniform, init_bias = zeros32,
use_bias::Bool = true,
add_self_loops = true, dropout=0.0)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_x = Dense(in => out * heads, use_bias = false)
dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing
negative_slope = convert(Float32, negative_slope)
return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias,
σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
end

# Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a)
function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv)
(in, ein), out = l.channel
dense_x = initialparameters(rng, l.dense_x)
a = init_weight(ein > 0 ? 3out : 2out, heads)
ps = (; dense_x, a)
if ein > 0
ps = (ps..., dense_e = initialparameters(rng, l.dense_e))
end
if use_bias
ps = (ps..., bias = l.init_bias(rng, concat ? out * l.heads : out))
end
return ps
end

(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::GATConv)(g, x, e, ps, st)
return GNNlib.gat_conv(l, g, x, e), st
end

function Base.show(io::IO, l::GATConv)
(in, ein), out = l.channel
print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end
51 changes: 39 additions & 12 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,63 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
g = rand_graph(10, 40, seed=1234)
x = randn(rng, Float32, 3, 10)
in_dims = 3
out_dims = 5
x = randn(rng, Float32, in_dims, 10)

@testset "GCNConv" begin
l = GCNConv(3 => 5, relu)
test_lux_layer(rng, l, g, x, outputsize=(5,))
l = GCNConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2)
test_lux_layer(rng, l, g, x, outputsize=(5,))
l = ChebConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
test_lux_layer(rng, l, g, x, outputsize=(5,))
l = GraphConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
test_lux_layer(rng, l, g, x, sizey=(3,10))
test_lux_layer(rng, l, g, x, sizey=(in_dims, 10))
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims))
l = EdgeConv(nn, aggr = +)
test_lux_layer(rng, l, g, x, sizey=(5,10), container=true)
test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true)
end

@testset "CGConv" begin
l = CGConv(3 => 3, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(3,), container=true)
l = CGConv(in_dims => in_dims, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true)
end

@testset "DConv" begin
l = DConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "EGNNConv" begin
hin = 6
hout = 7
hidden = 8
l = EGNNConv(hin => hout, hidden)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
h = randn(rng, Float32, hin, g.num_nodes)
(hnew, xnew), stnew = l(g, h, x, ps, st)
@test size(hnew) == (hout, g.num_nodes)
@test size(xnew) == (in_dims, g.num_nodes)
end

@testset "GATConv" begin
x = randn(rng, Float32, 6, 10)
l = GATConv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(8,))
end
end

4 changes: 2 additions & 2 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,14 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x

T0 = x
if l.K > 1
if l.k > 1
# T1_in = T0 * deg_in * A'
#T1_out = T0 * deg_out' * A
T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out')
T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in)
h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out
end
for i in 2:l.K
for i in 2:l.k
T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in)
T2_in = 2 * T2_in - T0
T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out')
Expand Down
18 changes: 9 additions & 9 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1536,14 +1536,14 @@ function Base.show(io::IO, l::TransformerConv)
end

"""
DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true)
DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true)
Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926).
# Arguments
- `ch`: Pair of input and output dimensions.
- `K`: Number of diffusion steps.
- `k`: Number of diffusion steps.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `bias`: Add learnable bias. Default `true`.
Expand All @@ -1552,7 +1552,7 @@ Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Ne
julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10));
julia> dconv = DConv(2 => 4, 4)
DConv(2 => 4, K=4)
DConv(2 => 4, 4)
julia> y = dconv(g, g.ndata.x);
Expand All @@ -1565,20 +1565,20 @@ struct DConv <: GNNLayer
out::Int
weights::AbstractArray
bias::AbstractArray
K::Int
k::Int
end

@functor DConv

function DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true)
function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true)
in, out = ch
weights = init(2, K, out, in)
weights = init(2, k, out, in)
b = bias ? Flux.create_bias(weights, true, out) : false
DConv(in, out, weights, b, K)
return DConv(in, out, weights, b, k)
end

(l::DConv)(g, x) = GNNlib.d_conv(l, g, x)

function Base.show(io::IO, l::DConv)
print(io, "DConv($(l.in) => $(l.out), K=$(l.K))")
end
print(io, "DConv($(l.in) => $(l.out), $(l.k))")
end

0 comments on commit fddb701

Please sign in to comment.