Skip to content

Commit

Permalink
feat: Add GCNConv support for HeteroGraphConv (#367)
Browse files Browse the repository at this point in the history
* WIP: working state but with degree placeholder

* clean up comments

* add test

* include edge type in degree calc for gnnheterograph

* add self loops for gnnheterograph

* add TODO comment

* add TODO comment

* fix failing test

* update the new add_self_loops_behavior

* change empty string to nothing for memory optimization

* add GCNConv support for HeteroGraphConv

* fix tests

* GCN tests passing

* add small optimization to reduce ifs

* avoid repeated code

* add PR review suggestion

* run all tests
  • Loading branch information
askorupka authored Mar 10, 2024
1 parent 95e8392 commit 0c641a2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
40 changes: 24 additions & 16 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,22 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity;
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
end

check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) =
throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs"))

function check_gcnconv_input(g::GNNGraph, edge_weight::AbstractVector)
function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector)
if length(edge_weight) !== g.num_edges
throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"))
end
end

check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing
check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing


function (l::GCNConv)(g::GNNGraph,
x::AbstractMatrix{T},
function (l::GCNConv)(g::AbstractGNNGraph,
x,
edge_weight::EW = nothing,
norm_fn::Function = d -> 1 ./ sqrt.(d)
) where {T, EW <: Union{Nothing, AbstractVector}}
) where {EW <: Union{Nothing, AbstractVector}}

check_gcnconv_input(g, edge_weight)

Expand All @@ -118,26 +117,35 @@ function (l::GCNConv)(g::GNNGraph,
end
end
Dout, Din = size(l.weight)
if Dout < Din
if Dout < Din && !(g isa GNNHeteroGraph)
# multiply before convolution if it is more convenient, otherwise multiply after
# (this works only for homogenous graph)
x = l.weight * x
end
if edge_weight !== nothing
d = degree(g, T; dir = :in, edge_weight)

xj, xi = expand_srcdst(g, x) # expand only after potential multiplication
T = eltype(xi)

if g isa GNNHeteroGraph
d = degree(g, g.etypes[1], T; dir = :in)
else
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
if edge_weight !== nothing
d = degree(g, T; dir = :in, edge_weight)
else
d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight)
end
end
c = norm_fn(d)
x = x .* c'
!(g isa GNNHeteroGraph) ? xj = xj .* c' : Nothing
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight)
elseif l.use_edge_weight
x = propagate(w_mul_xj, g, +, xj = x)
x = propagate(w_mul_xj, g, +, xj = xj)
else
x = propagate(copy_xj, g, +, xj = x)
x = propagate(copy_xj, g, +, xj = xj)
end
x = x .* c'
if Dout >= Din
if Dout >= Din || g isa GNNHeteroGraph
x = l.weight * x
end
return l.σ.(x .+ l.bias)
Expand Down
9 changes: 9 additions & 0 deletions test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,13 @@
y = layers(hg, x);
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end

@testset "GCNConv" begin
g = rand_bipartite_heterograph((2,3), 6)
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu),
(:B, :to, :A) => GCNConv(4 => 2, relu));
y = layers(g, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end
end

0 comments on commit 0c641a2

Please sign in to comment.