Skip to content

Commit

Permalink
Revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Sep 17, 2024
1 parent 5fa6dfc commit 14f17b7
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 229 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "BatchedTransformations"
uuid = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
authors = ["Anton Oresten <anton.oresten42@gmail.com> and contributors"]
version = "0.4.0"
version = "0.5.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down
1 change: 0 additions & 1 deletion ext/FunctorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ using Functors: @functor
@functor Linear
@functor Translation
@functor Affine
@functor Rotation

end
14 changes: 8 additions & 6 deletions src/BatchedTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ module BatchedTransformations

include("core.jl")
export Transformation, transform, inverse_transform
export batchsize
export Identity
export Composed, compose
export outer, inner
export Inverse, inverse

include("geometric/geometric.jl")
export GeometricTransformation, AbstractAffine, AbstractLinear
export Translation, Linear, Affine
export Rotation, Rigid
export linear, translation
include("batched/batched.jl")
export BatchedTransformation
export batchsize, batchreshape, batchunsqueeze
export AbstractAffine, translation, linear
export Translation
export Homomorphic, Endomorphic, Automorphic
export Linear, Orthonormal, Rotation, Reflection
export Affine, Rigid

end
131 changes: 131 additions & 0 deletions src/batched/affine.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
abstract type AbstractAffine <: GeometricTransformation end

function translation end
function linear end

Base.iterate(affine::AbstractAffine, state=0) = state == 0 ? (translation(affine), 1) : (state == 1 ? (linear(affine), nothing) : nothing)

abstract type Homomorphic end
abstract type Endomorphic <: Homomorphic end
abstract type Automorphic <: Endomorphic end

struct Linear{M<:Homomorphic,A<:AbstractArray} <: AbstractAffine
values::A
end

abstract type Orthonormal{Det} <: Automorphic end

const Rotation = Linear{Orthonormal{1}}
const Reflection = Linear{Orthonormal{-1}}

@inline Linear{M}(values::A) where {M,A} = Linear{M,A}(values)
@inline Linear{M}(linear::Linear) where M = Linear{M}(values(linear))

@inline function Linear{M}(values::A) where {M<:Endomorphic,A<:AbstractArray}
size(values, 1) == size(values, 2) || error("rotation values must have size (n, n, batchdims...)")
Linear{M,A}(values)
end

@inline function Linear(values::A) where A<:AbstractArray
M = size(values, 1) == size(values, 2) ? Endomorphic : Homomorphic
Linear{M,A}(values)
end

@inline compose(l2::Linear{M1}, l1::Linear{M2}) where {M1<:Homomorphic,M2<:Homomorphic} = Linear{typejoin(M1,M2)}(l2 * values(l1))

@inline linear(linear::Linear) = linear
@inline translation(::Linear) = Identity()

@inline Base.values(linear::Linear) = linear.values
@inline Base.:(==)(l1::Linear, l2::Linear) = values(l1) == values(l2)

batchsize(linear::Linear) = size(values(linear))[3:end]

function batchreshape(linear::Linear{M}, args...) where M
A = values(linear)
Linear{M}(reshape(A, size(A, 1), size(A, 2), args...))
end

function batchunsqueeze(linear::Linear{M}; dims::Int) where M
@assert dims > 0
Linear{M}(unsqueeze(values(linear), dims=dims+2))
end

transform(l::Linear, x::AbstractArray) = values(l) x

transform(linear::Linear, x::AbstractVecOrMat) = batched_mul_large_small(values(linear), x)

inverse_transform(t::Linear{<:Orthonormal}, x::AbstractArray) = batched_mul_T1(values(t), x)

Base.inv(t::Linear{M}) where M<:Automorphic = Linear{M}(mapslices(inv, values(t), dims=(1,2)))

Base.inv(t::Linear{M,<:AbstractArray{<:Any,2}}) where M<:Orthonormal = Linear{M}(transpose(values(t)))
Base.inv(t::Linear{M,<:AbstractArray{<:Any,3}}) where M<:Orthonormal = Linear{M}(batched_transpose(values(t)))
Base.inv(t::Linear{M}) where M<:Orthonormal = Linear{M}(permutedims(values(t), (2, 1, 3:ndims(values(t))...)))


struct Translation{A<:AbstractArray} <: AbstractAffine
values::A

function Translation{A}(values::A) where A<:AbstractArray
size(values, 2) == 1 || error("translation values must have size (n, 1, batchdims...)")
new{A}(values)
end
end

Translation(values::A) where A = Translation{A}(values)

@inline linear(::Translation) = Identity()
@inline translation(translation::Translation) = translation

@inline Base.values(translation::Translation) = translation.values
@inline Base.:(==)(t1::Translation, t2::Translation) = values(t1) == values(t2)

batchsize(translation::Translation) = size(values(translation))[3:end]

function batchreshape(translation::Translation, args...)
b = values(translation)
Translation(reshape(b, size(b, 1), 1, args...))
end

function batchunsqueeze(translation::Translation; dims::Int)
@assert dims > 0
Translation(unsqueeze(values(translation), dims=dims+2))
end

transform(t::Translation, x::AbstractArray) = x .+ values(t)
inverse_transform(t::Translation, x::AbstractArray) = x .- values(t)

Base.inv(t::Translation) = Translation(-values(t))

@inline compose(t2::Translation, t1::Translation) = Translation(t2 * values(t1))


struct Affine{T<:Translation,L<:Linear{<:Automorphic}} <: AbstractAffine
composed::Composed{T,L}
end

const Rigid = Affine{<:Translation,<:Rotation}

@inline linear(affine::Affine) = inner(affine.composed)
@inline translation(affine::Affine) = outer(affine.composed)

@inline Base.:(==)(affine1::Affine, affine2::Affine) = affine1.composed == affine2.composed

function batchunsqueeze((translation,linear)::Affine; dims::Int)
batchunsqueeze(translation; dims) batchunsqueeze(linear; dims)
end

transform(affine::Affine, x::AbstractArray) = transform(affine.composed, x)
inverse_transform(affine::Affine, x::AbstractArray) = inverse_transform(affine.composed, x)

Base.inv(affine::Affine) = inv(affine.composed)

Base.show(io::IO, affine::Affine) = print(io, "$(translation(affine))$(linear(affine))")

@inline compose(translation::Translation, linear::Linear) = Affine(Composed(translation, linear))
@inline compose(linear::Linear, translation::Translation) = Translation(linear * values(translation)) linear

@inline compose((t2,l2)::AbstractAffine, (t1,l1)::AbstractAffine) = (t2 (l2 t1)) l1
@inline compose((t2,l2)::AbstractAffine, l1::Linear) = t2 (l2 l1)
@inline compose(t2::Translation, (t1,l1)::AbstractAffine) = (t2 t1) l1
17 changes: 17 additions & 0 deletions src/batched/batched.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using NNlib: , batched_mul, batched_transpose
using MLUtils: unsqueeze

include("batched_utils.jl")

function batchsize end
function batchreshape end
function batchunsqueeze end

batchsize(t::Transformation, d::Integer) = batchsize(t)[d]
batchsize(t::Inverse{<:Transformation}) = batchsize(t.parent)

abstract type GeometricTransformation <: Transformation end

include("affine.jl")
include("rand.jl")

10 changes: 9 additions & 1 deletion src/geometric/batched_utils.jl → src/batched/batched_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ function batched_mul_T2(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

function batched_mul_large_small(A::AbstractArray, x::AbstractVecOrMat)
batch_size = size(A)[3:end]
A′ = reshape(A, size(A, 1), size(A, 2), :)
y′ = A′ reshape(x, size(x, 1), size(x, 2))
y = reshape(y′, size(A, 1), size(x, 2), batch_size...)
return y
end

# might need custom chain rule
# could do map(det, eachslice(data, dims=size(data)[3:end], drop=false))
# could also try map(det, eachslice(data, dims=size(data)[3:end], drop=false))
batched_det(data::AbstractArray{<:Real}) = mapslices(det, data, dims=(1,2))
34 changes: 17 additions & 17 deletions src/geometric/rand.jl → src/batched/rand.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
using LinearAlgebra: qr, Diagonal, diag, det
using Random: AbstractRNG, default_rng

function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, (n, m)::Pair{<:Integer,<:Integer}, batch_size::Dims=())
values = randn(rng, T, m, n, batch_size...)
function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, (n, m)::Pair{<:Integer,<:Integer}, batchdims::Dims=())
values = randn(rng, T, m, n, batchdims...)
return Linear(values)
end

Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, n::Integer, batch_size::Dims=()) =
rand(rng, T, Linear, n => n, batch_size)
Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Linear}, n::Integer, batchdims::Dims=()) =
rand(rng, T, Linear, n => n, batchdims)

function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Translation}, n::Integer, batch_size::Dims=())
values = randn(rng, T, n, 1, batch_size...)
function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Translation}, n::Integer, batchdims::Dims=())
values = randn(rng, T, n, 1, batchdims...)
return Translation(values)
end

function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Affine}, n::Integer, batch_size::Dims=())
translation = rand(rng, T, Translation, n, batch_size)
linear = rand(rng, T, Linear, n, batch_size)
function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Affine}, n::Integer, batchdims::Dims=())
translation = rand(rng, T, Translation, n, batchdims)
linear = Linear{Automorphic}(values(rand(rng, T, Linear, n, batchdims))) # doesn't actually guarantee invertibility :/
return translation linear
end

Expand All @@ -28,17 +28,17 @@ function rand_rotation(rng::AbstractRNG, T::Type{<:Real}, n::Integer)
return Q
end

function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rotation}, n::Integer, batch_size::Dims=())
values = reshape(stack([rand_rotation(rng, T, n) for _ in 1:prod(batch_size)]), n, n, batch_size...)
function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rotation}, n::Integer, batchdims::Dims=())
values = reshape(stack([rand_rotation(rng, T, n) for _ in 1:prod(batchdims)]), n, n, batchdims...)
return Rotation(values)
end

function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rigid}, n::Integer, batch_size::Dims=())
translations = rand(rng, T, Translation, n, batch_size)
rotations = rand(rng, T, Rotation, n, batch_size)
function Base.rand(rng::AbstractRNG, T::Type{<:Real}, ::Type{Rigid}, n::Integer, batchdims::Dims=())
translations = rand(rng, T, Translation, n, batchdims)
rotations = rand(rng, T, Rotation, n, batchdims)
return translations rotations
end

Base.rand(T::Type{<:Real}, Tr::Type{<:Transformation}, dims, batch_size::Dims=()) = rand(default_rng(), T, Tr, dims, batch_size)
Base.rand(Tr::Type{<:Transformation}, dims, batch_size::Dims) = rand(Float64, Tr, dims, batch_size)
Base.rand(Tr::Type{<:Transformation}, dims::Integer) = rand(Tr, dims, ())
Base.rand(T::Type{<:Real}, Tr::Type{<:Transformation}, dims, batchdims::Dims=()) = rand(default_rng(), T, Tr, dims, batchdims)
Base.rand(Tr::Type{<:Transformation}, dims, batchdims::Dims=()) = rand(Float64, Tr, dims, batchdims)
Base.rand(Tr::Type{<:Transformation}, dims::Integer) = rand(Tr, dims, ()) # needed to avoid ambiguity with other rand methods
12 changes: 5 additions & 7 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ that can be applied to an array. A `Transformation` `t` can be applied to
abstract type Transformation end

function compose end
function batchsize end

"""
transform(t, x)
Expand All @@ -32,10 +31,10 @@ Base.show(io::IO, ::MIME"text/plain", t::Transformation) = print(io, summary(t))
"""
struct Identity <: Transformation end

transform(::Identity, x) = x
inverse_transform(::Identity, x) = x
@inline transform(::Identity, x) = x
@inline inverse_transform(::Identity, x) = x

Base.inv(::Identity) = Identity()
@inline Base.inv(::Identity) = Identity()

@inline compose(::Identity, ::Identity) = Identity()
@inline compose(::Identity, t::Transformation) = t
Expand Down Expand Up @@ -92,12 +91,11 @@ struct Inverse{T<:Transformation} <: Transformation
parent::T
end

Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent

batchsize(t::Inverse) = batchsize(t.parent)
@inline Base.:(==)(t1::Inverse, t2::Inverse) = t1.parent == t2.parent

@inline inverse(t::Transformation) = Inverse(t)
@inline inverse(t::Inverse) = t.parent
@inline inverse(t::Identity) = t

@inline transform(t::Inverse, x) = inverse_transform(t.parent, x)

Expand Down
Loading

0 comments on commit 14f17b7

Please sign in to comment.