diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 024e62df2..7d4844450 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -88,23 +88,22 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity; GCNConv(W, b, σ, add_self_loops, use_edge_weight) end -check_gcnconv_input(g::GNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = +check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) -function check_gcnconv_input(g::GNNGraph, edge_weight::AbstractVector) +function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector) if length(edge_weight) !== g.num_edges throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))")) end end -check_gcnconv_input(g::GNNGraph, edge_weight::Nothing) = nothing +check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing - -function (l::GCNConv)(g::GNNGraph, - x::AbstractMatrix{T}, +function (l::GCNConv)(g::AbstractGNNGraph, + x, edge_weight::EW = nothing, norm_fn::Function = d -> 1 ./ sqrt.(d) - ) where {T, EW <: Union{Nothing, AbstractVector}} + ) where {EW <: Union{Nothing, AbstractVector}} check_gcnconv_input(g, edge_weight) @@ -118,26 +117,35 @@ function (l::GCNConv)(g::GNNGraph, end end Dout, Din = size(l.weight) - if Dout < Din + if Dout < Din && !(g isa GNNHeteroGraph) # multiply before convolution if it is more convenient, otherwise multiply after + # (this works only for homogenous graph) x = l.weight * x end - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) + + xj, xi = expand_srcdst(g, x) # expand only after potential multiplication + T = eltype(xi) + + if g isa GNNHeteroGraph + d = degree(g, g.etypes[1], T; dir = :in) else - d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) + end end c = norm_fn(d) - x = x .* c' + !(g isa GNNHeteroGraph) ? xj = xj .* c' : Nothing if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight) elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = x) + x = propagate(w_mul_xj, g, +, xj = xj) else - x = propagate(copy_xj, g, +, xj = x) + x = propagate(copy_xj, g, +, xj = xj) end x = x .* c' - if Dout >= Din + if Dout >= Din || g isa GNNHeteroGraph x = l.weight * x end return l.σ.(x .+ l.bias) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index d032b9c5c..761c5ca77 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -156,4 +156,13 @@ y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end + + @testset "GCNConv" begin + g = rand_bipartite_heterograph((2,3), 6) + x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu), + (:B, :to, :A) => GCNConv(4 => 2, relu)); + y = layers(g, x); + @test size(y.A) == (2,2) && size(y.B) == (2,3) + end end