diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index 02ed30e91bb7c..383c6761eb3a7 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -8,10 +8,6 @@ Diagonal(A::Matrix) = Diagonal(diag(A)) size(D::Diagonal) = (length(D.diag),length(D.diag)) size(D::Diagonal,d::Integer) = d<1 ? error("dimension out of range") : (d<=2 ? length(D.diag) : 1) -convert{T}(::Type{Matrix{T}}, D::Diagonal{T}) = diagm(D.diag) -convert{T}(::Type{SymTridiagonal{T}}, D::Diagonal{T}) = SymTridiagonal(D.diag,zeros(T,length(D.diag)-1)) -convert{T}(::Type{Tridiagonal{T}}, D::Diagonal{T}) = Tridiagonal(zeros(T,length(D.diag)-1),D.diag,zeros(T,length(D.diag)-1)) - full(D::Diagonal) = diagm(D.diag) getindex(D::Diagonal, i::Integer, j::Integer) = i == j ? D.diag[i] : zero(eltype(D.diag)) diff --git a/base/linalg/special.jl b/base/linalg/special.jl index 4a0b62a127a02..4a2124292d3f8 100644 --- a/base/linalg/special.jl +++ b/base/linalg/special.jl @@ -1,5 +1,74 @@ #Methods operating on different special matrix types +#Interconversion between special matrix types +import Base.convert +convert{T}(::Type{Bidiagonal}, A::Diagonal{T})=Bidiagonal(A.diag, zeros(T, size(A.diag,1)-1), true) +convert{T}(::Type{SymTridiagonal}, A::Diagonal{T})=SymTridiagonal(A.diag, zeros(T, size(A.diag,1)-1)) +convert{T}(::Type{Tridiagonal}, A::Diagonal{T})=Tridiagonal(zeros(T, size(A.diag,1)-1), A.diag, zeros(T, size(A.diag,1)-1)) +convert(::Type{Triangular}, A::Union(Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal))=Triangular(full(A)) +convert(::Type{Matrix}, D::Diagonal) = diagm(D.diag) + +function convert(::Type{Diagonal}, A::Union(Bidiagonal, SymTridiagonal)) + all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as Diagonal")) + Diagonal(A.dv) +end + +function convert(::Type{SymTridiagonal}, A::Bidiagonal) + all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as SymTridiagonal")) + SymTridiagonal(A.dv, A.ev) +end + +convert{T}(::Type{Tridiagonal}, A::Bidiagonal{T})=Tridiagonal(A.isupper?zeros(T, size(A.dv,1)-1):A.ev, A.dv, A.isupper?A.ev:zeros(T, size(A.dv,1)-1)) + +function convert(::Type{Bidiagonal}, A::SymTridiagonal) + all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as Bidiagonal")) + Bidiagonal(A.dv, A.ev, true) +end + +function convert(::Type{Diagonal}, A::Tridiagonal) + all(A.dl .== 0) && all(A.du .== 0) || throw(ArgumentError("Matrix cannot be represented as Diagonal")) + Diagonal(A.d) +end + +function convert(::Type{Bidiagonal}, A::Tridiagonal) + if all(A.dl .== 0) return Bidiagonal(A.d, A.du, true) + elseif all(A.du .== 0) return Bidiagonal(A.d, A.dl, true) + else throw(ArgumentError("Matrix cannot be represented as Bidiagonal")) + end +end + +function convert(::Type{SymTridiagonal}, A::Tridiagonal) + all(A.dl .== A.du) || throw(ArgumentError("Matrix cannot be represented as SymTridiagonal")) + SymTridiagonal(A.d, A.dl) +end + +function convert(::Type{Diagonal}, A::Triangular) + full(A) == diagm(diag(A)) || throw(ArgumentError("Matrix cannot be represented as Diagonal")) + Diagonal(diag(A)) +end + +function convert(::Type{Bidiagonal}, A::Triangular) + fA = full(A) + if fA == diagm(diag(A)) + diagm(diag(fA, 1), 1) + return Bidiagonal(diag(A), diag(fA,1), true) + elseif fA == diagm(diag(A)) + diagm(diag(fA, -1), -1) + return Bidiagonal(diag(A), diag(fA,-1), true) + else + throw(ArgumentError("Matrix cannot be represented as Bidiagonal")) + end +end + +convert(::Type{SymTridiagonal}, A::Triangular) = convert(SymTridiagonal, convert(Tridiagonal, A)) + +function convert(::Type{Tridiagonal}, A::Triangular) + fA = full(A) + if fA == diagm(diag(A)) + diagm(diag(fA, 1), 1) + diagm(diag(fA, -1), -1) + return Tridiagonal(diag(fA, -1), diag(A), diag(fA,1)) + else + throw(ArgumentError("Matrix cannot be represented as Tridiagonal")) + end +end + #Constructs two method definitions taking into account (assumed) commutativity # e.g. @commutative f{S,T}(x::S, y::T) = x+y is the same is defining # f{S,T}(x::S, y::T) = x+y diff --git a/test/linalg.jl b/test/linalg.jl index 8553d6bf2a63b..a3bcf82daadf9 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -718,6 +718,41 @@ for relty in (Float16, Float32, Float64, BigFloat), elty in (relty, Complex{relt end end +#Test interconversion between special matrix types +using Base.Test + +N=12 +A=Diagonal([1:N]*1.0) +for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal, Triangular, Matrix] + @test full(convert(newtype, A)) == full(A) +end + +for isupper in (true, false) + A=Bidiagonal([1:N]*1.0, [1:N-1]*1.0, isupper) + for newtype in [Bidiagonal, Tridiagonal, Triangular, Matrix] + @test full(convert(newtype, A)) == full(A) + end + A=Bidiagonal([1:N]*1.0, [1:N-1]*0.0, isupper) #morally Diagonal + for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Tridiagonal, Triangular, Matrix] + @test full(convert(newtype, A)) == full(A) + end +end + +A=SymTridiagonal([1:N]*1.0, [1:N-1]*1.0) +for newtype in [Tridiagonal, Matrix] + @test full(convert(newtype, A)) == full(A) +end + +A=Tridiagonal([1:N-1]*0.0, [1:N]*1.0, [1:N-1]*0.0) #morally Diagonal +for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Triangular, Matrix] + @test full(convert(newtype, A)) == full(A) +end + +A=Triangular(full(Diagonal([1:N]*1.0))) #morally Diagonal +for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Triangular, Matrix] + @test full(convert(newtype, A)) == full(A) +end + # Test gglse for elty in (Float32, Float64, Complex64, Complex128) A = convert(Array{elty, 2}, [1 1 1 1; 1 3 1 1; 1 -1 3 1; 1 1 1 3; 1 1 1 -1])