Skip to content

Commit

Permalink
add VectorNorm layer
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Mar 2, 2024
1 parent a7fbada commit bcf9088
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.0.0-DEV"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Flux = "0.14"
Expand Down
18 changes: 15 additions & 3 deletions src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module MessagePassingIPA
using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus
using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors
using LinearAlgebra: normalize
using Statistics: mean

# Algorithm 21 (x1: N, x2: Ca, x3: C)
function rigid_from_3points(x1::AbstractVector, x2::AbstractVector, x3::AbstractVector)
Expand Down Expand Up @@ -281,16 +282,27 @@ GeometricVectorPerceptron(sin::Integer, vin::Integer, σ::Function = identity; b
function (gvp::GeometricVectorPerceptron)(s::AbstractArray{T, 2}, V::AbstractArray{T, 3}) where T
@assert size(V, 1) == 3
V_h = batched_mul(V, gvp.W_h)
s′ = gvp.scalar(cat(norm1(V_h), s, dims = 1))
s′ = gvp.scalar(cat(norm1drop(V_h), s, dims = 1))
V_μ = batched_mul(V_h, gvp.W_μ)
V′ = gvp.(unsqueeze(norm1(V_μ), dims = 1)) .* V_μ
V′ = gvp.(unsqueeze(norm1drop(V_μ), dims = 1)) .* V_μ
s′, V′
end

# This makes chaining by Flux's Chain easier.
(gvp::GeometricVectorPerceptron)((s, V)::Tuple{AbstractArray{T, 2}, AbstractArray{T, 3}}) where T = gvp(s, V)

# Normalization for vector features
struct VectorNorm
#eps::Float32
end

function (norm::VectorNorm)(V::AbstractArray{T, 3}) where T
@assert size(V, 1) == 3
V ./ sqrt.(mean(sum(abs2, V, dims = 1), dims = 2))
end

# L2 norm along the first dimension
norm1(X) = dropdims(sqrt.(sum(abs2, X, dims = 1)), dims = 1)
norm1(X) = sqrt.(sum(abs2, X, dims = 1))
norm1drop(X) = dropdims(norm1(X), dims = 1)

end
14 changes: 11 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points, GeometricVectorPerceptron
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points, GeometricVectorPerceptron, VectorNorm
using GraphNeuralNetworks: rand_graph
using Flux: relu, batched_mul
using Rotations: RotMatrix
using Statistics: mean
using Test

@testset "MessagePassingIPA.jl" begin
Expand Down Expand Up @@ -70,8 +71,6 @@ using Test

# check returned type and size
s′, V′ = gvp(s, V)
@show typeof(s′)
@show typeof(V′)
@test s′ isa Array{Float32, 2}
@test V′ isa Array{Float32, 3}
@test size(s′) == (sout, n)
Expand All @@ -87,4 +86,13 @@ using Test
gvp = GeometricVectorPerceptron(12, 24, σ)
@test gvp isa GeometricVectorPerceptron
end

@testset "VectorNorm" begin
norm = VectorNorm()
V = randn(Float32, 3, 8, 128)
@test norm(V) isa Array{Float32, 3}
@test size(norm(V)) == size(V)
@test norm(V) norm(100 * V)
@test all(sqrt.(mean(sum(abs2, norm(V), dims = 1), dims = 2)) .≈ 1)
end
end

0 comments on commit bcf9088

Please sign in to comment.