Skip to content

Commit

Permalink
add compose
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Nov 6, 2023
1 parent 0d76040 commit 00f84e8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
15 changes: 15 additions & 0 deletions src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ inverse_transform(rigid::RigidTransformation{T}, y::AbstractArray{T,3}) where {T
y .- unsqueeze(rigid.translations, dims=2),
)

"""
compose(rigid1::RigidTransformation, rigid2::RigidTransformation)
Compose two rigid transformations.
"""
function compose(
rigid1::RigidTransformation{T},
rigid2::RigidTransformation{T},
) where {T}
rotations = batched_mul(rigid1.rotations, rigid2.rotations)
translations =
batched_vec(rigid1.rotations, rigid2.translations) + rigid1.translations
return RigidTransformation(rotations, translations)
end


# Invariant point attention
# -------------------------
Expand Down
27 changes: 20 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MessagePassingIPA
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, transform, inverse_transform, compose, rigid_from_3points
using GraphNeuralNetworks: rand_graph
using Rotations: RotMatrix
using Test
Expand All @@ -8,17 +8,30 @@ using Test
n = 100
rotations = stack(rand(RotMatrix{3,Float32}) for _ in 1:n)
translations = randn(Float32, 3, n)
rigid = MessagePassingIPA.RigidTransformation(rotations, translations)
rigid = RigidTransformation(rotations, translations)
x = randn(Float32, 3, 12, n)
y = MessagePassingIPA.transform(rigid, x)
y = transform(rigid, x)
@test size(x) == size(y)
@test x MessagePassingIPA.inverse_transform(rigid, y)
@test x inverse_transform(rigid, y)

n = 100
rigid1 =
RigidTransformation(stack(rand(RotMatrix{3,Float32})
for _ in 1:n), randn(Float32, 3, n))
rigid2 =
RigidTransformation(stack(rand(RotMatrix{3,Float32})
for _ in 1:n), randn(Float32, 3, n))
rigid12 = compose(rigid1, rigid2)
x = randn(Float32, 3, 12, n)
@test transform(rigid12, x) transform(rigid1, transform(rigid2, x))
y = transform(rigid12, x)
@test x inverse_transform(rigid2, inverse_transform(rigid1, y))
end

@testset "InvariantPointAttention" begin
n_dims_s = 32
n_dims_z = 16
ipa = MessagePassingIPA.InvariantPointAttention(n_dims_s, n_dims_z)
ipa = InvariantPointAttention(n_dims_s, n_dims_z)

n_nodes = 100
n_edges = 500
Expand All @@ -31,7 +44,7 @@ using Test
x1 = c .+ randn(Float32, 3, n_nodes)
x2 = c .+ randn(Float32, 3, n_nodes)
x3 = c .+ randn(Float32, 3, n_nodes)
rigid1 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...)
rigid1 = RigidTransformation(rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) isa Matrix{Float32}
@test size(ipa(g, s, z, rigid1)) == (n_dims_s, n_nodes)

Expand All @@ -40,7 +53,7 @@ using Test
x1 = R * x1 .+ t
x2 = R * x2 .+ t
x3 = R * x3 .+ t
rigid2 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...)
rigid2 = RigidTransformation(rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) ipa(g, s, z, rigid2)
end
end

0 comments on commit 00f84e8

Please sign in to comment.