Skip to content

Commit

Permalink
more layer
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 30, 2024
1 parent 4b4477e commit 67a51f7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
5 changes: 3 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
Expand All @@ -22,7 +23,7 @@ export AGNNConv,
DConv,
GATConv,
GATv2Conv,
# GatedGraphConv,
GatedGraphConv,
GCNConv,
# GINConv,
# GMMConv,
Expand Down
42 changes: 42 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,46 @@ function Base.show(io::IO, l::GATv2Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end


@concrete struct GatedGraphConv <: GRULayer
gru
init_weight
dims::Int
num_layers::Int
aggr
end


function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init_weight = glorot_uniform)
gru = GRUCell(dims => dims)
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
end

LucCore.outputsize(l::GatedGraphConv) = (l.dims,)

function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
gru = LuxCore.initialparameters(rng, l.gru)
weight = l.init_weight(rng, l.dims, l.dims)
return (; gru, weight)
end

LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2

function LuxCore.initialstates(rng::AbstractRNG, l::GatedGraphConv)
return (; gru = LuxCore.initialstates(rng, l.gru))
end

LuxCore.statelength(l::GatedGraphConv) = statelength(l.gru)

function (l::GatedGraphConv)(g, H, ps, st)
GNNlib.gated_graph_conv(l, g, H)
end

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end

0 comments on commit 67a51f7

Please sign in to comment.