Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add geometric vector perceptrons (GVPs) #4

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

Add geometric vector perceptrons (GVPs) #4

wants to merge 19 commits into from

Conversation

bicycle1885
Copy link
Collaborator

An implementation example of a graph neural network using geometric vector perceptrons. Still need to check carefully.

using Flux
using GraphNeuralNetworks
using MessagePassingIPA: GeometricVectorPerceptron, VectorNorm
using Statistics: mean

struct GNN
    gvpstack::Chain
    sdropout::Dropout
    vdropout::Dropout
    snorm::LayerNorm
    vnorm::VectorNorm
end

Flux.@functor GNN

# s: scalar; v: vector
# n: node; e: edge
function GNN((sn, vn)::Tuple{Integer, Integer}, (se, ve)::Tuple{Integer, Integer}; dropout::Real = 0.1, σ::Function = relu)
    gvpstack = Chain(
        GeometricVectorPerceptron(sn + se => sn, vn + ve => vn, σ, σ),
        GeometricVectorPerceptron(sn => sn, vn => vn, σ, σ),
        GeometricVectorPerceptron(sn => sn, vn => vn),
    )
    sdropout = Dropout(dropout)
    vdropout = Dropout(dropout, dims = (2, 3))  # dropout whole vectors
    snorm = LayerNorm(sn)
    vnorm = VectorNorm()
    GNN(gvpstack, sdropout, vdropout, snorm, vnorm)
end

function (gnn::GNN)(g, (sn, vn), (se, ve))
    # run message passing
    function message(_, xj, e)
        s = cat(xj.s, e.s, dims = 1)
        v = cat(xj.v, e.v, dims = 2)
        gnn.gvpstack((s, v))
    end
    xj = (s = sn, v = vn)
    e = (s = se, v = ve)
    msgs = apply_edges(message, g; xj, e)
    s, v = aggregate_neighbors(g, mean, msgs)

    # update node embeddings
    sn = gnn.snorm(sn + gnn.sdropout(s))
    vn = gnn.vnorm(vn + gnn.vdropout(v))
    sn, vn
end

n = 10  # number of nodes
m = 8n  # number of edges
graph = rand_graph(n, m)
gnn = GNN((5, 8), (6, 9))
# (scalar embeddings, vector embeddings) for nodes and edges
node_embeddings = randn(Float32, 5, n), randn(Float32, 3, 8, n)
edge_embeddings = randn(Float32, 6, m), randn(Float32, 3, 9, m)

# update node embeddings
node_embeddings = gnn(graph, node_embeddings, edge_embeddings)

# check autodiff on GPU
using CUDA
gnn, graph, node_embeddings, edge_embeddings = gpu((gnn, graph, node_embeddings, edge_embeddings))
Flux.gradient(() -> sum(gnn(graph, node_embeddings, edge_embeddings)[1]), Flux.params(gnn))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant