From 38a580f909de1b6850f18ce5f23f1a81853fe606 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 16 Sep 2024 03:01:26 +0200 Subject: [PATCH] Redesign --- Project.toml | 2 +- ext/ChainRulesCoreExt.jl | 32 ++++--- ext/FunctorsExt.jl | 7 +- src/BatchedTransformations.jl | 18 ++-- src/affine/affine.jl | 24 ----- src/affine/linear.jl | 42 --------- src/affine/translation.jl | 20 ---- src/compose.jl | 29 +++--- src/{affine => geometric}/batched_utils.jl | 2 - src/geometric/geometric.jl | 101 +++++++++++++++++++++ src/geometric/rand.jl | 44 +++++++++ src/identity.jl | 10 ++ src/inverse.jl | 30 ++++-- src/rand.jl | 40 -------- src/transformations.jl | 20 ++-- test/ext/ChainRulesCoreExt.jl | 4 +- test/ext/FunctorsExt.jl | 11 ++- test/runtests.jl | 56 ++++++------ 18 files changed, 272 insertions(+), 220 deletions(-) delete mode 100644 src/affine/affine.jl delete mode 100644 src/affine/linear.jl delete mode 100644 src/affine/translation.jl rename src/{affine => geometric}/batched_utils.jl (95%) create mode 100644 src/geometric/geometric.jl create mode 100644 src/geometric/rand.jl create mode 100644 src/identity.jl delete mode 100644 src/rand.jl diff --git a/Project.toml b/Project.toml index 1c15fa7..511fe56 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BatchedTransformations" uuid = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b" authors = ["Anton Oresten and contributors"] -version = "0.3.0" +version = "0.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/ChainRulesCoreExt.jl b/ext/ChainRulesCoreExt.jl index e1e97c7..d418f4b 100644 --- a/ext/ChainRulesCoreExt.jl +++ b/ext/ChainRulesCoreExt.jl @@ -8,9 +8,9 @@ using ChainRulesCore using BatchedTransformations: batched_mul, batched_mul_T1, batched_mul_T2 -function ChainRulesCore.rrule(::typeof(transform), affine_maps::AbstractAffineMaps, x::AbstractArray) - translations, linear_maps = outer(affine_maps), linear(affine_maps) - t, R = values(translations), values(linear_maps) +function ChainRulesCore.rrule(::typeof(transform), affine::Affine, x::AbstractArray) + translation, linear = affine.composed.outer, affine.composed.inner + t, R = values(translation), values(linear) y = batched_mul(R, x) .+ t @@ -21,21 +21,22 @@ function ChainRulesCore.rrule(::typeof(transform), affine_maps::AbstractAffineMa Δt = @thunk(sum(Δy, dims=2)) Δx = @thunk(batched_mul_T1(R, Δy)) - Δtranslations = Tangent{typeof(translations)}(; values=Δt) - Δlinear_maps = Tangent{typeof(linear_maps)}(; values=ΔR) - Δaffine_maps = Tangent{typeof(affine_maps)}(; outer=Δtranslations, inner=Δlinear_maps) + Δtranslation = Tangent{typeof(translation)}(; values=Δt) + Δlinear = Tangent{typeof(linear)}(; values=ΔR) + Δcomposed = Tangent{typeof(affine.composed)}(; outer=Δtranslation, inner=Δlinear) + Δaffine = Tangent{typeof(affine)}(; composed=Δcomposed) - return NoTangent(), Δaffine_maps, Δx + return NoTangent(), Δaffine, Δx end return y, transform_pullback end -function ChainRulesCore.rrule(::typeof(inverse_transform), rigid::RigidTransformations, x::AbstractArray) - translations, rotations = translation(rigid), linear(rigid) - z = inverse_transform(translations, x) # x .- t - y = inverse_transform(rotations, z) # R' * (x .- t) - t, R = values(translations), values(rotations) +function ChainRulesCore.rrule(::typeof(inverse_transform), rigid::Rigid, x::AbstractArray) + translation, rotation = rigid.composed.outer, rigid.composed.inner + z = inverse_transform(translation, x) # x .- t + y = inverse_transform(rotation, z) # R' * (x .- t) + t, R = values(translation), values(rotation) function inverse_transform_pullback(_Δy) Δy = unthunk(_Δy) @@ -44,9 +45,10 @@ function ChainRulesCore.rrule(::typeof(inverse_transform), rigid::RigidTransform Δx = @thunk(batched_mul(R, Δy)) Δt = @thunk(-sum(Δx, dims=2)) # t is in the same position as x, but negated and broadcasted - Δtranslations = Tangent{typeof(translations)}(; values=Δt) - Δrotations = Tangent{typeof(rotations)}(; values=ΔR) - Δrigid = Tangent{typeof(rigid)}(; outer=Δtranslations, inner=Δrotations) + Δtranslation = Tangent{typeof(translation)}(; values=Δt) + Δrotation = Tangent{typeof(rotation)}(; values=ΔR) + Δcomposed = Tangent{typeof(rigid.composed)}(; outer=Δtranslation, inner=Δrotation) + Δrigid = Tangent{typeof(rigid)}(; composed=Δcomposed) return NoTangent(), Δrigid, Δx end diff --git a/ext/FunctorsExt.jl b/ext/FunctorsExt.jl index c9aada6..ba694f2 100644 --- a/ext/FunctorsExt.jl +++ b/ext/FunctorsExt.jl @@ -5,8 +5,9 @@ using Functors: @functor @functor Inverse @functor Composed -@functor LinearMaps -@functor Rotations -@functor Translations +@functor Linear +@functor Translation +@functor Affine +@functor Rotation end \ No newline at end of file diff --git a/src/BatchedTransformations.jl b/src/BatchedTransformations.jl index a85b266..30bd0c3 100644 --- a/src/BatchedTransformations.jl +++ b/src/BatchedTransformations.jl @@ -1,7 +1,11 @@ module BatchedTransformations include("transformations.jl") -export Transformations, transform, inverse_transform +export Transformation, transform, inverse_transform +export batchsize + +include("identity.jl") +export Identity include("inverse.jl") export Inverse, inverse @@ -10,12 +14,10 @@ include("compose.jl") export Composed, compose export outer, inner -include("affine/affine.jl") -export AbstractLinearMaps, LinearMaps, Rotations -export Translations -export AbstractAffineMaps, AffineMaps, RigidTransformations -export translation, linear - -include("rand.jl") +include("geometric/geometric.jl") +export GeometricTransformation, AbstractAffine, AbstractLinear +export Translation, Linear, Affine +export Rotation, Rigid +export linear, translation end \ No newline at end of file diff --git a/src/affine/affine.jl b/src/affine/affine.jl deleted file mode 100644 index 3434fcd..0000000 --- a/src/affine/affine.jl +++ /dev/null @@ -1,24 +0,0 @@ -include("batched_utils.jl") -include("linear.jl") -include("translation.jl") - -const AffineMaps = Composed{<:Translations,<:LinearMaps} -const RigidTransformations = Composed{<:Translations,<:Rotations} - -const AbstractAffineMaps = Union{AffineMaps,RigidTransformations} - -translation(a::AbstractAffineMaps) = outer(a) -linear(a::AbstractAffineMaps) = inner(a) - -function Base.inv(a::AbstractAffineMaps) - inv_t, inv_l = inv(translation(a)), inv(linear(a)) - return inv_l ∘ inv_t -end - -@inline compose(l2::AbstractLinearMaps, t1::Translations) = Translations(l2 * values(t1)) ∘ l2 - -@inline compose((t2,l2)::AbstractAffineMaps, (t1,l1)::AbstractAffineMaps) = (t2 ∘ (l2 ∘ t1)) ∘ l1 -@inline compose((t2,l2)::AbstractAffineMaps, l1::AbstractLinearMaps) = t2 ∘ (l2 ∘ l1) -@inline compose((t2,l2)::AbstractAffineMaps, t1::Translations) = t2 ∘ (l2 ∘ t1) -@inline compose(l2::AbstractLinearMaps, (t1,l1)::AbstractAffineMaps) = (l2 ∘ t1) ∘ l1 -@inline compose(t2::Translations, (t1,l1)::AbstractAffineMaps) = (t2 ∘ t1) ∘ l1 diff --git a/src/affine/linear.jl b/src/affine/linear.jl deleted file mode 100644 index 31fd631..0000000 --- a/src/affine/linear.jl +++ /dev/null @@ -1,42 +0,0 @@ -""" - AbstractLinearMaps <: Transformations -""" -abstract type AbstractLinearMaps <: Transformations end - -Base.values(t::AbstractLinearMaps) = t.values - -linear(l::AbstractLinearMaps) = l - -transform(l::AbstractLinearMaps, x::AbstractArray) = values(l) ⊠ x - -@inline compose(l2::AbstractLinearMaps, l1::AbstractLinearMaps) = LinearMaps(l2 * values(l1)) - -""" - LinearMaps{A<:AbstractArray} <: AbstractLinearMaps - -Contains a batch of linear maps mapping from n-dimensional to m-dimensional space, -represented by an array of size `(m, n, b1, b2, ...)`. -""" -struct LinearMaps{A<:AbstractArray} <: AbstractLinearMaps - values::A -end - -Base.inv(t::LinearMaps) = LinearMaps(mapslices(inv, values(t), dims=(1,2))) - - -""" - Rotations{A<:AbstractArray} <: AbstractLinearMaps - -Contains a batch of n-dimensional rotations matrices, -represented by an array of size `(n, n, b1, b2, ...)`. -""" -struct Rotations{A<:AbstractArray} <: AbstractLinearMaps - values::A -end - -Base.inv(t::Rotations{<:AbstractArray{<:Any,3}}) = Rotations(batched_transpose(values(t))) -Base.inv(t::Rotations{<:AbstractArray{<:Any,N}}) where N = Rotations(permutedims(values(t), (2, 1, 3:N...))) - -inverse_transform(r::Rotations, x::AbstractArray) = batched_mul_T1(values(r), x) - -@inline compose(l2::Rotations, l1::Rotations) = Rotations(l2 * values(l1)) \ No newline at end of file diff --git a/src/affine/translation.jl b/src/affine/translation.jl deleted file mode 100644 index d341688..0000000 --- a/src/affine/translation.jl +++ /dev/null @@ -1,20 +0,0 @@ -""" - Translations{A<:AbstractArray} <: Transformations - -Contains a batch of n-dimensional translation vectors, -represented by an array of size `(n, 1, b1, b2, ...)`. -""" -struct Translations{A<:AbstractArray} <: Transformations - values::A -end - -Base.values(t::Translations) = t.values - -translation(t::Translations) = t - -transform(t::Translations, x::AbstractArray) = x .+ values(t) -inverse_transform(t::Translations, x::AbstractArray) = x .- values(t) - -Base.inv(t::Translations) = Translations(-values(t)) - -@inline compose(t2::Translations, t1::Translations) = Translations(t2 * values(t1)) \ No newline at end of file diff --git a/src/compose.jl b/src/compose.jl index 4b38780..806f131 100644 --- a/src/compose.jl +++ b/src/compose.jl @@ -1,11 +1,12 @@ """ - Composed{Outer<:Transformations,Inner<:Transformations} + Composed{Outer<:Transformation,Inner<:Transformation} -A `Composed` contains two transformations `t2` and `t1` that are composed. -It can be constructed with `compose(t2, t1)` and `t2 ∘ t1`, where `t1` is the -transformation to be applied first, and `t2` second. +A `Composed` contains two transformations `outer` and `inner` that are composed, +where `inner` gets applied first, and then `outer`.. +It can be constructed with `compose(outer, inner)` or `outer ∘ inner`, unless +the `compose` function is overloaded for the specific types. """ -struct Composed{Outer<:Transformations,Inner<:Transformations} <: Transformations +struct Composed{Outer<:Transformation,Inner<:Transformation} <: Transformation outer::Outer inner::Inner end @@ -14,18 +15,20 @@ end compose(t2, t1) t2 ∘ t1 """ -@inline compose(outer::Transformations, inner::Transformations) = Composed(outer, inner) +@inline compose(outer::Transformation, inner::Transformation) = Composed(outer, inner) -@inline Base.:(∘)(outer::Transformations, inner::Transformations) = compose(outer, inner) +@inline Base.:(∘)(outer::Transformation, inner::Transformation) = compose(outer, inner) -outer(composed::Composed) = composed.outer -inner(composed::Composed) = composed.inner +@inline Base.:(==)(a1::Composed, a2::Composed) = a1.outer == a2.outer && a1.inner == a2.inner -transform(t::Composed, x) = transform(outer(t), transform(inner(t), x)) +@inline outer(composed::Composed) = composed.outer +@inline inner(composed::Composed) = composed.inner -inverse_transform(t::Composed, x) = inverse_transform(inner(t), inverse_transform(outer(t), x)) +@inline transform(t::Composed, x) = transform(outer(t), transform(inner(t), x)) -Base.inv(t::Composed) = compose(inv(inner(t)), inv(outer(t))) +@inline inverse_transform(t::Composed, x) = inverse_transform(inner(t), inverse_transform(outer(t), x)) -# t2, t1 = compose(t2, t1) +@inline Base.inv(t::Composed) = inv(inner(t)) ∘ inv(outer(t)) + +# enables `outer, inner = compose(outer, inner)` syntax Base.iterate(t::Composed, state=1) = state == 1 ? (t.outer, 2) : (state == 2 ? (t.inner, nothing) : nothing) diff --git a/src/affine/batched_utils.jl b/src/geometric/batched_utils.jl similarity index 95% rename from src/affine/batched_utils.jl rename to src/geometric/batched_utils.jl index 789580b..010ae36 100644 --- a/src/affine/batched_utils.jl +++ b/src/geometric/batched_utils.jl @@ -1,5 +1,3 @@ -using NNlib: ⊠, batched_mul, batched_transpose - function batched_mul_T1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} batch_size = size(x)[3:end] @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." diff --git a/src/geometric/geometric.jl b/src/geometric/geometric.jl new file mode 100644 index 0000000..f66f6ff --- /dev/null +++ b/src/geometric/geometric.jl @@ -0,0 +1,101 @@ +using NNlib: ⊠, batched_mul, batched_transpose + +include("batched_utils.jl") + +abstract type GeometricTransformation <: Transformation end + +abstract type AbstractAffine <: GeometricTransformation end + +function translation end +function linear end + +Base.iterate(affine::AbstractAffine, state=0) = state == 0 ? (translation(affine), 1) : (state == 1 ? (linear(affine), nothing) : nothing) + +abstract type AbstractLinear <: AbstractAffine end + +@inline linear(linear::AbstractLinear) = linear +@inline translation(::AbstractLinear) = Identity() + +@inline Base.values(linear::AbstractLinear) = linear.values +@inline Base.:(==)(l1::AbstractLinear, l2::AbstractLinear) = values(l1) == values(l2) +@inline batchsize(linear::AbstractLinear) = size(values(linear))[3:end] + +transform(l::AbstractLinear, x::AbstractArray) = values(l) ⊠ x + +function transform(linear::AbstractLinear, x::AbstractVecOrMat) + A = values(linear) + batch_size = size(A)[3:end] + A′ = reshape(A, size(A, 1), size(A, 2), :) + y′ = A′ ⊠ x + y = reshape(y′, size(A, 1), size(y′, 2), batch_size...) + return y +end + +@inline compose(l2::AbstractLinear, l1::AbstractLinear) = Linear(l2 * values(l1)) + + +struct Linear{A<:AbstractArray} <: AbstractLinear + values::A +end + +Base.inv(t::Linear) = Linear(mapslices(inv, values(t), dims=(1,2))) + + +struct Translation{A<:AbstractArray} <: AbstractAffine + values::A +end + +@inline linear(::Translation) = Identity() +@inline translation(translation::Translation) = translation + +@inline Base.values(translation::Translation) = translation.values +@inline Base.:(==)(t1::Translation, t2::Translation) = values(t1) == values(t2) +@inline batchsize(translation::Translation) = size(values(translation))[3:end] + +transform(t::Translation, x::AbstractArray) = x .+ values(t) +inverse_transform(t::Translation, x::AbstractArray) = x .- values(t) + +Base.inv(t::Translation) = Translation(-values(t)) + +@inline compose(t2::Translation, t1::Translation) = Translation(t2 * values(t1)) + + +struct Affine{T<:Translation,L<:AbstractLinear} <: AbstractAffine + composed::Composed{T,L} +end + +@inline linear(affine::Affine) = inner(affine.composed) +@inline translation(affine::Affine) = outer(affine.composed) + +@inline Base.:(==)(affine1::Affine, affine2::Affine) = affine1.composed == affine2.composed + +transform(affine::Affine, x::AbstractArray) = transform(affine.composed, x) +inverse_transform(affine::Affine, x::AbstractArray) = inverse_transform(affine.composed, x) + +Base.inv(affine::Affine) = inv(affine.composed) + +Base.show(io::IO, affine::Affine) = print(io, "$(translation(affine)) ∘ $(linear(affine))") + +@inline compose(translation::Translation, linear::AbstractLinear) = Affine(Composed(translation, linear)) +@inline compose(linear::AbstractLinear, translation::Translation) = Translation(linear * values(translation)) ∘ linear + +@inline compose((t2,l2)::AbstractAffine, (t1,l1)::AbstractAffine) = (t2 ∘ (l2 ∘ t1)) ∘ l1 +@inline compose((t2,l2)::AbstractAffine, l1::AbstractLinear) = t2 ∘ (l2 ∘ l1) +@inline compose(t2::Translation, (t1,l1)::AbstractAffine) = (t2 ∘ t1) ∘ l1 + + +struct Rotation{A<:AbstractArray} <: AbstractLinear + values::A +end + +inverse_transform(r::Rotation, x::AbstractArray) = batched_mul_T1(values(r), x) + +Base.inv(t::Rotation{<:AbstractArray{<:Any,3}}) = Rotation(batched_transpose(values(t))) +Base.inv(t::Rotation{<:AbstractArray{<:Any,N}}) where N = Rotation(permutedims(values(t), (2, 1, 3:N...))) + +@inline compose(r2::Rotation, r1::Rotation) = Rotation(r2 * values(r1)) + +const Rigid = Affine{<:Translation,<:Rotation} + + +include("rand.jl") diff --git a/src/geometric/rand.jl b/src/geometric/rand.jl new file mode 100644 index 0000000..3c57268 --- /dev/null +++ b/src/geometric/rand.jl @@ -0,0 +1,44 @@ +using LinearAlgebra: qr, Diagonal, diag, det +using Random: AbstractRNG, default_rng + +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, (n, m)::Pair{<:Integer,<:Integer}, batch_size::Dims=()) + values = randn(rng, T, m, n, batch_size...) + return Linear(values) +end + +Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, n::Integer, batch_size::Dims=()) = + rand(rng, T, Linear, n => n, batch_size) + +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Translation}, n::Integer, batch_size::Dims=()) + values = randn(rng, T, n, 1, batch_size...) + return Translation(values) +end + +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Affine}, n::Integer, batch_size::Dims=()) + translation = rand(rng, T, Translation, n, batch_size) + linear = rand(rng, T, Linear, n, batch_size) + return translation ∘ linear +end + +function rand_rotation(rng::AbstractRNG, T::Type{<:Real}, n::Integer) + A = randn(rng, T, n, n) + Q, R = qr(A) + Q = Q * Diagonal(sign.(diag(R))) + det(Q) < 0 && (Q[:, end] *= -1) + return Q +end + +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rotation}, n::Integer, batch_size::Dims=()) + values = reshape(stack([rand_rotation(rng, T, n) for _ in 1:prod(batch_size)]), n, n, batch_size...) + return Rotation(values) +end + +function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rigid}, n::Integer, batch_size::Dims=()) + translations = rand(rng, T, Translation, n, batch_size) + rotations = rand(rng, T, Rotation, n, batch_size) + return translations ∘ rotations +end + +Base.rand(T::Type{<:Real}, Tr::Type{<:Transformation}, dims, batch_size::Dims=()) = rand(default_rng(), T, Tr, dims, batch_size) +Base.rand(Tr::Type{<:Transformation}, dims, batch_size::Dims) = rand(Float64, Tr, dims, batch_size) +Base.rand(Tr::Type{<:Transformation}, dims::Integer) = rand(Tr, dims, ()) \ No newline at end of file diff --git a/src/identity.jl b/src/identity.jl new file mode 100644 index 0000000..10d80d3 --- /dev/null +++ b/src/identity.jl @@ -0,0 +1,10 @@ +struct Identity <: Transformation end + +transform(::Identity, x) = x +inverse_transform(::Identity, x) = x + +Base.inv(::Identity) = Identity() + +@inline compose(::Identity, ::Identity) = Identity() +@inline compose(::Identity, t::Transformation) = t +@inline compose(t::Transformation, ::Identity) = t \ No newline at end of file diff --git a/src/inverse.jl b/src/inverse.jl index 15a949d..8a5b856 100644 --- a/src/inverse.jl +++ b/src/inverse.jl @@ -1,20 +1,34 @@ """ - Inverse{T<:Transformations} + Inverse{T<:Transformation} -An `Inverse` represents a *lazy* inverse of a `Transformations` t. +An `Inverse` represents a *lazy* inverse of a `Transformation` t. `inverse(t)` is a lazy inverse that defaults to `inv(t)` when evaluated. `transform(inverse(t), x)` is equivalent to `inverse_transform(t, x)`. This allows for specialized inverse transform implementations that don't require the inverse to be computed explicitly. """ -struct Inverse{T<:Transformations} <: Transformations - t::T +struct Inverse{T<:Transformation} <: Transformation + parent::T end -@inline inverse(t::Transformations) = Inverse(t) -@inline inverse(t::Inverse) = t.t +Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent -@inline transform(t::Inverse, x) = inverse_transform(t.t, x) +batchsize(t::Inverse) = batchsize(t.parent) -@inline Base.inv(t::Inverse) = t.t +@inline inverse(t::Transformation) = Inverse(t) +@inline inverse(t::Inverse) = t.parent + +@inline transform(t::Inverse, x) = inverse_transform(t.parent, x) + +@inline Base.inv(t::Inverse) = t.parent + +function compose(t2::Inverse{T}, t1::T) where T<:Transformation + t2.parent === t1 && return Identity() + Composed(t2, t1) +end + +function compose(t2::T, t1::Inverse{T}) where T<:Transformation + t2 === t1.parent && return Identity() + Composed(t2, t1) +end diff --git a/src/rand.jl b/src/rand.jl deleted file mode 100644 index 52d0e50..0000000 --- a/src/rand.jl +++ /dev/null @@ -1,40 +0,0 @@ -using LinearAlgebra: qr, Diagonal, diag, det -using Random: AbstractRNG, default_rng - -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{LinearMaps}, (n, m)::Pair{<:Integer,<:Integer}, batch_size::Dims) - values = rand(rng, T, m, n, batch_size...) - return LinearMaps(values) -end - -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Translations}, n::Integer, batch_size::Dims) - values = rand(rng, T, n, 1, batch_size...) - return Translations(values) -end - -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{AffineMaps}, (n, m)::Pair{<:Integer,<:Integer}, batch_size::Dims) - translation = rand(rng, T, Translations, m, batch_size) - linear = rand(rng, T, LinearMaps, n => m, batch_size) - return translation ∘ linear -end - -function rand_rotation(rng::AbstractRNG, T::Type{<:Real}, n::Integer) - A = randn(rng, T, n, n) - Q, R = qr(A) - Q = Q * Diagonal(sign.(diag(R))) - det(Q) < 0 && (Q[:, end] *= -1) - return Q -end - -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rotations}, n::Integer, batch_size::Dims) - values = reshape(stack([rand_rotation(rng, T, n) for _ in 1:prod(batch_size)]), n, n, batch_size...) - return Rotations(values) -end - -function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{RigidTransformations}, n::Integer, batch_size::Dims) - translations = rand(rng, T, Translations, n, batch_size) - rotations = rand(rng, T, Rotations, n, batch_size) - return translations ∘ rotations -end - -Base.rand(T::Type{<:Real}, Tr::Type{<:Transformations}, dims, batch_size::Dims) = rand(default_rng(), T, Tr, dims, batch_size) -Base.rand(Tr::Type{<:Transformations}, dims, batch_size::Dims) = rand(Float64, Tr, dims, batch_size) diff --git a/src/transformations.jl b/src/transformations.jl index c9ab45d..97c4b6b 100644 --- a/src/transformations.jl +++ b/src/transformations.jl @@ -1,24 +1,26 @@ """ - Transformations + Transformation An abstract type whose concrete subtypes contain batches of transformations -that can be applied to an array. A `Transformations` `t` can be applied to +that can be applied to an array. A `Transformation` `t` can be applied to `x` with `transform(t, x)`, `t * x`, and t(x). """ -abstract type Transformations end +abstract type Transformation end + +batchsize(t::Transformation) = error("batchsize not defined for $(typeof(t))") """ transform(t, x) t * x t(x) """ -transform(t::Transformations, x) = error("transform not defined for $(typeof(t)) and $(typeof(x))") +transform(t::Transformation, x) = error("transform not defined for $(typeof(t)) and $(typeof(x))") -Base.inv(t::Transformations) = error("inverse not defined for $(typeof(t)) ") -@inline inverse_transform(t::Transformations, x) = transform(inv(t), x) +Base.inv(t::Transformation) = error("inverse not defined for $(typeof(t)) ") +@inline inverse_transform(t::Transformation, x) = transform(inv(t), x) -@inline Base.:(*)(t::Transformations, x) = transform(t, x) -@inline (t::Transformations)(x) = transform(t, x) +@inline Base.:(*)(t::Transformation, x) = transform(t, x) +@inline (t::Transformation)(x) = transform(t, x) -Base.show(io::IO, ::MIME"text/plain", t::Transformations) = print(io, summary(t)) +Base.show(io::IO, ::MIME"text/plain", t::Transformation) = print(io, summary(t)) diff --git a/test/ext/ChainRulesCoreExt.jl b/test/ext/ChainRulesCoreExt.jl index 81bc189..861e25f 100644 --- a/test/ext/ChainRulesCoreExt.jl +++ b/test/ext/ChainRulesCoreExt.jl @@ -4,8 +4,8 @@ using ChainRulesTestUtils: test_rrule @testset "ChainRulesCoreExt.jl" begin n = 3 batch_size = (1, 2) - affine = rand(Float64, AffineMaps, n => 4, batch_size) - rigid = rand(Float64, RigidTransformations, n, batch_size) + affine = rand(Float64, Affine, n, batch_size) + rigid = rand(Float64, Rigid, n, batch_size) x = rand(Float64, n, 2, batch_size...) test_rrule(transform, affine, x) diff --git a/test/ext/FunctorsExt.jl b/test/ext/FunctorsExt.jl index 79bb340..4f56f76 100644 --- a/test/ext/FunctorsExt.jl +++ b/test/ext/FunctorsExt.jl @@ -2,10 +2,11 @@ using Functors: functor @testset "FunctorsExt.jl" begin - @test !isempty(functor(inverse(rand(Float32, Translations, 3, (1,))))[1]) - @test !isempty(functor(compose(rand(Float32, Translations, 3, (1,)), rand(LinearMaps, 3 => 3, (1,))))[1]) - @test !isempty(functor(rand(Float32, Translations, 3, (1,)))[1]) - @test !isempty(functor(rand(Float32, LinearMaps, 3 => 3, (1,)))[1]) - @test !isempty(functor(rand(Float32, Rotations, 3, (1,)))[1]) + @test !isempty(functor(Inverse(rand(Float32, Translation, 3)))[1]) + @test !isempty(functor(Composed(rand(Float32, Translation, 3), rand(Float32, Translation, 3)))[1]) + @test !isempty(functor(rand(Float32, Translation, 3))[1]) + @test !isempty(functor(rand(Float32, Linear, 3 => 3))[1]) + @test !isempty(functor(rand(Float32, Rotation, 3))[1]) + @test !isempty(functor(rand(Float32, Affine, 3))[1]) end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9147eee..b5da0b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,8 @@ using ChainRulesTestUtils: test_rrule include("ext/FunctorsExt.jl") @testset "transformations.jl" begin - struct FooTransformations{A<:AbstractArray} <: Transformations; values::A end - t = FooTransformations(rand(Float64, ())) + struct FooTransformation{A<:AbstractArray} <: Transformation; values::A end + t = FooTransformation(rand(Float64, ())) x = rand(3, 2, 4) @test_throws ErrorException transform(t, x) @test_throws ErrorException inv(t) @@ -25,7 +25,7 @@ using ChainRulesTestUtils: test_rrule io = IOBuffer() show(io, MIME("text/plain"), t) str = String(take!(io)) - @test str == "FooTransformations{Array{Float64, 0}}" + @test str == "FooTransformation{Array{Float64, 0}}" end @@ -34,7 +34,7 @@ using ChainRulesTestUtils: test_rrule batch_size = (2, 4) x = rand(Float32, n, 2, batch_size...) - l = rand(Float32, LinearMaps, n => m, batch_size) + l = rand(Float32, Linear, n => m, batch_size) @test inverse(inverse(l)) === l @test inverse(l) * x == inv(l) * x end @@ -44,8 +44,8 @@ using ChainRulesTestUtils: test_rrule batch_size = (2, 4) x = rand(Float32, n, 5, batch_size...) - t = rand(Float32, Translations, m, batch_size) - l = rand(Float32, LinearMaps, n => m, batch_size) + t = rand(Float32, Translation, m, batch_size) + l = rand(Float32, Linear, n => m, batch_size) c = compose(t, l) @test t ∘ l == c #@test c(x) == t(l)(x) @@ -58,28 +58,28 @@ using ChainRulesTestUtils: test_rrule batch_size = (2, 4) x = rand(Float32, n, 5, batch_size...) - @testset "LinearMaps" begin - l = rand(Float32, LinearMaps, n => m, batch_size) - @test linear(l) isa LinearMaps + @testset "Linear" begin + l = rand(Float32, Linear, n => m, batch_size) + @test linear(l) isa Linear @test values(l) isa AbstractArray @test l * x == values(l) ⊠ x @test (inv(l) ∘ l) * x ≈ x @test inv(l) * (l * x) ≈ x end - @testset "Translations" begin - t = rand(Float32, Translations, n, batch_size) - @test translation(t) isa Translations + @testset "Translation" begin + t = rand(Float32, Translation, n, batch_size) + @test translation(t) isa Translation @test values(t) isa AbstractArray @test t * x == x .+ values(t) @test (inv(t) ∘ t) * x ≈ x @test inv(t) * (t * x) ≈ x end - @testset "AffineMaps" begin - affine = rand(Float32, AffineMaps, n => m, batch_size) - @test linear(affine) isa LinearMaps - @test translation(affine) isa Translations + @testset "Affine" begin + affine = rand(Float32, Affine, n, batch_size) + @test linear(affine) isa Linear + @test translation(affine) isa Translation @test affine * x == values(linear(affine)) ⊠ x .+ values(translation(affine)) @test (inv(affine) ∘ affine) * x ≈ x @test inv(affine) * (affine * x) ≈ x @@ -89,35 +89,35 @@ using ChainRulesTestUtils: test_rrule batch_size = (2, 4) x = rand(Float32, n, 5, batch_size...) - @testset "Rotations" begin - rotation = rand(Float32, Rotations, n, batch_size) - @test linear(rotation) isa Rotations + @testset "Rotation" begin + rotation = rand(Float32, Rotation, n, batch_size) + @test linear(rotation) isa Rotation @test values(rotation) isa AbstractArray @test rotation * x == values(linear(rotation)) ⊠ x @test (inv(rotation) ∘ rotation) * x ≈ x @test inv(rotation) * (rotation * x) ≈ x # NNlib.batched_transpose only supports one batch dimension - @test !isa(values(inv(rand(Float32, Rotations, n, (2,)))), Array) - @test isa(values(inv(rand(Float32, Rotations, n, (2,1)))), Array) + @test !isa(values(inv(rand(Float32, Rotation, n, (2,)))), Array) + @test isa(values(inv(rand(Float32, Rotation, n, (2,1)))), Array) end - @testset "RigidTransformations" begin - rigid = rand(Float32, RigidTransformations, n, batch_size) - @test linear(rigid) isa Rotations - @test translation(rigid) isa Translations + @testset "Rigid" begin + rigid = rand(Float32, Rigid, n, batch_size) + @test linear(rigid) isa Rotation + @test translation(rigid) isa Translation @test rigid * x == values(linear(rigid)) ⊠ x .+ values(translation(rigid)) @test (inv(rigid) ∘ rigid) * x ≈ x @test inv(rigid) * (rigid * x) ≈ x - @test rigid ∘ rigid isa RigidTransformations + @test rigid ∘ rigid isa Rigid end end @testset "rand.jl" begin - @testset "Rotations" begin - rotations = rand(Float32, Rotations, 3, (2, 4)) + @testset "Rotation" begin + rotations = rand(Float32, Rotation, 3, (2, 4)) @test BatchedTransformations.batched_det(values(rotations)) ≈ ones(Float32, 1, 1, 2, 4) end