From 5fa6dfc4ce5d91f77bb5c99aa566af69aa46666b Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 16 Sep 2024 10:44:07 +0200 Subject: [PATCH] Restructure --- src/BatchedTransformations.jl | 10 +-- src/compose.jl | 34 ---------- src/core.jl | 114 ++++++++++++++++++++++++++++++++++ src/geometric/geometric.jl | 2 +- src/identity.jl | 10 --- src/inverse.jl | 34 ---------- src/transformations.jl | 26 -------- test/runtests.jl | 8 +-- 8 files changed, 121 insertions(+), 117 deletions(-) delete mode 100644 src/compose.jl create mode 100644 src/core.jl delete mode 100644 src/identity.jl delete mode 100644 src/inverse.jl delete mode 100644 src/transformations.jl diff --git a/src/BatchedTransformations.jl b/src/BatchedTransformations.jl index 30bd0c3..0e50631 100644 --- a/src/BatchedTransformations.jl +++ b/src/BatchedTransformations.jl @@ -1,18 +1,12 @@ module BatchedTransformations -include("transformations.jl") +include("core.jl") export Transformation, transform, inverse_transform export batchsize - -include("identity.jl") export Identity - -include("inverse.jl") -export Inverse, inverse - -include("compose.jl") export Composed, compose export outer, inner +export Inverse, inverse include("geometric/geometric.jl") export GeometricTransformation, AbstractAffine, AbstractLinear diff --git a/src/compose.jl b/src/compose.jl deleted file mode 100644 index 806f131..0000000 --- a/src/compose.jl +++ /dev/null @@ -1,34 +0,0 @@ -""" - Composed{Outer<:Transformation,Inner<:Transformation} - -A `Composed` contains two transformations `outer` and `inner` that are composed, -where `inner` gets applied first, and then `outer`.. -It can be constructed with `compose(outer, inner)` or `outer ∘ inner`, unless -the `compose` function is overloaded for the specific types. -""" -struct Composed{Outer<:Transformation,Inner<:Transformation} <: Transformation - outer::Outer - inner::Inner -end - -""" - compose(t2, t1) - t2 ∘ t1 -""" -@inline compose(outer::Transformation, inner::Transformation) = Composed(outer, inner) - -@inline Base.:(∘)(outer::Transformation, inner::Transformation) = compose(outer, inner) - -@inline Base.:(==)(a1::Composed, a2::Composed) = a1.outer == a2.outer && a1.inner == a2.inner - -@inline outer(composed::Composed) = composed.outer -@inline inner(composed::Composed) = composed.inner - -@inline transform(t::Composed, x) = transform(outer(t), transform(inner(t), x)) - -@inline inverse_transform(t::Composed, x) = inverse_transform(inner(t), inverse_transform(outer(t), x)) - -@inline Base.inv(t::Composed) = inv(inner(t)) ∘ inv(outer(t)) - -# enables `outer, inner = compose(outer, inner)` syntax -Base.iterate(t::Composed, state=1) = state == 1 ? (t.outer, 2) : (state == 2 ? (t.inner, nothing) : nothing) diff --git a/src/core.jl b/src/core.jl new file mode 100644 index 0000000..d3ae4bb --- /dev/null +++ b/src/core.jl @@ -0,0 +1,114 @@ +""" + Transformation + +An abstract type whose concrete subtypes contain batches of transformations +that can be applied to an array. A `Transformation` `t` can be applied to +`x` with `transform(t, x)`, `t * x`, and t(x). +""" +abstract type Transformation end + +function compose end +function batchsize end + +""" + transform(t, x) + t * x + t(x) +""" +transform(t::Transformation, x) = error("transform not defined for $(typeof(t)) and $(typeof(x))") + +@inline Base.:(*)(t::Transformation, x) = transform(t, x) +@inline (t::Transformation)(x) = transform(t, x) + +Base.inv(t::Transformation) = error("inverse not defined for $(typeof(t))") + +@inline inverse_transform(t::Transformation, x) = transform(inv(t), x) + +Base.show(io::IO, ::MIME"text/plain", t::Transformation) = print(io, summary(t)) + + +""" + Identity <: Transformation +""" +struct Identity <: Transformation end + +transform(::Identity, x) = x +inverse_transform(::Identity, x) = x + +Base.inv(::Identity) = Identity() + +@inline compose(::Identity, ::Identity) = Identity() +@inline compose(::Identity, t::Transformation) = t +@inline compose(t::Transformation, ::Identity) = t + + +""" + Composed{Outer<:Transformation,Inner<:Transformation} + +A `Composed` contains two transformations `outer` and `inner` that are composed, +where `inner` gets applied first, and then `outer`.. +It can be constructed with `compose(outer, inner)` or `outer ∘ inner`, unless +the `compose` function is overloaded for the specific types. +""" +struct Composed{Outer<:Transformation,Inner<:Transformation} <: Transformation + outer::Outer + inner::Inner +end + +""" + compose(t2, t1) + t2 ∘ t1 +""" +@inline compose(outer::Transformation, inner::Transformation) = Composed(outer, inner) + +@inline Base.:(∘)(outer::Transformation, inner::Transformation) = compose(outer, inner) + +@inline Base.:(==)(a1::Composed, a2::Composed) = a1.outer == a2.outer && a1.inner == a2.inner + +@inline outer(composed::Composed) = composed.outer +@inline inner(composed::Composed) = composed.inner + +@inline transform(t::Composed, x) = transform(outer(t), transform(inner(t), x)) + +@inline inverse_transform(t::Composed, x) = inverse_transform(inner(t), inverse_transform(outer(t), x)) + +@inline Base.inv(t::Composed) = inv(inner(t)) ∘ inv(outer(t)) + +# enables `outer, inner = compose(outer, inner)` syntax +Base.iterate(t::Composed, state=1) = state == 1 ? (t.outer, 2) : (state == 2 ? (t.inner, nothing) : nothing) + + +""" + Inverse{T<:Transformation} <: Transformation + +An `Inverse` represents a *lazy* inverse of a `Transformation` t. + +`inverse(t)` is a lazy inverse that defaults to `inv(t)` when evaluated. +`transform(inverse(t), x)` is equivalent to `inverse_transform(t, x)`. +This allows for specialized inverse transform implementations that don't +require the inverse to be computed explicitly. +""" +struct Inverse{T<:Transformation} <: Transformation + parent::T +end + +Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent + +batchsize(t::Inverse) = batchsize(t.parent) + +@inline inverse(t::Transformation) = Inverse(t) +@inline inverse(t::Inverse) = t.parent + +@inline transform(t::Inverse, x) = inverse_transform(t.parent, x) + +@inline Base.inv(t::Inverse) = t.parent + +@inline function compose(t2::Inverse{T}, t1::T) where T<:Transformation + t2.parent === t1 && return Identity() + Composed(t2, t1) +end + +@inline function compose(t2::T, t1::Inverse{T}) where T<:Transformation + t2 === t1.parent && return Identity() + Composed(t2, t1) +end diff --git a/src/geometric/geometric.jl b/src/geometric/geometric.jl index f66f6ff..696653a 100644 --- a/src/geometric/geometric.jl +++ b/src/geometric/geometric.jl @@ -9,7 +9,7 @@ abstract type AbstractAffine <: GeometricTransformation end function translation end function linear end -Base.iterate(affine::AbstractAffine, state=0) = state == 0 ? (translation(affine), 1) : (state == 1 ? (linear(affine), nothing) : nothing) +Base.iterate(affine::AbstractAffine, args...) = iterate(affine.composed, args...) abstract type AbstractLinear <: AbstractAffine end diff --git a/src/identity.jl b/src/identity.jl deleted file mode 100644 index 10d80d3..0000000 --- a/src/identity.jl +++ /dev/null @@ -1,10 +0,0 @@ -struct Identity <: Transformation end - -transform(::Identity, x) = x -inverse_transform(::Identity, x) = x - -Base.inv(::Identity) = Identity() - -@inline compose(::Identity, ::Identity) = Identity() -@inline compose(::Identity, t::Transformation) = t -@inline compose(t::Transformation, ::Identity) = t \ No newline at end of file diff --git a/src/inverse.jl b/src/inverse.jl deleted file mode 100644 index 8a5b856..0000000 --- a/src/inverse.jl +++ /dev/null @@ -1,34 +0,0 @@ -""" - Inverse{T<:Transformation} - -An `Inverse` represents a *lazy* inverse of a `Transformation` t. - -`inverse(t)` is a lazy inverse that defaults to `inv(t)` when evaluated. -`transform(inverse(t), x)` is equivalent to `inverse_transform(t, x)`. -This allows for specialized inverse transform implementations that don't -require the inverse to be computed explicitly. -""" -struct Inverse{T<:Transformation} <: Transformation - parent::T -end - -Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent - -batchsize(t::Inverse) = batchsize(t.parent) - -@inline inverse(t::Transformation) = Inverse(t) -@inline inverse(t::Inverse) = t.parent - -@inline transform(t::Inverse, x) = inverse_transform(t.parent, x) - -@inline Base.inv(t::Inverse) = t.parent - -function compose(t2::Inverse{T}, t1::T) where T<:Transformation - t2.parent === t1 && return Identity() - Composed(t2, t1) -end - -function compose(t2::T, t1::Inverse{T}) where T<:Transformation - t2 === t1.parent && return Identity() - Composed(t2, t1) -end diff --git a/src/transformations.jl b/src/transformations.jl deleted file mode 100644 index 97c4b6b..0000000 --- a/src/transformations.jl +++ /dev/null @@ -1,26 +0,0 @@ -""" - Transformation - -An abstract type whose concrete subtypes contain batches of transformations -that can be applied to an array. A `Transformation` `t` can be applied to -`x` with `transform(t, x)`, `t * x`, and t(x). -""" -abstract type Transformation end - -batchsize(t::Transformation) = error("batchsize not defined for $(typeof(t))") - -""" - transform(t, x) - t * x - t(x) -""" -transform(t::Transformation, x) = error("transform not defined for $(typeof(t)) and $(typeof(x))") - -Base.inv(t::Transformation) = error("inverse not defined for $(typeof(t)) ") -@inline inverse_transform(t::Transformation, x) = transform(inv(t), x) - -@inline Base.:(*)(t::Transformation, x) = transform(t, x) -@inline (t::Transformation)(x) = transform(t, x) - -Base.show(io::IO, ::MIME"text/plain", t::Transformation) = print(io, summary(t)) - diff --git a/test/runtests.jl b/test/runtests.jl index b5da0b5..be1a08d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,9 +13,9 @@ using ChainRulesTestUtils: test_rrule include("ext/FunctorsExt.jl") @testset "transformations.jl" begin - struct FooTransformation{A<:AbstractArray} <: Transformation; values::A end - t = FooTransformation(rand(Float64, ())) - x = rand(3, 2, 4) + struct FooTransformation <: Transformation end + t = FooTransformation() + x = "Bar" @test_throws ErrorException transform(t, x) @test_throws ErrorException inv(t) @test_throws ErrorException inverse_transform(t, x) @@ -25,7 +25,7 @@ using ChainRulesTestUtils: test_rrule io = IOBuffer() show(io, MIME("text/plain"), t) str = String(take!(io)) - @test str == "FooTransformation{Array{Float64, 0}}" + @test str == "FooTransformation" end