From 14f17b7989c33cb095656d1a1acb1cf6e329ed35 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 18 Sep 2024 01:12:40 +0200 Subject: [PATCH] Revamp --- Project.toml | 3 +- ext/FunctorsExt.jl | 1 - src/BatchedTransformations.jl | 14 +- src/batched/affine.jl | 131 ++++++++++++ src/batched/batched.jl | 17 ++ src/{geometric => batched}/batched_utils.jl | 10 +- src/{geometric => batched}/rand.jl | 34 ++-- src/core.jl | 12 +- src/geometric/geometric.jl | 101 ---------- test/ext/FunctorsExt.jl | 1 - test/runtests.jl | 209 +++++++++++--------- 11 files changed, 304 insertions(+), 229 deletions(-) create mode 100644 src/batched/affine.jl create mode 100644 src/batched/batched.jl rename src/{geometric => batched}/batched_utils.jl (72%) rename src/{geometric => batched}/rand.jl (52%) delete mode 100644 src/geometric/geometric.jl diff --git a/Project.toml b/Project.toml index 511fe56..2e092d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "BatchedTransformations" uuid = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b" authors = ["Anton Oresten and contributors"] -version = "0.4.0" +version = "0.5.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/ext/FunctorsExt.jl b/ext/FunctorsExt.jl index ba694f2..75f323f 100644 --- a/ext/FunctorsExt.jl +++ b/ext/FunctorsExt.jl @@ -8,6 +8,5 @@ using Functors: @functor @functor Linear @functor Translation @functor Affine -@functor Rotation end \ No newline at end of file diff --git a/src/BatchedTransformations.jl b/src/BatchedTransformations.jl index 0e50631..b083f87 100644 --- a/src/BatchedTransformations.jl +++ b/src/BatchedTransformations.jl @@ -2,16 +2,18 @@ module BatchedTransformations include("core.jl") export Transformation, transform, inverse_transform -export batchsize export Identity export Composed, compose export outer, inner export Inverse, inverse -include("geometric/geometric.jl") -export GeometricTransformation, AbstractAffine, AbstractLinear -export Translation, Linear, Affine -export Rotation, Rigid -export linear, translation +include("batched/batched.jl") +export BatchedTransformation +export batchsize, batchreshape, batchunsqueeze +export AbstractAffine, translation, linear +export Translation +export Homomorphic, Endomorphic, Automorphic +export Linear, Orthonormal, Rotation, Reflection +export Affine, Rigid end \ No newline at end of file diff --git a/src/batched/affine.jl b/src/batched/affine.jl new file mode 100644 index 0000000..0d23bc8 --- /dev/null +++ b/src/batched/affine.jl @@ -0,0 +1,131 @@ +abstract type AbstractAffine <: GeometricTransformation end + +function translation end +function linear end + +Base.iterate(affine::AbstractAffine, state=0) = state == 0 ? (translation(affine), 1) : (state == 1 ? (linear(affine), nothing) : nothing) + +abstract type Homomorphic end +abstract type Endomorphic <: Homomorphic end +abstract type Automorphic <: Endomorphic end + +struct Linear{M<:Homomorphic,A<:AbstractArray} <: AbstractAffine + values::A +end + +abstract type Orthonormal{Det} <: Automorphic end + +const Rotation = Linear{Orthonormal{1}} +const Reflection = Linear{Orthonormal{-1}} + +@inline Linear{M}(values::A) where {M,A} = Linear{M,A}(values) +@inline Linear{M}(linear::Linear) where M = Linear{M}(values(linear)) + +@inline function Linear{M}(values::A) where {M<:Endomorphic,A<:AbstractArray} + size(values, 1) == size(values, 2) || error("rotation values must have size (n, n, batchdims...)") + Linear{M,A}(values) +end + +@inline function Linear(values::A) where A<:AbstractArray + M = size(values, 1) == size(values, 2) ? Endomorphic : Homomorphic + Linear{M,A}(values) +end + +@inline compose(l2::Linear{M1}, l1::Linear{M2}) where {M1<:Homomorphic,M2<:Homomorphic} = Linear{typejoin(M1,M2)}(l2 * values(l1)) + +@inline linear(linear::Linear) = linear +@inline translation(::Linear) = Identity() + +@inline Base.values(linear::Linear) = linear.values +@inline Base.:(==)(l1::Linear, l2::Linear) = values(l1) == values(l2) + +batchsize(linear::Linear) = size(values(linear))[3:end] + +function batchreshape(linear::Linear{M}, args...) where M + A = values(linear) + Linear{M}(reshape(A, size(A, 1), size(A, 2), args...)) +end + +function batchunsqueeze(linear::Linear{M}; dims::Int) where M + @assert dims > 0 + Linear{M}(unsqueeze(values(linear), dims=dims+2)) +end + +transform(l::Linear, x::AbstractArray) = values(l) ⊠ x + +transform(linear::Linear, x::AbstractVecOrMat) = batched_mul_large_small(values(linear), x) + +inverse_transform(t::Linear{<:Orthonormal}, x::AbstractArray) = batched_mul_T1(values(t), x) + +Base.inv(t::Linear{M}) where M<:Automorphic = Linear{M}(mapslices(inv, values(t), dims=(1,2))) + +Base.inv(t::Linear{M,<:AbstractArray{<:Any,2}}) where M<:Orthonormal = Linear{M}(transpose(values(t))) +Base.inv(t::Linear{M,<:AbstractArray{<:Any,3}}) where M<:Orthonormal = Linear{M}(batched_transpose(values(t))) +Base.inv(t::Linear{M}) where M<:Orthonormal = Linear{M}(permutedims(values(t), (2, 1, 3:ndims(values(t))...))) + + +struct Translation{A<:AbstractArray} <: AbstractAffine + values::A + + function Translation{A}(values::A) where A<:AbstractArray + size(values, 2) == 1 || error("translation values must have size (n, 1, batchdims...)") + new{A}(values) + end +end + +Translation(values::A) where A = Translation{A}(values) + +@inline linear(::Translation) = Identity() +@inline translation(translation::Translation) = translation + +@inline Base.values(translation::Translation) = translation.values +@inline Base.:(==)(t1::Translation, t2::Translation) = values(t1) == values(t2) + +batchsize(translation::Translation) = size(values(translation))[3:end] + +function batchreshape(translation::Translation, args...) + b = values(translation) + Translation(reshape(b, size(b, 1), 1, args...)) +end + +function batchunsqueeze(translation::Translation; dims::Int) + @assert dims > 0 + Translation(unsqueeze(values(translation), dims=dims+2)) +end + +transform(t::Translation, x::AbstractArray) = x .+ values(t) +inverse_transform(t::Translation, x::AbstractArray) = x .- values(t) + +Base.inv(t::Translation) = Translation(-values(t)) + +@inline compose(t2::Translation, t1::Translation) = Translation(t2 * values(t1)) + + +struct Affine{T<:Translation,L<:Linear{<:Automorphic}} <: AbstractAffine + composed::Composed{T,L} +end + +const Rigid = Affine{<:Translation,<:Rotation} + +@inline linear(affine::Affine) = inner(affine.composed) +@inline translation(affine::Affine) = outer(affine.composed) + +@inline Base.:(==)(affine1::Affine, affine2::Affine) = affine1.composed == affine2.composed + +function batchunsqueeze((translation,linear)::Affine; dims::Int) + batchunsqueeze(translation; dims) ∘ batchunsqueeze(linear; dims) +end + +transform(affine::Affine, x::AbstractArray) = transform(affine.composed, x) +inverse_transform(affine::Affine, x::AbstractArray) = inverse_transform(affine.composed, x) + +Base.inv(affine::Affine) = inv(affine.composed) + +Base.show(io::IO, affine::Affine) = print(io, "$(translation(affine)) ∘ $(linear(affine))") + +@inline compose(translation::Translation, linear::Linear) = Affine(Composed(translation, linear)) +@inline compose(linear::Linear, translation::Translation) = Translation(linear * values(translation)) ∘ linear + +@inline compose((t2,l2)::AbstractAffine, (t1,l1)::AbstractAffine) = (t2 ∘ (l2 ∘ t1)) ∘ l1 +@inline compose((t2,l2)::AbstractAffine, l1::Linear) = t2 ∘ (l2 ∘ l1) +@inline compose(t2::Translation, (t1,l1)::AbstractAffine) = (t2 ∘ t1) ∘ l1 diff --git a/src/batched/batched.jl b/src/batched/batched.jl new file mode 100644 index 0000000..d3cf450 --- /dev/null +++ b/src/batched/batched.jl @@ -0,0 +1,17 @@ +using NNlib: ⊠, batched_mul, batched_transpose +using MLUtils: unsqueeze + +include("batched_utils.jl") + +function batchsize end +function batchreshape end +function batchunsqueeze end + +batchsize(t::Transformation, d::Integer) = batchsize(t)[d] +batchsize(t::Inverse{<:Transformation}) = batchsize(t.parent) + +abstract type GeometricTransformation <: Transformation end + +include("affine.jl") +include("rand.jl") + diff --git a/src/geometric/batched_utils.jl b/src/batched/batched_utils.jl similarity index 72% rename from src/geometric/batched_utils.jl rename to src/batched/batched_utils.jl index 010ae36..ae5662b 100644 --- a/src/geometric/batched_utils.jl +++ b/src/batched/batched_utils.jl @@ -16,6 +16,14 @@ function batched_mul_T2(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T return reshape(z, size(z, 1), size(z, 2), batch_size...) end +function batched_mul_large_small(A::AbstractArray, x::AbstractVecOrMat) + batch_size = size(A)[3:end] + A′ = reshape(A, size(A, 1), size(A, 2), :) + y′ = A′ ⊠ reshape(x, size(x, 1), size(x, 2)) + y = reshape(y′, size(A, 1), size(x, 2), batch_size...) + return y +end + # might need custom chain rule -# could do map(det, eachslice(data, dims=size(data)[3:end], drop=false)) +# could also try map(det, eachslice(data, dims=size(data)[3:end], drop=false)) batched_det(data::AbstractArray{<:Real}) = mapslices(det, data, dims=(1,2)) diff --git a/src/geometric/rand.jl b/src/batched/rand.jl similarity index 52% rename from src/geometric/rand.jl rename to src/batched/rand.jl index 3c57268..3801d0b 100644 --- a/src/geometric/rand.jl +++ b/src/batched/rand.jl @@ -1,22 +1,22 @@ using LinearAlgebra: qr, Diagonal, diag, det using Random: AbstractRNG, default_rng -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, (n, m)::Pair{<:Integer,<:Integer}, batch_size::Dims=()) - values = randn(rng, T, m, n, batch_size...) +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, (n, m)::Pair{<:Integer,<:Integer}, batchdims::Dims=()) + values = randn(rng, T, m, n, batchdims...) return Linear(values) end -Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, n::Integer, batch_size::Dims=()) = - rand(rng, T, Linear, n => n, batch_size) +Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, n::Integer, batchdims::Dims=()) = + rand(rng, T, Linear, n => n, batchdims) -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Translation}, n::Integer, batch_size::Dims=()) - values = randn(rng, T, n, 1, batch_size...) +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Translation}, n::Integer, batchdims::Dims=()) + values = randn(rng, T, n, 1, batchdims...) return Translation(values) end -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Affine}, n::Integer, batch_size::Dims=()) - translation = rand(rng, T, Translation, n, batch_size) - linear = rand(rng, T, Linear, n, batch_size) +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Affine}, n::Integer, batchdims::Dims=()) + translation = rand(rng, T, Translation, n, batchdims) + linear = Linear{Automorphic}(values(rand(rng, T, Linear, n, batchdims))) # doesn't actually guarantee invertibility :/ return translation ∘ linear end @@ -28,17 +28,17 @@ function rand_rotation(rng::AbstractRNG, T::Type{<:Real}, n::Integer) return Q end -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rotation}, n::Integer, batch_size::Dims=()) - values = reshape(stack([rand_rotation(rng, T, n) for _ in 1:prod(batch_size)]), n, n, batch_size...) +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rotation}, n::Integer, batchdims::Dims=()) + values = reshape(stack([rand_rotation(rng, T, n) for _ in 1:prod(batchdims)]), n, n, batchdims...) return Rotation(values) end -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rigid}, n::Integer, batch_size::Dims=()) - translations = rand(rng, T, Translation, n, batch_size) - rotations = rand(rng, T, Rotation, n, batch_size) +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rigid}, n::Integer, batchdims::Dims=()) + translations = rand(rng, T, Translation, n, batchdims) + rotations = rand(rng, T, Rotation, n, batchdims) return translations ∘ rotations end -Base.rand(T::Type{<:Real}, Tr::Type{<:Transformation}, dims, batch_size::Dims=()) = rand(default_rng(), T, Tr, dims, batch_size) -Base.rand(Tr::Type{<:Transformation}, dims, batch_size::Dims) = rand(Float64, Tr, dims, batch_size) -Base.rand(Tr::Type{<:Transformation}, dims::Integer) = rand(Tr, dims, ()) \ No newline at end of file +Base.rand(T::Type{<:Real}, Tr::Type{<:Transformation}, dims, batchdims::Dims=()) = rand(default_rng(), T, Tr, dims, batchdims) +Base.rand(Tr::Type{<:Transformation}, dims, batchdims::Dims=()) = rand(Float64, Tr, dims, batchdims) +Base.rand(Tr::Type{<:Transformation}, dims::Integer) = rand(Tr, dims, ()) # needed to avoid ambiguity with other rand methods \ No newline at end of file diff --git a/src/core.jl b/src/core.jl index d3ae4bb..356ffdc 100644 --- a/src/core.jl +++ b/src/core.jl @@ -8,7 +8,6 @@ that can be applied to an array. A `Transformation` `t` can be applied to abstract type Transformation end function compose end -function batchsize end """ transform(t, x) @@ -32,10 +31,10 @@ Base.show(io::IO, ::MIME"text/plain", t::Transformation) = print(io, summary(t)) """ struct Identity <: Transformation end -transform(::Identity, x) = x -inverse_transform(::Identity, x) = x +@inline transform(::Identity, x) = x +@inline inverse_transform(::Identity, x) = x -Base.inv(::Identity) = Identity() +@inline Base.inv(::Identity) = Identity() @inline compose(::Identity, ::Identity) = Identity() @inline compose(::Identity, t::Transformation) = t @@ -92,12 +91,11 @@ struct Inverse{T<:Transformation} <: Transformation parent::T end -Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent - -batchsize(t::Inverse) = batchsize(t.parent) +@inline Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent @inline inverse(t::Transformation) = Inverse(t) @inline inverse(t::Inverse) = t.parent +@inline inverse(t::Identity) = t @inline transform(t::Inverse, x) = inverse_transform(t.parent, x) diff --git a/src/geometric/geometric.jl b/src/geometric/geometric.jl deleted file mode 100644 index 696653a..0000000 --- a/src/geometric/geometric.jl +++ /dev/null @@ -1,101 +0,0 @@ -using NNlib: ⊠, batched_mul, batched_transpose - -include("batched_utils.jl") - -abstract type GeometricTransformation <: Transformation end - -abstract type AbstractAffine <: GeometricTransformation end - -function translation end -function linear end - -Base.iterate(affine::AbstractAffine, args...) = iterate(affine.composed, args...) - -abstract type AbstractLinear <: AbstractAffine end - -@inline linear(linear::AbstractLinear) = linear -@inline translation(::AbstractLinear) = Identity() - -@inline Base.values(linear::AbstractLinear) = linear.values -@inline Base.:(==)(l1::AbstractLinear, l2::AbstractLinear) = values(l1) == values(l2) -@inline batchsize(linear::AbstractLinear) = size(values(linear))[3:end] - -transform(l::AbstractLinear, x::AbstractArray) = values(l) ⊠ x - -function transform(linear::AbstractLinear, x::AbstractVecOrMat) - A = values(linear) - batch_size = size(A)[3:end] - A′ = reshape(A, size(A, 1), size(A, 2), :) - y′ = A′ ⊠ x - y = reshape(y′, size(A, 1), size(y′, 2), batch_size...) - return y -end - -@inline compose(l2::AbstractLinear, l1::AbstractLinear) = Linear(l2 * values(l1)) - - -struct Linear{A<:AbstractArray} <: AbstractLinear - values::A -end - -Base.inv(t::Linear) = Linear(mapslices(inv, values(t), dims=(1,2))) - - -struct Translation{A<:AbstractArray} <: AbstractAffine - values::A -end - -@inline linear(::Translation) = Identity() -@inline translation(translation::Translation) = translation - -@inline Base.values(translation::Translation) = translation.values -@inline Base.:(==)(t1::Translation, t2::Translation) = values(t1) == values(t2) -@inline batchsize(translation::Translation) = size(values(translation))[3:end] - -transform(t::Translation, x::AbstractArray) = x .+ values(t) -inverse_transform(t::Translation, x::AbstractArray) = x .- values(t) - -Base.inv(t::Translation) = Translation(-values(t)) - -@inline compose(t2::Translation, t1::Translation) = Translation(t2 * values(t1)) - - -struct Affine{T<:Translation,L<:AbstractLinear} <: AbstractAffine - composed::Composed{T,L} -end - -@inline linear(affine::Affine) = inner(affine.composed) -@inline translation(affine::Affine) = outer(affine.composed) - -@inline Base.:(==)(affine1::Affine, affine2::Affine) = affine1.composed == affine2.composed - -transform(affine::Affine, x::AbstractArray) = transform(affine.composed, x) -inverse_transform(affine::Affine, x::AbstractArray) = inverse_transform(affine.composed, x) - -Base.inv(affine::Affine) = inv(affine.composed) - -Base.show(io::IO, affine::Affine) = print(io, "$(translation(affine)) ∘ $(linear(affine))") - -@inline compose(translation::Translation, linear::AbstractLinear) = Affine(Composed(translation, linear)) -@inline compose(linear::AbstractLinear, translation::Translation) = Translation(linear * values(translation)) ∘ linear - -@inline compose((t2,l2)::AbstractAffine, (t1,l1)::AbstractAffine) = (t2 ∘ (l2 ∘ t1)) ∘ l1 -@inline compose((t2,l2)::AbstractAffine, l1::AbstractLinear) = t2 ∘ (l2 ∘ l1) -@inline compose(t2::Translation, (t1,l1)::AbstractAffine) = (t2 ∘ t1) ∘ l1 - - -struct Rotation{A<:AbstractArray} <: AbstractLinear - values::A -end - -inverse_transform(r::Rotation, x::AbstractArray) = batched_mul_T1(values(r), x) - -Base.inv(t::Rotation{<:AbstractArray{<:Any,3}}) = Rotation(batched_transpose(values(t))) -Base.inv(t::Rotation{<:AbstractArray{<:Any,N}}) where N = Rotation(permutedims(values(t), (2, 1, 3:N...))) - -@inline compose(r2::Rotation, r1::Rotation) = Rotation(r2 * values(r1)) - -const Rigid = Affine{<:Translation,<:Rotation} - - -include("rand.jl") diff --git a/test/ext/FunctorsExt.jl b/test/ext/FunctorsExt.jl index 4f56f76..c13b568 100644 --- a/test/ext/FunctorsExt.jl +++ b/test/ext/FunctorsExt.jl @@ -6,7 +6,6 @@ using Functors: functor @test !isempty(functor(Composed(rand(Float32, Translation, 3), rand(Float32, Translation, 3)))[1]) @test !isempty(functor(rand(Float32, Translation, 3))[1]) @test !isempty(functor(rand(Float32, Linear, 3 => 3))[1]) - @test !isempty(functor(rand(Float32, Rotation, 3))[1]) @test !isempty(functor(rand(Float32, Affine, 3))[1]) end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index be1a08d..ba86a11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,113 +12,134 @@ using ChainRulesTestUtils: test_rrule include("ext/ChainRulesCoreExt.jl") include("ext/FunctorsExt.jl") - @testset "transformations.jl" begin - struct FooTransformation <: Transformation end - t = FooTransformation() - x = "Bar" - @test_throws ErrorException transform(t, x) - @test_throws ErrorException inv(t) - @test_throws ErrorException inverse_transform(t, x) - @test_throws ErrorException t * x - @test_throws ErrorException t(x) - - io = IOBuffer() - show(io, MIME("text/plain"), t) - str = String(take!(io)) - @test str == "FooTransformation" - - end - - @testset "inverse.jl" begin - m, n = 3, 3 # linear map needs to be square - batch_size = (2, 4) - x = rand(Float32, n, 2, batch_size...) + @testset "core.jl" begin + + @testset "Transformation" begin + struct Foo <: Transformation end + t = Foo() + x = "Bar" + @test_throws ErrorException transform(t, x) + @test_throws ErrorException inv(t) + @test_throws ErrorException inverse_transform(t, x) + @test_throws ErrorException t * x + @test_throws ErrorException t(x) + + io = IOBuffer() + show(io, MIME("text/plain"), t) + str = String(take!(io)) + @test str == "Foo" + end - l = rand(Float32, Linear, n => m, batch_size) - @test inverse(inverse(l)) === l - @test inverse(l) * x == inv(l) * x - end + @testset "Identity" begin + @test Identity() ∘ Identity() == Identity() + @test inverse(Identity()) == Identity() + end - @testset "compose.jl" begin - m, n = 3, 3 # out, in - batch_size = (2, 4) - x = rand(Float32, n, 5, batch_size...) - - t = rand(Float32, Translation, m, batch_size) - l = rand(Float32, Linear, n => m, batch_size) - c = compose(t, l) - @test t ∘ l == c - #@test c(x) == t(l)(x) - @test c * x == t(l(x)) - @test inv(c)(x) ≈ inv(l) * (inv(t) * x) - end + @testset "Inverse" begin + n = 3 + batchdims = (2, 4) + x = rand(Float32, n, 2, batchdims...) - @testset "affine.jl" begin - m, n = 3, 3 # out, in - batch_size = (2, 4) - x = rand(Float32, n, 5, batch_size...) - - @testset "Linear" begin - l = rand(Float32, Linear, n => m, batch_size) - @test linear(l) isa Linear - @test values(l) isa AbstractArray - @test l * x == values(l) ⊠ x - @test (inv(l) ∘ l) * x ≈ x - @test inv(l) * (l * x) ≈ x + l = rand(Float32, Translation, n, batchdims) + @test inverse(inverse(l)) === l + @test inverse(l) * x == inv(l) * x end - @testset "Translation" begin - t = rand(Float32, Translation, n, batch_size) - @test translation(t) isa Translation - @test values(t) isa AbstractArray - @test t * x == x .+ values(t) - @test (inv(t) ∘ t) * x ≈ x - @test inv(t) * (t * x) ≈ x - end + @testset "Composed" begin + n = 3 + batchdims = (2, 4) + x = rand(Float32, n, 5, batchdims...) - @testset "Affine" begin - affine = rand(Float32, Affine, n, batch_size) - @test linear(affine) isa Linear - @test translation(affine) isa Translation - @test affine * x == values(linear(affine)) ⊠ x .+ values(translation(affine)) - @test (inv(affine) ∘ affine) * x ≈ x - @test inv(affine) * (affine * x) ≈ x + t2 = rand(Float32, Translation, n, batchdims) + t1 = rand(Float32, Translation, n, batchdims) + c = Composed(t2, t1) + @test c * x == t2(t1(x)) + @test inv(c)(x) ≈ inv(t2) * (inv(t1) * x) end - n = 3 - batch_size = (2, 4) - x = rand(Float32, n, 5, batch_size...) - - @testset "Rotation" begin - rotation = rand(Float32, Rotation, n, batch_size) - @test linear(rotation) isa Rotation - @test values(rotation) isa AbstractArray - @test rotation * x == values(linear(rotation)) ⊠ x - @test (inv(rotation) ∘ rotation) * x ≈ x - @test inv(rotation) * (rotation * x) ≈ x - - # NNlib.batched_transpose only supports one batch dimension - @test !isa(values(inv(rand(Float32, Rotation, n, (2,)))), Array) - @test isa(values(inv(rand(Float32, Rotation, n, (2,1)))), Array) - end + end + + @testset "batched.jl" begin + + @testset "affine.jl" begin + n = 3 + batchdims = (2, 4) + x = rand(Float32, n, 5, batchdims...) + + @testset "Linear" begin + l = rand(Float32, Linear, n, batchdims) + @test linear(l) isa Linear + @test values(l) isa AbstractArray + @test l * x == values(l) ⊠ x + + @test_throws ErrorException inv(l) + invertible_l = Linear{Automorphic}(l) + @test (inv(invertible_l) ∘ invertible_l) * x ≈ x + @test inv(invertible_l) * (invertible_l * x) ≈ x + + @test batchsize(l) == batchdims + @test batchsize(batchreshape(l, 1, batchdims...)) == (1, batchdims...) + @test batchsize(batchunsqueeze(l, dims=1)) == (1, batchdims...) + end + + @testset "Translation" begin + t = rand(Float32, Translation, n, batchdims) + @test translation(t) isa Translation + @test values(t) isa AbstractArray + @test t * x == x .+ values(t) + @test (inv(t) ∘ t) * x ≈ x + @test inv(t) * (t * x) ≈ x + + @test batchsize(t) == batchdims + @test batchsize(batchreshape(t, 1, batchdims...)) == (1, batchdims...) + @test batchsize(batchunsqueeze(t, dims=1)) == (1, batchdims...) + end + + @testset "Affine" begin + affine = rand(Float32, Affine, n, batchdims) + @test linear(affine) isa Linear + @test translation(affine) isa Translation + @test affine * x == values(linear(affine)) ⊠ x .+ values(translation(affine)) + @test (inv(affine) ∘ affine) * x ≈ x + @test inv(affine) * (affine * x) ≈ x + end + + n = 3 + batchdims = (2, 4) + x = rand(Float32, n, 5, batchdims...) + + @testset "Rotation" begin + rotation = rand(Float32, Rotation, n, batchdims) + @test linear(rotation) isa Rotation + @test values(rotation) isa AbstractArray + @test rotation * x == values(linear(rotation)) ⊠ x + @test (inv(rotation) ∘ rotation) * x ≈ x + @test inv(rotation) * (rotation * x) ≈ x + + # NNlib.batched_transpose only supports one batch dimension + @test !isa(values(inv(rand(Float32, Rotation, n, (2,)))), Array) + @test isa(values(inv(rand(Float32, Rotation, n, (2,1)))), Array) + end + + @testset "Rigid" begin + rigid = rand(Float32, Rigid, n, batchdims) + @test linear(rigid) isa Rotation + @test translation(rigid) isa Translation + @test rigid * x == values(linear(rigid)) ⊠ x .+ values(translation(rigid)) + @test (inv(rigid) ∘ rigid) * x ≈ x + @test inv(rigid) * (rigid * x) ≈ x + @test rigid ∘ rigid isa Rigid + end - @testset "Rigid" begin - rigid = rand(Float32, Rigid, n, batch_size) - @test linear(rigid) isa Rotation - @test translation(rigid) isa Translation - @test rigid * x == values(linear(rigid)) ⊠ x .+ values(translation(rigid)) - @test (inv(rigid) ∘ rigid) * x ≈ x - @test inv(rigid) * (rigid * x) ≈ x - @test rigid ∘ rigid isa Rigid end - end + @testset "rand.jl" begin - @testset "rand.jl" begin + @testset "Rotation" begin + rotations = rand(Float32, Rotation, 3, (2, 4)) + @test BatchedTransformations.batched_det(values(rotations)) ≈ ones(Float32, 1, 1, 2, 4) + end - @testset "Rotation" begin - rotations = rand(Float32, Rotation, 3, (2, 4)) - @test BatchedTransformations.batched_det(values(rotations)) ≈ ones(Float32, 1, 1, 2, 4) end end