Skip to content

Commit

Permalink
Implement missing LinearAlgebra wrappers and add support for uplo par…
Browse files Browse the repository at this point in the history
…ameter (#51)
  • Loading branch information
danielwe authored Oct 23, 2023
1 parent ee2be7a commit adf4bde
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 69 deletions.
80 changes: 45 additions & 35 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,19 @@ end
$(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum)))
end

adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Transpose) =
LinearAlgebra.transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.LowerTriangular) =
LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) =
LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UpperTriangular) =
LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) =
LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Diagonal) =
LinearAlgebra.Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::LinearAlgebra.Symmetric) =
LinearAlgebra.Symmetric(adapt(to, Base.parent(A)))
adapt_structure(to, A::Adjoint) = adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::Transpose) = transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LowerTriangular) = LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UnitLowerTriangular) = UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UpperTriangular) = UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UnitUpperTriangular) = UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::Diagonal) = Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::Bidiagonal) = Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo))
adapt_structure(to, A::Tridiagonal) = Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::SymTridiagonal) = SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev))
adapt_structure(to, A::Symmetric) = Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo))
adapt_structure(to, A::Hermitian) = Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo))
adapt_structure(to, A::UpperHessenberg) = UpperHessenberg(adapt(to, Base.parent(A)))


# we generally don't support multiple layers of wrappers, but some occur often
Expand Down Expand Up @@ -93,15 +88,19 @@ const WrappedArray{T,N,Src,Dst} = Union{
#Base.ReshapedArray{T,N,<:Src},
#Base.ReinterpretArray{T,N,<:Any,<:Src},

LinearAlgebra.Adjoint{T,<:Dst},
LinearAlgebra.Transpose{T,<:Dst},
LinearAlgebra.LowerTriangular{T,<:Dst},
LinearAlgebra.UnitLowerTriangular{T,<:Dst},
LinearAlgebra.UpperTriangular{T,<:Dst},
LinearAlgebra.UnitUpperTriangular{T,<:Dst},
LinearAlgebra.Diagonal{T,<:Dst},
LinearAlgebra.Tridiagonal{T,<:Dst},
LinearAlgebra.Symmetric{T,<:Dst},
Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d
Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst
LowerTriangular{T,<:Dst},
UnitLowerTriangular{T,<:Dst},
UpperTriangular{T,<:Dst},
UnitUpperTriangular{T,<:Dst},
Diagonal{T,<:Src},
Bidiagonal{T,<:Src},
Tridiagonal{T,<:Src},
SymTridiagonal{T,<:Src},
Symmetric{T,<:Dst},
Hermitian{T,<:Dst},
UpperHessenberg{T,<:Dst},

WrappedReinterpretArray{T,N,<:Src},
WrappedReshapedArray{T,N,<:Src},
Expand All @@ -121,20 +120,31 @@ const WrappedArray{T,N,Src,Dst} = Union{

# accessors for extracting information about the wrapper type
ndims(::Type{<:Base.LogicalIndex}) = 1
ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
ndims(::Type{<:LinearAlgebra.Transpose}) = 2
ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2
ndims(::Type{<:Adjoint}) = 2
ndims(::Type{<:Transpose}) = 2
ndims(::Type{<:LowerTriangular}) = 2
ndims(::Type{<:UnitLowerTriangular}) = 2
ndims(::Type{<:UpperTriangular}) = 2
ndims(::Type{<:UnitUpperTriangular}) = 2
ndims(::Type{<:Diagonal}) = 2
ndims(::Type{<:Bidiagonal}) = 2
ndims(::Type{<:Tridiagonal}) = 2
ndims(::Type{<:SymTridiagonal}) = 2
ndims(::Type{<:Symmetric}) = 2
ndims(::Type{<:Hermitian}) = 2
ndims(::Type{<:UpperHessenberg}) = 2
ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N

eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar

for T in [:(Base.LogicalIndex{<:Any,<:Src}),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}),
:(Adjoint{<:Any,<:Src}),
:(Transpose{<:Any,<:Src}),
:(Diagonal{<:Any,<:Src}),
:(Bidiagonal{<:Any,<:Src}),
:(Tridiagonal{<:Any,<:Src}),
:(SymTridiagonal{<:Any,<:Src}),
:(WrappedReinterpretArray{<:Any,<:Any,<:Src}),
:(WrappedReshapedArray{<:Any,<:Any,<:Src}),
:(WrappedSubArray{<:Any,<:Any,<:Src})]
Expand Down
77 changes: 43 additions & 34 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ macro test_adapt(to, src_expr, dst_expr, typ=nothing)
end
end

AnyCustomArray{T,N} = Union{CustomArray,WrappedArray{T,N,CustomArray,CustomArray{T,N}}}
AnyCustomArray{T,N} = Union{CustomArray{T,N},WrappedArray{T,N,CustomArray,CustomArray{T,N}}}
AnyCustomVector{T} = AnyCustomArray{T,1}
AnyCustomMatrix{T} = AnyCustomArray{T,2}


# basic adaption
Expand Down Expand Up @@ -128,75 +130,82 @@ end

@testset "array wrappers" begin

@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomArray
@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomMatrix
inds = CustomArray{Int,1}([1,2])
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomArray
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomMatrix

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomArray
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomMatrix

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomArray
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomMatrix

@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomArray
@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomVector

@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomArray
@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomMatrix

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomArray
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomMatrix
end


## doubly-wrapped

@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomArray
@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomMatrix

@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomMatrix

@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomArray
@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :)