Skip to content

Commit

Permalink
Redesign
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Sep 16, 2024
1 parent 4264c07 commit 38a580f
Show file tree
Hide file tree
Showing 18 changed files with 272 additions and 220 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BatchedTransformations"
uuid = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
authors = ["Anton Oresten <anton.oresten42@gmail.com> and contributors"]
version = "0.3.0"
version = "0.4.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
32 changes: 17 additions & 15 deletions ext/ChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ using ChainRulesCore

using BatchedTransformations: batched_mul, batched_mul_T1, batched_mul_T2

function ChainRulesCore.rrule(::typeof(transform), affine_maps::AbstractAffineMaps, x::AbstractArray)
translations, linear_maps = outer(affine_maps), linear(affine_maps)
t, R = values(translations), values(linear_maps)
function ChainRulesCore.rrule(::typeof(transform), affine::Affine, x::AbstractArray)
translation, linear = affine.composed.outer, affine.composed.inner
t, R = values(translation), values(linear)

y = batched_mul(R, x) .+ t

Expand All @@ -21,21 +21,22 @@ function ChainRulesCore.rrule(::typeof(transform), affine_maps::AbstractAffineMa
Δt = @thunk(sum(Δy, dims=2))
Δx = @thunk(batched_mul_T1(R, Δy))

Δtranslations = Tangent{typeof(translations)}(; values=Δt)
Δlinear_maps = Tangent{typeof(linear_maps)}(; values=ΔR)
Δaffine_maps = Tangent{typeof(affine_maps)}(; outer=Δtranslations, inner=Δlinear_maps)
Δtranslation = Tangent{typeof(translation)}(; values=Δt)
Δlinear = Tangent{typeof(linear)}(; values=ΔR)
Δcomposed = Tangent{typeof(affine.composed)}(; outer=Δtranslation, inner=Δlinear)
Δaffine = Tangent{typeof(affine)}(; composed=Δcomposed)

return NoTangent(), Δaffine_maps, Δx
return NoTangent(), Δaffine, Δx
end

return y, transform_pullback
end

function ChainRulesCore.rrule(::typeof(inverse_transform), rigid::RigidTransformations, x::AbstractArray)
translations, rotations = translation(rigid), linear(rigid)
z = inverse_transform(translations, x) # x .- t
y = inverse_transform(rotations, z) # R' * (x .- t)
t, R = values(translations), values(rotations)
function ChainRulesCore.rrule(::typeof(inverse_transform), rigid::Rigid, x::AbstractArray)
translation, rotation = rigid.composed.outer, rigid.composed.inner
z = inverse_transform(translation, x) # x .- t
y = inverse_transform(rotation, z) # R' * (x .- t)
t, R = values(translation), values(rotation)

function inverse_transform_pullback(_Δy)
Δy = unthunk(_Δy)
Expand All @@ -44,9 +45,10 @@ function ChainRulesCore.rrule(::typeof(inverse_transform), rigid::RigidTransform
Δx = @thunk(batched_mul(R, Δy))
Δt = @thunk(-sum(Δx, dims=2)) # t is in the same position as x, but negated and broadcasted

Δtranslations = Tangent{typeof(translations)}(; values=Δt)
Δrotations = Tangent{typeof(rotations)}(; values=ΔR)
Δrigid = Tangent{typeof(rigid)}(; outer=Δtranslations, inner=Δrotations)
Δtranslation = Tangent{typeof(translation)}(; values=Δt)
Δrotation = Tangent{typeof(rotation)}(; values=ΔR)
Δcomposed = Tangent{typeof(rigid.composed)}(; outer=Δtranslation, inner=Δrotation)
Δrigid = Tangent{typeof(rigid)}(; composed=Δcomposed)

return NoTangent(), Δrigid, Δx
end
Expand Down
7 changes: 4 additions & 3 deletions ext/FunctorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using Functors: @functor

@functor Inverse
@functor Composed
@functor LinearMaps
@functor Rotations
@functor Translations
@functor Linear
@functor Translation
@functor Affine
@functor Rotation

end
18 changes: 10 additions & 8 deletions src/BatchedTransformations.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
module BatchedTransformations

include("transformations.jl")
export Transformations, transform, inverse_transform
export Transformation, transform, inverse_transform
export batchsize

include("identity.jl")
export Identity

include("inverse.jl")
export Inverse, inverse
Expand All @@ -10,12 +14,10 @@ include("compose.jl")
export Composed, compose
export outer, inner

include("affine/affine.jl")
export AbstractLinearMaps, LinearMaps, Rotations
export Translations
export AbstractAffineMaps, AffineMaps, RigidTransformations
export translation, linear

include("rand.jl")
include("geometric/geometric.jl")
export GeometricTransformation, AbstractAffine, AbstractLinear
export Translation, Linear, Affine
export Rotation, Rigid
export linear, translation

end
24 changes: 0 additions & 24 deletions src/affine/affine.jl

This file was deleted.

42 changes: 0 additions & 42 deletions src/affine/linear.jl

This file was deleted.

20 changes: 0 additions & 20 deletions src/affine/translation.jl

This file was deleted.

29 changes: 16 additions & 13 deletions src/compose.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Composed{Outer<:Transformations,Inner<:Transformations}
Composed{Outer<:Transformation,Inner<:Transformation}
A `Composed` contains two transformations `t2` and `t1` that are composed.
It can be constructed with `compose(t2, t1)` and `t2 ∘ t1`, where `t1` is the
transformation to be applied first, and `t2` second.
A `Composed` contains two transformations `outer` and `inner` that are composed,
where `inner` gets applied first, and then `outer`..
It can be constructed with `compose(outer, inner)` or `outer ∘ inner`, unless
the `compose` function is overloaded for the specific types.
"""
struct Composed{Outer<:Transformations,Inner<:Transformations} <: Transformations
struct Composed{Outer<:Transformation,Inner<:Transformation} <: Transformation
outer::Outer
inner::Inner
end
Expand All @@ -14,18 +15,20 @@ end
compose(t2, t1)
t2 ∘ t1
"""
@inline compose(outer::Transformations, inner::Transformations) = Composed(outer, inner)
@inline compose(outer::Transformation, inner::Transformation) = Composed(outer, inner)

@inline Base.:()(outer::Transformations, inner::Transformations) = compose(outer, inner)
@inline Base.:()(outer::Transformation, inner::Transformation) = compose(outer, inner)

outer(composed::Composed) = composed.outer
inner(composed::Composed) = composed.inner
@inline Base.:(==)(a1::Composed, a2::Composed) = a1.outer == a2.outer && a1.inner == a2.inner

transform(t::Composed, x) = transform(outer(t), transform(inner(t), x))
@inline outer(composed::Composed) = composed.outer
@inline inner(composed::Composed) = composed.inner

inverse_transform(t::Composed, x) = inverse_transform(inner(t), inverse_transform(outer(t), x))
@inline transform(t::Composed, x) = transform(outer(t), transform(inner(t), x))

Base.inv(t::Composed) = compose(inv(inner(t)), inv(outer(t)))
@inline inverse_transform(t::Composed, x) = inverse_transform(inner(t), inverse_transform(outer(t), x))

# t2, t1 = compose(t2, t1)
@inline Base.inv(t::Composed) = inv(inner(t)) inv(outer(t))

# enables `outer, inner = compose(outer, inner)` syntax
Base.iterate(t::Composed, state=1) = state == 1 ? (t.outer, 2) : (state == 2 ? (t.inner, nothing) : nothing)
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using NNlib: , batched_mul, batched_transpose

function batched_mul_T1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
Expand Down
101 changes: 101 additions & 0 deletions src/geometric/geometric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
using NNlib: , batched_mul, batched_transpose

include("batched_utils.jl")

abstract type GeometricTransformation <: Transformation end

abstract type AbstractAffine <: GeometricTransformation end

function translation end
function linear end

Base.iterate(affine::AbstractAffine, state=0) = state == 0 ? (translation(affine), 1) : (state == 1 ? (linear(affine), nothing) : nothing)

abstract type AbstractLinear <: AbstractAffine end

@inline linear(linear::AbstractLinear) = linear
@inline translation(::AbstractLinear) = Identity()

@inline Base.values(linear::AbstractLinear) = linear.values
@inline Base.:(==)(l1::AbstractLinear, l2::AbstractLinear) = values(l1) == values(l2)
@inline batchsize(linear::AbstractLinear) = size(values(linear))[3:end]

transform(l::AbstractLinear, x::AbstractArray) = values(l) x

function transform(linear::AbstractLinear, x::AbstractVecOrMat)
A = values(linear)
batch_size = size(A)[3:end]
A′ = reshape(A, size(A, 1), size(A, 2), :)
y′ = A′ x
y = reshape(y′, size(A, 1), size(y′, 2), batch_size...)
return y
end

@inline compose(l2::AbstractLinear, l1::AbstractLinear) = Linear(l2 * values(l1))


struct Linear{A<:AbstractArray} <: AbstractLinear
values::A
end

Base.inv(t::Linear) = Linear(mapslices(inv, values(t), dims=(1,2)))


struct Translation{A<:AbstractArray} <: AbstractAffine
values::A
end

@inline linear(::Translation) = Identity()
@inline translation(translation::Translation) = translation

@inline Base.values(translation::Translation) = translation.values
@inline Base.:(==)(t1::Translation, t2::Translation) = values(t1) == values(t2)
@inline batchsize(translation::Translation) = size(values(translation))[3:end]

transform(t::Translation, x::AbstractArray) = x .+ values(t)
inverse_transform(t::Translation, x::AbstractArray) = x .- values(t)

Base.inv(t::Translation) = Translation(-values(t))

@inline compose(t2::Translation, t1::Translation) = Translation(t2 * values(t1))


struct Affine{T<:Translation,L<:AbstractLinear} <: AbstractAffine
composed::Composed{T,L}
end

@inline linear(affine::Affine) = inner(affine.composed)
@inline translation(affine::Affine) = outer(affine.composed)

@inline Base.:(==)(affine1::Affine, affine2::Affine) = affine1.composed == affine2.composed

transform(affine::Affine, x::AbstractArray) = transform(affine.composed, x)
inverse_transform(affine::Affine, x::AbstractArray) = inverse_transform(affine.composed, x)

Base.inv(affine::Affine) = inv(affine.composed)

Base.show(io::IO, affine::Affine) = print(io, "$(translation(affine))$(linear(affine))")

@inline compose(translation::Translation, linear::AbstractLinear) = Affine(Composed(translation, linear))
@inline compose(linear::AbstractLinear, translation::Translation) = Translation(linear * values(translation)) linear

@inline compose((t2,l2)::AbstractAffine, (t1,l1)::AbstractAffine) = (t2 (l2 t1)) l1
@inline compose((t2,l2)::AbstractAffine, l1::AbstractLinear) = t2 (l2 l1)
@inline compose(t2::Translation, (t1,l1)::AbstractAffine) = (t2 t1) l1


struct Rotation{A<:AbstractArray} <: AbstractLinear
values::A
end

inverse_transform(r::Rotation, x::AbstractArray) = batched_mul_T1(values(r), x)

Base.inv(t::Rotation{<:AbstractArray{<:Any,3}}) = Rotation(batched_transpose(values(t)))
Base.inv(t::Rotation{<:AbstractArray{<:Any,N}}) where N = Rotation(permutedims(values(t), (2, 1, 3:N...)))

@inline compose(r2::Rotation, r1::Rotation) = Rotation(r2 * values(r1))

const Rigid = Affine{<:Translation,<:Rotation}


include("rand.jl")
Loading

0 comments on commit 38a580f

Please sign in to comment.