Skip to content

Commit

Permalink
Clean up conversions between special and/or annotated matrix types.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha0 committed Sep 26, 2017
1 parent 76bf3a7 commit ba3c22d
Showing 1 changed file with 34 additions and 64 deletions.
98 changes: 34 additions & 64 deletions base/linalg/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

# Methods operating on different special matrix types


# Interconversion between special matrix types

# conversions from Diagonal to other special matrix types
convert(::Type{Bidiagonal}, A::Diagonal) =
Bidiagonal(A.diag, fill!(similar(A.diag, length(A.diag)-1), 0), :U)
convert(::Type{SymTridiagonal}, A::Diagonal) =
Expand All @@ -11,86 +14,53 @@ convert(::Type{Tridiagonal}, A::Diagonal) =
Tridiagonal(fill!(similar(A.diag, length(A.diag)-1), 0), A.diag,
fill!(similar(A.diag, length(A.diag)-1), 0))

function convert(::Type{Diagonal}, A::Union{Bidiagonal, SymTridiagonal})
if !iszero(A.ev)
# conversions from Bidiagonal to other special matrix types
convert(::Type{Diagonal}, A::Bidiagonal) =
iszero(A.ev) ? Diagonal(A.dv) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
end
Diagonal(A.dv)
end

function convert(::Type{SymTridiagonal}, A::Bidiagonal)
if !iszero(A.ev)
convert(::Type{SymTridiagonal}, A::Bidiagonal) =
iszero(A.ev) ? SymTridiagonal(A.dv, A.ev) :
throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
end
SymTridiagonal(A.dv, A.ev)
end

convert(::Type{Tridiagonal}, A::Bidiagonal{T}) where {T} =
convert(::Type{Tridiagonal}, A::Bidiagonal) =
Tridiagonal(A.uplo == 'U' ? fill!(similar(A.ev), 0) : A.ev, A.dv,
A.uplo == 'U' ? A.ev : fill!(similar(A.ev), 0))

function convert(::Type{Bidiagonal}, A::SymTridiagonal)
if !iszero(A.ev)
# conversions from SymTridiagonal to other special matrix types
convert(::Type{Diagonal}, A::SymTridiagonal) =
iszero(A.ev) ? Diagonal(A.dv) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
convert(::Type{Bidiagonal}, A::SymTridiagonal) =
iszero(A.ev) ? Bidiagonal(A.dv, A.ev, :U) :
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
end
Bidiagonal(A.dv, A.ev, :U)
end
convert(::Type{Tridiagonal}, A::SymTridiagonal) =
Tridiagonal(copy(A.ev), A.dv, A.ev)

function convert(::Type{Diagonal}, A::Tridiagonal)
if !(iszero(A.dl) && iszero(A.du))
# conversions from Tridiagonal to other special matrix types
convert(::Type{Diagonal}, A::Tridiagonal) =
iszero(A.dl) && iszero(A.du) ? Diagonal(A.d) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
end
Diagonal(A.d)
end

function convert(::Type{Bidiagonal}, A::Tridiagonal)
if iszero(A.dl)
return Bidiagonal(A.d, A.du, :U)
elseif iszero(A.du)
return Bidiagonal(A.d, A.dl, :L)
else
convert(::Type{Bidiagonal}, A::Tridiagonal) =
iszero(A.dl) ? Bidiagonal(A.d, A.du, :U) :
iszero(A.du) ? Bidiagonal(A.d, A.dl, :L) :
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
end
end

function convert(::Type{SymTridiagonal}, A::Tridiagonal)
if A.dl != A.du
convert(::Type{SymTridiagonal}, A::Tridiagonal) =
A.dl == A.du ? SymTridiagonal(A.d, A.dl) :
throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
end
SymTridiagonal(A.d, A.dl)
end

function convert(::Type{Tridiagonal}, A::SymTridiagonal)
Tridiagonal(copy(A.ev), A.dv, A.ev)
end

function convert(::Type{Diagonal}, A::AbstractTriangular)
if !isdiag(A)
# conversions from AbstractTriangular to special matrix types
convert(::Type{Diagonal}, A::AbstractTriangular) =
isdiag(A) ? Diagonal(diag(A)) :
throw(ArgumentError("matrix cannot be represented as Diagonal"))
end
Diagonal(diag(A))
end

function convert(::Type{Bidiagonal}, A::AbstractTriangular)
if isbanded(A, 0, 1) # is upper bidiagonal
return Bidiagonal(diag(A), diag(A, 1), :U)
elseif isbanded(A, -1, 0) # is lower bidiagonal
return Bidiagonal(diag(A), diag(A, -1), :L)
else
convert(::Type{Bidiagonal}, A::AbstractTriangular) =
isbanded(A, 0, 1) ? Bidiagonal(diag(A), diag(A, 1), :U) : # is upper bidiagonal
isbanded(A, -1, 0) ? Bidiagonal(diag(A), diag(A, -1), :L) : # is lower bidiagonal
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
end
end

convert(::Type{SymTridiagonal}, A::AbstractTriangular) =
convert(SymTridiagonal, convert(Tridiagonal, A))

function convert(::Type{Tridiagonal}, A::AbstractTriangular)
if isbanded(A, -1, 1) # is tridiagonal
return Tridiagonal(diag(A, -1), diag(A), diag(A, 1))
else
convert(::Type{Tridiagonal}, A::AbstractTriangular) =
isbanded(A, -1, 1) ? Tridiagonal(diag(A, -1), diag(A), diag(A, 1)) : # is tridiagonal
throw(ArgumentError("matrix cannot be represented as Tridiagonal"))
end
end


# Constructs two method definitions taking into account (assumed) commutativity
# e.g. @commutative f(x::S, y::T) where {S,T} = x+y is the same is defining
Expand Down

0 comments on commit ba3c22d

Please sign in to comment.