From e4696c37f01202f7abbc76bcd4d1f14bfefe42da Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Wed, 28 Feb 2024 22:23:43 +0000 Subject: [PATCH 01/19] add GeometricVectorPerceptron --- src/MessagePassingIPA.jl | 26 ++++++++++++++++++++++++++ test/runtests.jl | 29 ++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 482935b..6ad1b30 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -226,4 +226,30 @@ end sumdrop(x; dims) = dropdims(sum(x; dims); dims) +struct GeometricVectorPerceptron + W_h + W_μ + scalar::Dense + vσ +end + +function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = identity, vσ::Function = identity; bias = true) + h = max(vin, vout) # intermediate dimension for vector mapping + W_h = randn(Float32, vin, h) + W_μ = randn(Float32, h, vout) + scalar = Dense(sin + h => sout, sσ; bias) + GeometricVectorPerceptron(W_h, W_μ, scalar, vσ) +end + +function (gvp::GeometricVectorPerceptron)(s::AbstractArray, V::AbstractArray) + V_h = batched_mul(V, gvp.W_h) + s′ = gvp.scalar(cat(norm1(V_h), s, dims = 1)) + V_μ = batched_mul(V_h, gvp.W_μ) + V′ = gvp.vσ(unsqueeze(norm1(V_μ), dims = 1)) .* V_μ + s′, V′ +end + +# L2 norm along the first dimension +norm1(X) = dropdims(sqrt.(sum(abs2.(X), dims = 1)), dims = 1) + end diff --git a/test/runtests.jl b/test/runtests.jl index 4b25282..bb4690d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ -using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points +using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points, GeometricVectorPerceptron using GraphNeuralNetworks: rand_graph +using Flux: relu, batched_mul using Rotations: RotMatrix using Test @@ -56,4 +57,30 @@ using Test rigid2 = RigidTransformation(rigid_from_3points(x1, x2, x3)...) @test ipa(g, s, z, rigid1) ≈ ipa(g, s, z, rigid2) end + + @testset "GeometricVectorPerceptron" begin + sin, sout = 8, 12 + vin, vout = 10, 14 + σ = relu + gvp = GeometricVectorPerceptron(sin => sout, vin => vout, σ, σ) + n = 12 + # scalar and vector feautres + s = randn(Float32, sin, n) + V = randn(Float32, 3, vin, n) + + # check returned type and size + s′, V′ = gvp(s, V) + @show typeof(s′) + @show typeof(V′) + @test s′ isa Array{Float32, 2} + @test V′ isa Array{Float32, 3} + @test size(s′) == (sout, n) + @test size(V′) == (3, vout, n) + + # check invariance and equivariance + R = rand(RotMatrix{3, Float32}) + s″, V″ = gvp(s, batched_mul(R, V)) + @test s″ ≈ s′ + @test V″ ≈ batched_mul(R, V′) + end end From 6689dd4ade5bcc99e032e37ea0291951cbf14277 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Wed, 28 Feb 2024 22:25:56 +0000 Subject: [PATCH 02/19] functor --- src/MessagePassingIPA.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 6ad1b30..1bf587f 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -233,6 +233,8 @@ struct GeometricVectorPerceptron vσ end +Flux.@functor GeometricVectorPerceptron + function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = identity, vσ::Function = identity; bias = true) h = max(vin, vout) # intermediate dimension for vector mapping W_h = randn(Float32, vin, h) From da209b4e07ddf0a5b4222e823d78de185ca7bce0 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Thu, 29 Feb 2024 13:47:26 +0000 Subject: [PATCH 03/19] add constructor --- src/MessagePassingIPA.jl | 3 +++ test/runtests.jl | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 1bf587f..1d86d3c 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -243,6 +243,9 @@ function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = ide GeometricVectorPerceptron(W_h, W_μ, scalar, vσ) end +GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; bias = true) = + GeometricVectorPerceptron(sin => sin, vin => vin, σ, σ; bias) + function (gvp::GeometricVectorPerceptron)(s::AbstractArray, V::AbstractArray) V_h = batched_mul(V, gvp.W_h) s′ = gvp.scalar(cat(norm1(V_h), s, dims = 1)) diff --git a/test/runtests.jl b/test/runtests.jl index bb4690d..091f2c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,5 +82,9 @@ using Test s″, V″ = gvp(s, batched_mul(R, V)) @test s″ ≈ s′ @test V″ ≈ batched_mul(R, V′) + + # utility constructor where #inputs == #outputs + gvp = GeometricVectorPerceptron(12, 24, σ) + @test gvp isa GeometricVectorPerceptron end end From f7b75e16f75b113a99d898299650be4607a1c602 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Thu, 29 Feb 2024 15:14:46 +0000 Subject: [PATCH 04/19] add docs --- src/MessagePassingIPA.jl | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 1d86d3c..59baeb9 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -226,6 +226,10 @@ end sumdrop(x; dims) = dropdims(sum(x; dims); dims) + +# Geometric vector perceptron +# --------------------------- + struct GeometricVectorPerceptron W_h W_μ @@ -235,6 +239,32 @@ end Flux.@functor GeometricVectorPerceptron +""" + GeometricVectorPerceptron( + sin => sout, + vin => vout, + sσ = identity, + vσ = identity; + bias = true + ) + +Create a geometric vector perceptron layer. + +This layer takes a pair of scalar and vector feature arrays that have the size +of `sin × batchsize` and `3 × vin × batchsize`, respectively, and returns a pair +of scalar and vector feature arrays that have the size of `sout × batchsize` and +`3 × vout × batchsize`, respectively. The scalar features are invariant whereas +the vector features are equivariant under any rotation and reflection. + +# Arguments +- `sin => sout`: scalar input and output dimensions +- `vin => vout`: vector input and output dimensions +- `sσ`: scalar nonlinearlity +- `vσ`: vector nonlinearlity +- `bias`: includes a bias term iff `bias = true`` + +Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). +""" function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = identity, vσ::Function = identity; bias = true) h = max(vin, vout) # intermediate dimension for vector mapping W_h = randn(Float32, vin, h) @@ -246,7 +276,10 @@ end GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; bias = true) = GeometricVectorPerceptron(sin => sin, vin => vin, σ, σ; bias) -function (gvp::GeometricVectorPerceptron)(s::AbstractArray, V::AbstractArray) +# s: scalar features (sin × batch) +# V: vector feautres (3 × vin × batch) +function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArray{T, 3}) where T + @assert size(V, 1) == 3 V_h = batched_mul(V, gvp.W_h) s′ = gvp.scalar(cat(norm1(V_h), s, dims = 1)) V_μ = batched_mul(V_h, gvp.W_μ) @@ -254,6 +287,9 @@ function (gvp::GeometricVectorPerceptron)(s::AbstractArray, V::AbstractArray) s′, V′ end +# This makes chaining by Flux.Chain easier. +(gvp::GeometricVectorPerceptron)((s, V)::Tuple{AbstractArray{T, 2}, AbstractArray{T, 3}}) where T = gvp(s, V) + # L2 norm along the first dimension norm1(X) = dropdims(sqrt.(sum(abs2.(X), dims = 1)), dims = 1) From c1319f59f964644de1d63e9bf87748232c8621d9 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Thu, 29 Feb 2024 15:17:59 +0000 Subject: [PATCH 05/19] --amend --- src/MessagePassingIPA.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 59baeb9..9de9804 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -287,10 +287,10 @@ function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArr s′, V′ end -# This makes chaining by Flux.Chain easier. +# This makes chaining by Flux's Chain easier. (gvp::GeometricVectorPerceptron)((s, V)::Tuple{AbstractArray{T, 2}, AbstractArray{T, 3}}) where T = gvp(s, V) # L2 norm along the first dimension -norm1(X) = dropdims(sqrt.(sum(abs2.(X), dims = 1)), dims = 1) +norm1(X) = dropdims(sqrt.(sum(abs2, X, dims = 1)), dims = 1) end From a7fbada20bd2014ca33d333a5f8491efa959bec5 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 1 Mar 2024 19:56:16 +0000 Subject: [PATCH 06/19] type constraint --- src/MessagePassingIPA.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 9de9804..57c65c7 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -265,7 +265,7 @@ the vector features are equivariant under any rotation and reflection. Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). """ -function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = identity, vσ::Function = identity; bias = true) +function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = identity, vσ::Function = identity; bias::Bool = true) h = max(vin, vout) # intermediate dimension for vector mapping W_h = randn(Float32, vin, h) W_μ = randn(Float32, h, vout) @@ -273,7 +273,7 @@ function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = ide GeometricVectorPerceptron(W_h, W_μ, scalar, vσ) end -GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; bias = true) = +GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; bias::Bool = true) = GeometricVectorPerceptron(sin => sin, vin => vin, σ, σ; bias) # s: scalar features (sin × batch) From bcf9088a421e696dcaf3e4e47743bfda65851521 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Sat, 2 Mar 2024 20:19:22 +0000 Subject: [PATCH 07/19] add VectorNorm layer --- Project.toml | 1 + src/MessagePassingIPA.jl | 18 +++++++++++++++--- test/runtests.jl | 14 +++++++++++--- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 80a87dc..e45379f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "1.0.0-DEV" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Flux = "0.14" diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 57c65c7..8b5dd0f 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -3,6 +3,7 @@ module MessagePassingIPA using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors using LinearAlgebra: normalize +using Statistics: mean # Algorithm 21 (x1: N, x2: Ca, x3: C) function rigid_from_3points(x1::AbstractVector, x2::AbstractVector, x3::AbstractVector) @@ -281,16 +282,27 @@ GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; b function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArray{T, 3}) where T @assert size(V, 1) == 3 V_h = batched_mul(V, gvp.W_h) - s′ = gvp.scalar(cat(norm1(V_h), s, dims = 1)) + s′ = gvp.scalar(cat(norm1drop(V_h), s, dims = 1)) V_μ = batched_mul(V_h, gvp.W_μ) - V′ = gvp.vσ(unsqueeze(norm1(V_μ), dims = 1)) .* V_μ + V′ = gvp.vσ(unsqueeze(norm1drop(V_μ), dims = 1)) .* V_μ s′, V′ end # This makes chaining by Flux's Chain easier. (gvp::GeometricVectorPerceptron)((s, V)::Tuple{AbstractArray{T, 2}, AbstractArray{T, 3}}) where T = gvp(s, V) +# Normalization for vector features +struct VectorNorm + #eps::Float32 +end + +function (norm::VectorNorm)(V::AbstractArray{T, 3}) where T + @assert size(V, 1) == 3 + V ./ sqrt.(mean(sum(abs2, V, dims = 1), dims = 2)) +end + # L2 norm along the first dimension -norm1(X) = dropdims(sqrt.(sum(abs2, X, dims = 1)), dims = 1) +norm1(X) = sqrt.(sum(abs2, X, dims = 1)) +norm1drop(X) = dropdims(norm1(X), dims = 1) end diff --git a/test/runtests.jl b/test/runtests.jl index 091f2c2..a156b8b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,8 @@ -using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points, GeometricVectorPerceptron +using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points, GeometricVectorPerceptron, VectorNorm using GraphNeuralNetworks: rand_graph using Flux: relu, batched_mul using Rotations: RotMatrix +using Statistics: mean using Test @testset "MessagePassingIPA.jl" begin @@ -70,8 +71,6 @@ using Test # check returned type and size s′, V′ = gvp(s, V) - @show typeof(s′) - @show typeof(V′) @test s′ isa Array{Float32, 2} @test V′ isa Array{Float32, 3} @test size(s′) == (sout, n) @@ -87,4 +86,13 @@ using Test gvp = GeometricVectorPerceptron(12, 24, σ) @test gvp isa GeometricVectorPerceptron end + + @testset "VectorNorm" begin + norm = VectorNorm() + V = randn(Float32, 3, 8, 128) + @test norm(V) isa Array{Float32, 3} + @test size(norm(V)) == size(V) + @test norm(V) ≈ norm(100 * V) + @test all(sqrt.(mean(sum(abs2, norm(V), dims = 1), dims = 2)) .≈ 1) + end end From 7bb4af66bfa1ffd5511d40b8afa08236d9644bb6 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Sat, 2 Mar 2024 21:45:39 +0000 Subject: [PATCH 08/19] add epsilon --- src/MessagePassingIPA.jl | 6 ++++-- test/runtests.jl | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 8b5dd0f..ec91ecd 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -293,12 +293,14 @@ end # Normalization for vector features struct VectorNorm - #eps::Float32 + ϵ::Float32 end +VectorNorm(; eps::Real = 1f-5) = VectorNorm(eps) + function (norm::VectorNorm)(V::AbstractArray{T, 3}) where T @assert size(V, 1) == 3 - V ./ sqrt.(mean(sum(abs2, V, dims = 1), dims = 2)) + V ./ (sqrt.(mean(sum(abs2, V, dims = 1), dims = 2)) .+ norm.ϵ) end # L2 norm along the first dimension diff --git a/test/runtests.jl b/test/runtests.jl index a156b8b..4b428a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,10 +89,14 @@ using Test @testset "VectorNorm" begin norm = VectorNorm() - V = randn(Float32, 3, 8, 128) + V = randn(Float32, 3, 5, 10) @test norm(V) isa Array{Float32, 3} @test size(norm(V)) == size(V) @test norm(V) ≈ norm(100 * V) @test all(sqrt.(mean(sum(abs2, norm(V), dims = 1), dims = 2)) .≈ 1) + + # zero values + V = zeros(Float32, 3, 5, 10) + @test all(!isnan, norm(V)) end end From 02889d5e6e9f8991dcd13a4b3c17f2144e96d309 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Mon, 4 Mar 2024 01:00:22 +0000 Subject: [PATCH 09/19] add more explicit type constraints --- src/MessagePassingIPA.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index ec91ecd..7667a08 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -232,10 +232,10 @@ sumdrop(x; dims) = dropdims(sum(x; dims); dims) # --------------------------- struct GeometricVectorPerceptron - W_h - W_μ + W_h::AbstractMatrix + W_μ::AbstractMatrix scalar::Dense - vσ + vσ::Function end Flux.@functor GeometricVectorPerceptron From 7c7a4851bd3fb6534859735e0e8a3a0e2e162676 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Mon, 4 Mar 2024 15:40:24 +0000 Subject: [PATCH 10/19] use glorot_uniform weight initialization --- src/MessagePassingIPA.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 7667a08..c5fedbd 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -266,11 +266,18 @@ the vector features are equivariant under any rotation and reflection. Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). """ -function GeometricVectorPerceptron((sin, sout), (vin, vout), sσ::Function = identity, vσ::Function = identity; bias::Bool = true) +function GeometricVectorPerceptron( + (sin, sout), + (vin, vout), + sσ::Function = identity, + vσ::Function = identity; + bias::Bool = true, + init = Flux.glorot_uniform +) h = max(vin, vout) # intermediate dimension for vector mapping - W_h = randn(Float32, vin, h) - W_μ = randn(Float32, h, vout) - scalar = Dense(sin + h => sout, sσ; bias) + W_h = init(vin, h) + W_μ = init(h, vout) + scalar = Dense(sin + h => sout, sσ; bias, init) GeometricVectorPerceptron(W_h, W_μ, scalar, vσ) end From 292388629f3721cff09c16d655553ea2558fae93 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Tue, 5 Mar 2024 10:20:13 +0000 Subject: [PATCH 11/19] fix --- src/MessagePassingIPA.jl | 24 ++++++++---------------- test/runtests.jl | 7 +------ 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index c5fedbd..85cb5e3 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -242,10 +242,8 @@ Flux.@functor GeometricVectorPerceptron """ GeometricVectorPerceptron( - sin => sout, - vin => vout, - sσ = identity, - vσ = identity; + (sin, vin) => (sout, vout), + (sσ, vσ) = (identity, identity); bias = true ) @@ -258,19 +256,16 @@ of scalar and vector feature arrays that have the size of `sout × batchsize` an the vector features are equivariant under any rotation and reflection. # Arguments -- `sin => sout`: scalar input and output dimensions -- `vin => vout`: vector input and output dimensions -- `sσ`: scalar nonlinearlity -- `vσ`: vector nonlinearlity -- `bias`: includes a bias term iff `bias = true`` +- `sin`, `vin`: scalar and vector input dimensions +- `sout`, `vout`: scalar and vector output dimensions +- `sσ`, `vσ`: scalar and vector nonlinearlities +- `bias`: includes a bias term iff `bias = true` Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). """ function GeometricVectorPerceptron( - (sin, sout), - (vin, vout), - sσ::Function = identity, - vσ::Function = identity; + ((sin, vin), (sout, vout)), + (sσ, vσ) = (identity, identity); bias::Bool = true, init = Flux.glorot_uniform ) @@ -281,9 +276,6 @@ function GeometricVectorPerceptron( GeometricVectorPerceptron(W_h, W_μ, scalar, vσ) end -GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; bias::Bool = true) = - GeometricVectorPerceptron(sin => sin, vin => vin, σ, σ; bias) - # s: scalar features (sin × batch) # V: vector feautres (3 × vin × batch) function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArray{T, 3}) where T diff --git a/test/runtests.jl b/test/runtests.jl index 4b428a8..eabdb8f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,8 +62,7 @@ using Test @testset "GeometricVectorPerceptron" begin sin, sout = 8, 12 vin, vout = 10, 14 - σ = relu - gvp = GeometricVectorPerceptron(sin => sout, vin => vout, σ, σ) + gvp = GeometricVectorPerceptron((sin, vin) => (sout, vout), (relu, identity)) n = 12 # scalar and vector feautres s = randn(Float32, sin, n) @@ -81,10 +80,6 @@ using Test s″, V″ = gvp(s, batched_mul(R, V)) @test s″ ≈ s′ @test V″ ≈ batched_mul(R, V′) - - # utility constructor where #inputs == #outputs - gvp = GeometricVectorPerceptron(12, 24, σ) - @test gvp isa GeometricVectorPerceptron end @testset "VectorNorm" begin From a506ae8366bf46311dd9c8267fb057bab06a2a97 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Tue, 5 Mar 2024 17:03:19 +0000 Subject: [PATCH 12/19] add vector gating --- src/MessagePassingIPA.jl | 22 +++++++++++++++++----- test/runtests.jl | 35 +++++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 85cb5e3..b75ad3d 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -1,6 +1,6 @@ module MessagePassingIPA -using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus +using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus, sigmoid using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors using LinearAlgebra: normalize using Statistics: mean @@ -235,7 +235,9 @@ struct GeometricVectorPerceptron W_h::AbstractMatrix W_μ::AbstractMatrix scalar::Dense + sσ::Function vσ::Function + vgate::Union{Dense, Nothing} end Flux.@functor GeometricVectorPerceptron @@ -267,13 +269,18 @@ function GeometricVectorPerceptron( ((sin, vin), (sout, vout)), (sσ, vσ) = (identity, identity); bias::Bool = true, + vector_gate::Bool = false, init = Flux.glorot_uniform ) h = max(vin, vout) # intermediate dimension for vector mapping W_h = init(vin, h) W_μ = init(h, vout) - scalar = Dense(sin + h => sout, sσ; bias, init) - GeometricVectorPerceptron(W_h, W_μ, scalar, vσ) + scalar = Dense(sin + h => sout; bias, init) + vgate = nothing + if vector_gate + vgate = Dense(sout => vout, sigmoid; init) + end + GeometricVectorPerceptron(W_h, W_μ, scalar, sσ, vσ, vgate) end # s: scalar features (sin × batch) @@ -281,9 +288,14 @@ end function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArray{T, 3}) where T @assert size(V, 1) == 3 V_h = batched_mul(V, gvp.W_h) - s′ = gvp.scalar(cat(norm1drop(V_h), s, dims = 1)) + s_m = gvp.scalar(cat(norm1drop(V_h), s, dims = 1)) V_μ = batched_mul(V_h, gvp.W_μ) - V′ = gvp.vσ(unsqueeze(norm1drop(V_μ), dims = 1)) .* V_μ + s′ = gvp.sσ.(s_m) + if gvp.vgate === nothing + V′ = gvp.vσ.(unsqueeze(norm1drop(V_μ), dims = 1)) .* V_μ + else + V′ = unsqueeze(gvp.vgate(gvp.vσ.(s_m)), dims = 1) .* V_μ + end s′, V′ end diff --git a/test/runtests.jl b/test/runtests.jl index eabdb8f..3692ca1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,26 +60,33 @@ using Test end @testset "GeometricVectorPerceptron" begin + # scalar and vector feautres + n = 12 sin, sout = 8, 12 vin, vout = 10, 14 - gvp = GeometricVectorPerceptron((sin, vin) => (sout, vout), (relu, identity)) - n = 12 - # scalar and vector feautres s = randn(Float32, sin, n) V = randn(Float32, 3, vin, n) - # check returned type and size - s′, V′ = gvp(s, V) - @test s′ isa Array{Float32, 2} - @test V′ isa Array{Float32, 3} - @test size(s′) == (sout, n) - @test size(V′) == (3, vout, n) + for vector_gate in [false, true] + gvp = GeometricVectorPerceptron( + (sin, vin) => (sout, vout), + (relu, identity); + vector_gate + ) + + # check returned type and size + s′, V′ = gvp(s, V) + @test s′ isa Array{Float32, 2} + @test V′ isa Array{Float32, 3} + @test size(s′) == (sout, n) + @test size(V′) == (3, vout, n) - # check invariance and equivariance - R = rand(RotMatrix{3, Float32}) - s″, V″ = gvp(s, batched_mul(R, V)) - @test s″ ≈ s′ - @test V″ ≈ batched_mul(R, V′) + # check invariance and equivariance + R = rand(RotMatrix{3, Float32}) + s″, V″ = gvp(s, batched_mul(R, V)) + @test s″ ≈ s′ + @test V″ ≈ batched_mul(R, V′) + end end @testset "VectorNorm" begin From e448973c8df5008418d314744ec5ce4e401ab665 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Tue, 5 Mar 2024 17:07:12 +0000 Subject: [PATCH 13/19] --amend --- src/MessagePassingIPA.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index b75ad3d..e3117e9 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -246,7 +246,8 @@ Flux.@functor GeometricVectorPerceptron GeometricVectorPerceptron( (sin, vin) => (sout, vout), (sσ, vσ) = (identity, identity); - bias = true + bias = true, + vector_gate = false ) Create a geometric vector perceptron layer. @@ -262,8 +263,11 @@ the vector features are equivariant under any rotation and reflection. - `sout`, `vout`: scalar and vector output dimensions - `sσ`, `vσ`: scalar and vector nonlinearlities - `bias`: includes a bias term iff `bias = true` +- `vector_gate`: includes vector gating iff `vector_gate = true` -Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). +# References +- Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). +- Jing, Bowen, et al. "Equivariant graph neural networks for 3d macromolecular structure." arXiv preprint arXiv:2106.03843 (2021). """ function GeometricVectorPerceptron( ((sin, vin), (sout, vout)), From 979c7a07ccb7963e6dec46ee7bf0410673e8c6fb Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Wed, 6 Mar 2024 09:17:12 +0000 Subject: [PATCH 14/19] add GeometricVectorPerceptronGNN --- src/MessagePassingIPA.jl | 66 +++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 23 +++++++++++++- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index e3117e9..b51fb18 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -1,6 +1,6 @@ module MessagePassingIPA -using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus, sigmoid +using Flux: Flux, Dense, Chain, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus, sigmoid, relu using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors using LinearAlgebra: normalize using Statistics: mean @@ -306,6 +306,70 @@ end # This makes chaining by Flux's Chain easier. (gvp::GeometricVectorPerceptron)((s, V)::Tuple{AbstractArray{T, 2}, AbstractArray{T, 3}}) where T = gvp(s, V) +struct GeometricVectorPerceptronGNN + gvpstack::Chain +end + +""" + GeometricVectorPerceptronGNN( + (sn, vn), + (se, ve), + (sσ, vσ) = (relu, relu); + n_hidden_layers = 1, + vector_gate = false, + ) + +Create a graph neural network with geometric vector perceptrons. + +This layer first concatenates the node and the edge features and then propagates +them over the graph. It returns a pair of scalr and vector feature arrays that +have the same size of input node features. + +# Arguments +- `sn`, `vn`: scalar and vector dimensions of node features +- `se`, `ve`: scalar and vector dimensions of edge features +- `sσ`, `sσ`: scalar and vector nonlinearlities +- `vector_gate`: includes vector gating iff `vector_gate = true` +- `n_intermediate_layers`: number of intermediate layers between the input and the output geometric vector perceptrons +""" +function GeometricVectorPerceptronGNN( + (sn, vn)::Tuple{Integer, Integer}, + (se, ve)::Tuple{Integer, Integer}, + (sσ, vσ)::Tuple{Function, Function} = (relu, relu); + vector_gate::Bool = false, + n_intermediate_layers::Integer = 1, +) + gvpstack = Chain( + # input layer + GeometricVectorPerceptron((sn + se, vn + ve) => (sn, vn), (sσ, vσ); vector_gate), + # intermediate layers + [ + GeometricVectorPerceptron((sn, vn) => (sn, vn), (sσ, vσ); vector_gate) + for _ in 1:n_intermediate_layers + ]..., + # output layers + GeometricVectorPerceptron((sn, vn) => (sn, vn)), + ) + GeometricVectorPerceptronGNN(gvpstack) +end + +function (gnn::GeometricVectorPerceptronGNN)( + g::GNNGraph, + (sn, vn)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}}, + (se, ve)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}}, +) where T + # run message passing + function message(_, xj, e) + s = cat(xj.s, e.s, dims = 1) + v = cat(xj.v, e.v, dims = 2) + gnn.gvpstack((s, v)) + end + xj = (s = sn, v = vn) + e = (s = se, v = ve) + msgs = apply_edges(message, g; xj, e) + aggregate_neighbors(g, mean, msgs) # return (s, v) +end + # Normalization for vector features struct VectorNorm ϵ::Float32 diff --git a/test/runtests.jl b/test/runtests.jl index 3692ca1..df603ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,7 @@ -using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points, GeometricVectorPerceptron, VectorNorm +using MessagePassingIPA: + RigidTransformation, InvariantPointAttention, + transform, inverse_transform, compose, rigid_from_3points, + GeometricVectorPerceptron, GeometricVectorPerceptronGNN, VectorNorm using GraphNeuralNetworks: rand_graph using Flux: relu, batched_mul using Rotations: RotMatrix @@ -89,6 +92,24 @@ using Test end end + @testset "GeometricVectorPerceptronGNN" begin + n = 10 + m = 8n + g = rand_graph(n, m) + + sn, vn = 8, 12 + se, ve = 10, 14 + gnn = GeometricVectorPerceptronGNN((sn, vn), (se, ve)) + + node_embeddings = randn(Float32, sn, n), randn(Float32, 3, vn, n) + edge_embeddings = randn(Float32, se, m), randn(Float32, 3, ve, m) + node_embeddings = gnn(g, node_embeddings, edge_embeddings) + @test node_embeddings isa Tuple{Array{Float32, 2}, Array{Float32, 3}} + s, v = node_embeddings + @test size(s) == (sn, n) + @test size(v) == (3, vn, n) + end + @testset "VectorNorm" begin norm = VectorNorm() V = randn(Float32, 3, 5, 10) From ffffcfbcf97376e0359d015221056fc7ad3d4f8a Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Wed, 6 Mar 2024 10:23:01 +0000 Subject: [PATCH 15/19] check in/equivariance --- test/runtests.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index df603ce..2c8e601 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -100,14 +100,23 @@ using Test sn, vn = 8, 12 se, ve = 10, 14 gnn = GeometricVectorPerceptronGNN((sn, vn), (se, ve)) - node_embeddings = randn(Float32, sn, n), randn(Float32, 3, vn, n) edge_embeddings = randn(Float32, se, m), randn(Float32, 3, ve, m) - node_embeddings = gnn(g, node_embeddings, edge_embeddings) - @test node_embeddings isa Tuple{Array{Float32, 2}, Array{Float32, 3}} - s, v = node_embeddings - @test size(s) == (sn, n) - @test size(v) == (3, vn, n) + + # check returned type and size + results = gnn(g, node_embeddings, edge_embeddings) + @test results isa Tuple{Array{Float32, 2}, Array{Float32, 3}} + s′, v′ = results + @test size(s′) == (sn, n) + @test size(v′) == (3, vn, n) + + # check invariance and equivariance + R = rand(RotMatrix{3, Float32}) + node_embeddings = (node_embeddings[1], batched_mul(R, node_embeddings[2])) + edge_embeddings = (edge_embeddings[1], batched_mul(R, edge_embeddings[2])) + s″, v″ = gnn(g, node_embeddings, edge_embeddings) + @test s″ ≈ s′ + @test v″ ≈ batched_mul(R, v′) end @testset "VectorNorm" begin From 3adbe651f5b81fbba90ac8a3968bc41b3f190104 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Wed, 6 Mar 2024 10:26:49 +0000 Subject: [PATCH 16/19] fix typo --- src/MessagePassingIPA.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index b51fb18..cef2321 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -328,7 +328,7 @@ have the same size of input node features. # Arguments - `sn`, `vn`: scalar and vector dimensions of node features - `se`, `ve`: scalar and vector dimensions of edge features -- `sσ`, `sσ`: scalar and vector nonlinearlities +- `sσ`, `vσ`: scalar and vector nonlinearlities - `vector_gate`: includes vector gating iff `vector_gate = true` - `n_intermediate_layers`: number of intermediate layers between the input and the output geometric vector perceptrons """ From ed26b45fff5e9ff769b5804726c58e52a6aac891 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Wed, 6 Mar 2024 10:27:18 +0000 Subject: [PATCH 17/19] fix --- src/MessagePassingIPA.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index cef2321..fd2383b 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -315,8 +315,8 @@ end (sn, vn), (se, ve), (sσ, vσ) = (relu, relu); - n_hidden_layers = 1, vector_gate = false, + n_intermediate_layers = 1, ) Create a graph neural network with geometric vector perceptrons. From cfc373651fdd2e055140b0f597b0745bc1b2e96d Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Mon, 11 Mar 2024 23:00:28 +0000 Subject: [PATCH 18/19] add functor --- src/MessagePassingIPA.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index fd2383b..459e9a7 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -310,6 +310,8 @@ struct GeometricVectorPerceptronGNN gvpstack::Chain end +Flux.@functor GeometricVectorPerceptronGNN + """ GeometricVectorPerceptronGNN( (sn, vn), From a39c8665a22865ed38352120a4d3e7153c0bf561 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 15 Mar 2024 17:48:30 +0000 Subject: [PATCH 19/19] concatenate node i features --- src/MessagePassingIPA.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 459e9a7..a21e71a 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -343,7 +343,7 @@ function GeometricVectorPerceptronGNN( ) gvpstack = Chain( # input layer - GeometricVectorPerceptron((sn + se, vn + ve) => (sn, vn), (sσ, vσ); vector_gate), + GeometricVectorPerceptron((2sn + se, 2vn + ve) => (sn, vn), (sσ, vσ); vector_gate), # intermediate layers [ GeometricVectorPerceptron((sn, vn) => (sn, vn), (sσ, vσ); vector_gate) @@ -361,14 +361,14 @@ function (gnn::GeometricVectorPerceptronGNN)( (se, ve)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}}, ) where T # run message passing - function message(_, xj, e) - s = cat(xj.s, e.s, dims = 1) - v = cat(xj.v, e.v, dims = 2) + function message(xi, xj, e) + s = cat(xi.s, xj.s, e.s, dims = 1) + v = cat(xi.v, xj.v, e.v, dims = 2) gnn.gvpstack((s, v)) end - xj = (s = sn, v = vn) + xi = xj = (s = sn, v = vn) e = (s = se, v = ve) - msgs = apply_edges(message, g; xj, e) + msgs = apply_edges(message, g; xi, xj, e) aggregate_neighbors(g, mean, msgs) # return (s, v) end