Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add geometric vector perceptrons (GVPs) #4

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.0.0-DEV"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Flux = "0.14"
Expand Down
164 changes: 163 additions & 1 deletion src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module MessagePassingIPA

using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus
using Flux: Flux, Dense, Chain, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus, sigmoid, relu
using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors
using LinearAlgebra: normalize
using Statistics: mean

# Algorithm 21 (x1: N, x2: Ca, x3: C)
function rigid_from_3points(x1::AbstractVector, x2::AbstractVector, x3::AbstractVector)
Expand Down Expand Up @@ -226,4 +227,165 @@ end

sumdrop(x; dims) = dropdims(sum(x; dims); dims)


# Geometric vector perceptron
# ---------------------------

struct GeometricVectorPerceptron
W_h::AbstractMatrix
W_μ::AbstractMatrix
scalar::Dense
sσ::Function
vσ::Function
vgate::Union{Dense, Nothing}
end

Flux.@functor GeometricVectorPerceptron

"""
GeometricVectorPerceptron(
(sin, vin) => (sout, vout),
(sσ, vσ) = (identity, identity);
bias = true,
vector_gate = false
)

Create a geometric vector perceptron layer.

This layer takes a pair of scalar and vector feature arrays that have the size
of `sin × batchsize` and `3 × vin × batchsize`, respectively, and returns a pair
of scalar and vector feature arrays that have the size of `sout × batchsize` and
`3 × vout × batchsize`, respectively. The scalar features are invariant whereas
the vector features are equivariant under any rotation and reflection.

# Arguments
- `sin`, `vin`: scalar and vector input dimensions
- `sout`, `vout`: scalar and vector output dimensions
- `sσ`, `vσ`: scalar and vector nonlinearlities
- `bias`: includes a bias term iff `bias = true`
- `vector_gate`: includes vector gating iff `vector_gate = true`

# References
- Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020).
- Jing, Bowen, et al. "Equivariant graph neural networks for 3d macromolecular structure." arXiv preprint arXiv:2106.03843 (2021).
"""
function GeometricVectorPerceptron(
((sin, vin), (sout, vout)),
(sσ, vσ) = (identity, identity);
bias::Bool = true,
vector_gate::Bool = false,
init = Flux.glorot_uniform
)
h = max(vin, vout) # intermediate dimension for vector mapping
W_h = init(vin, h)
W_μ = init(h, vout)
scalar = Dense(sin + h => sout; bias, init)
vgate = nothing
if vector_gate
vgate = Dense(sout => vout, sigmoid; init)
end
GeometricVectorPerceptron(W_h, W_μ, scalar, sσ, vσ, vgate)
end

# s: scalar features (sin × batch)
# V: vector feautres (3 × vin × batch)
function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArray{T, 3}) where T
@assert size(V, 1) == 3
V_h = batched_mul(V, gvp.W_h)
s_m = gvp.scalar(cat(norm1drop(V_h), s, dims = 1))
V_μ = batched_mul(V_h, gvp.W_μ)
s′ = gvp.sσ.(s_m)
if gvp.vgate === nothing
V′ = gvp.vσ.(unsqueeze(norm1drop(V_μ), dims = 1)) .* V_μ
else
V′ = unsqueeze(gvp.vgate(gvp.vσ.(s_m)), dims = 1) .* V_μ
end
s′, V′
end

# This makes chaining by Flux's Chain easier.
(gvp::GeometricVectorPerceptron)((s, V)::Tuple{AbstractArray{T, 2}, AbstractArray{T, 3}}) where T = gvp(s, V)

struct GeometricVectorPerceptronGNN
gvpstack::Chain
end

Flux.@functor GeometricVectorPerceptronGNN

"""
GeometricVectorPerceptronGNN(
(sn, vn),
(se, ve),
(sσ, vσ) = (relu, relu);
vector_gate = false,
n_intermediate_layers = 1,
)

Create a graph neural network with geometric vector perceptrons.

This layer first concatenates the node and the edge features and then propagates
them over the graph. It returns a pair of scalr and vector feature arrays that
have the same size of input node features.

# Arguments
- `sn`, `vn`: scalar and vector dimensions of node features
- `se`, `ve`: scalar and vector dimensions of edge features
- `sσ`, `vσ`: scalar and vector nonlinearlities
- `vector_gate`: includes vector gating iff `vector_gate = true`
- `n_intermediate_layers`: number of intermediate layers between the input and the output geometric vector perceptrons
"""
function GeometricVectorPerceptronGNN(
(sn, vn)::Tuple{Integer, Integer},
(se, ve)::Tuple{Integer, Integer},
(sσ, vσ)::Tuple{Function, Function} = (relu, relu);
vector_gate::Bool = false,
n_intermediate_layers::Integer = 1,
)
gvpstack = Chain(
# input layer
GeometricVectorPerceptron((2sn + se, 2vn + ve) => (sn, vn), (sσ, vσ); vector_gate),
# intermediate layers
[
GeometricVectorPerceptron((sn, vn) => (sn, vn), (sσ, vσ); vector_gate)
for _ in 1:n_intermediate_layers
]...,
# output layers
GeometricVectorPerceptron((sn, vn) => (sn, vn)),
)
GeometricVectorPerceptronGNN(gvpstack)
end

function (gnn::GeometricVectorPerceptronGNN)(
g::GNNGraph,
(sn, vn)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}},
(se, ve)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}},
) where T
# run message passing
function message(xi, xj, e)
s = cat(xi.s, xj.s, e.s, dims = 1)
v = cat(xi.v, xj.v, e.v, dims = 2)
gnn.gvpstack((s, v))
end
xi = xj = (s = sn, v = vn)
e = (s = se, v = ve)
msgs = apply_edges(message, g; xi, xj, e)
aggregate_neighbors(g, mean, msgs) # return (s, v)
end

# Normalization for vector features
struct VectorNorm
ϵ::Float32
end

VectorNorm(; eps::Real = 1f-5) = VectorNorm(eps)

function (norm::VectorNorm)(V::AbstractArray{T, 3}) where T
@assert size(V, 1) == 3
V ./ (sqrt.(mean(sum(abs2, V, dims = 1), dims = 2)) .+ norm.ϵ)
end

# L2 norm along the first dimension
norm1(X) = sqrt.(sum(abs2, X, dims = 1))
norm1drop(X) = dropdims(norm1(X), dims = 1)

end
77 changes: 76 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points
using MessagePassingIPA:
RigidTransformation, InvariantPointAttention,
transform, inverse_transform, compose, rigid_from_3points,
GeometricVectorPerceptron, GeometricVectorPerceptronGNN, VectorNorm
using GraphNeuralNetworks: rand_graph
using Flux: relu, batched_mul
using Rotations: RotMatrix
using Statistics: mean
using Test

@testset "MessagePassingIPA.jl" begin
Expand Down Expand Up @@ -56,4 +61,74 @@ using Test
rigid2 = RigidTransformation(rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) ≈ ipa(g, s, z, rigid2)
end

@testset "GeometricVectorPerceptron" begin
# scalar and vector feautres
n = 12
sin, sout = 8, 12
vin, vout = 10, 14
s = randn(Float32, sin, n)
V = randn(Float32, 3, vin, n)

for vector_gate in [false, true]
gvp = GeometricVectorPerceptron(
(sin, vin) => (sout, vout),
(relu, identity);
vector_gate
)

# check returned type and size
s′, V′ = gvp(s, V)
@test s′ isa Array{Float32, 2}
@test V′ isa Array{Float32, 3}
@test size(s′) == (sout, n)
@test size(V′) == (3, vout, n)

# check invariance and equivariance
R = rand(RotMatrix{3, Float32})
s″, V″ = gvp(s, batched_mul(R, V))
@test s″ ≈ s′
@test V″ ≈ batched_mul(R, V′)
end
end

@testset "GeometricVectorPerceptronGNN" begin
n = 10
m = 8n
g = rand_graph(n, m)

sn, vn = 8, 12
se, ve = 10, 14
gnn = GeometricVectorPerceptronGNN((sn, vn), (se, ve))
node_embeddings = randn(Float32, sn, n), randn(Float32, 3, vn, n)
edge_embeddings = randn(Float32, se, m), randn(Float32, 3, ve, m)

# check returned type and size
results = gnn(g, node_embeddings, edge_embeddings)
@test results isa Tuple{Array{Float32, 2}, Array{Float32, 3}}
s′, v′ = results
@test size(s′) == (sn, n)
@test size(v′) == (3, vn, n)

# check invariance and equivariance
R = rand(RotMatrix{3, Float32})
node_embeddings = (node_embeddings[1], batched_mul(R, node_embeddings[2]))
edge_embeddings = (edge_embeddings[1], batched_mul(R, edge_embeddings[2]))
s″, v″ = gnn(g, node_embeddings, edge_embeddings)
@test s″ ≈ s′
@test v″ ≈ batched_mul(R, v′)
end

@testset "VectorNorm" begin
norm = VectorNorm()
V = randn(Float32, 3, 5, 10)
@test norm(V) isa Array{Float32, 3}
@test size(norm(V)) == size(V)
@test norm(V) ≈ norm(100 * V)
@test all(sqrt.(mean(sum(abs2, norm(V), dims = 1), dims = 2)) .≈ 1)

# zero values
V = zeros(Float32, 3, 5, 10)
@test all(!isnan, norm(V))
end
end
Loading