Skip to content

Commit

Permalink
Add lu decomposition of *diagonal matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Jun 20, 2021
1 parent 9d5f31e commit e6a9151
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 56 deletions.
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ cases where a more specific implementation of `lu!` (or `lu`) is not available.
See also: `copy_oftype`, `copy_similar`
"""
copy_to_array(A::AbstractArray, ::Type{T}) where {T} = copyto!(Array{T}(undef, size(A)...), A)
copy_to_array(A::AbstractArray, ::Type{T}) where {T} = copyto!(Array{T}(undef, size(A)), A)

# The three copy functions above return mutable arrays with eltype T.
# To only ensure a certain eltype, and if a mutable copy is not needed, it is
Expand Down
129 changes: 96 additions & 33 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,45 +726,17 @@ end

#Linear solvers
ldiv!(A::Union{Bidiagonal, AbstractTriangular}, b::AbstractVector) = naivesub!(A, b)
ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVector) = ldiv!(copy(A), b)
ldiv!(A::Adjoint{<:Any,<:Bidiagonal}, b::AbstractVector) = ldiv!(copy(A), b)
ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = ldiv!(copy(A), b)
ldiv!(A::Adjoint{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = ldiv!(copy(A), b)
function ldiv!(A::Union{Bidiagonal,AbstractTriangular}, B::AbstractMatrix)
require_one_based_indexing(A, B)
nA,mA = size(A)
mA, nA = size(A)
n = size(B, 1)
if nA != n
throw(DimensionMismatch("size of A is ($nA,$mA), corresponding dimension of B is $n"))
end
for b in eachcol(B)
ldiv!(A, b)
end
B
end
function ldiv!(adjA::Adjoint{<:Any,<:Bidiagonal}, B::AbstractMatrix)
require_one_based_indexing(adjA, B)
mA, nA = size(adjA)
n = size(B, 1)
tmp = similar(B, n)
Ac = copy(adjA)
if mA != n
throw(DimensionMismatch("size of adjoint of A is ($mA,$nA), corresponding dimension of B is $n"))
throw(DimensionMismatch("size of A is ($mA,$nA), corresponding dimension of B is $n"))
end
for b in eachcol(B)
ldiv!(Ac, b)
end
B
end
function ldiv!(tA::Transpose{<:Any,<:Bidiagonal}, B::AbstractMatrix)
require_one_based_indexing(tA, B)
mA, nA = size(tA)
n = size(B, 1)
tmp = similar(B, n)
At = copy(tA)
if mA != n
throw(DimensionMismatch("size of transpose of A is ($mA,$nA), corresponding dimension of B is $n"))
end
for b in eachcol(B)
ldiv!(At, b)
ldiv!(A, b)
end
B
end
Expand Down Expand Up @@ -810,6 +782,41 @@ function naivesub!(A::Bidiagonal{T}, b::AbstractVector, x::AbstractVector = b) w
return x
end

function rdiv!(A::StridedMatrix, B::Bidiagonal)
m, n = size(A)
if size(B, 1) != n
throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(size(B,1))"))
end
if B.uplo == 'L'
diagB = B.dv[n]
for i = 1:m
A[i,n] /= diagB
end
for j = n-1:-1:1
diagB = B.dv[j]
offdiagB = B.ev[j]
for i = 1:m
A[i,j] = (A[i,j] - A[i,j+1]*offdiagB)/diagB
end
end
else
diagB = B.dv[1]
for i = 1:m
A[i,1] /= diagB
end
for j = 2:n
diagB = B.dv[j]
offdiagB = B.ev[j-1]
for i = 1:m
A[i,j] = (A[i,j] - A[i,j-1]*offdiagB)/diagB
end
end
end
A
end
rdiv!(A::StridedMatrix, B::Adjoint{<:Any,<:Bidiagonal}) = rdiv!(A, copy(B))
rdiv!(A::StridedMatrix, B::Transpose{<:Any,<:Bidiagonal}) = rdiv!(A, copy(B))

### Generic promotion methods and fallbacks
function \(A::Bidiagonal{<:Number}, B::AbstractVecOrMat{<:Number})
TA, TB = eltype(A), eltype(B)
Expand All @@ -834,6 +841,62 @@ end

factorize(A::Bidiagonal) = A

lu!(A::Bidiagonal{T}; check::Bool = true) where {T} =
lu!(A, pivot; check = check)
function lu!(A::Bidiagonal{T}, pivot::NoPivot; check::Bool = true) where {T}
if A.uplo == 'U'
info = something(findfirst(iszero, A.dv), 0)
else
info = 0
for i in eachindex(A.ev)
if iszero(A.ev[i])
info = i
end
A.ev[i] /= A.dv[i]
end
end
check && checknonsingular(info, pivot)
return LU{T}(A, 1:length(A.dv), info)
end
lu!(A::Bidiagonal{T}, pivot::RowMaximum; check::Bool = true) where {T} =
lu!(Tridiagonal{T}(A), pivot; check = check)

lu(A::Bidiagonal{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T} =
lu!(copy_oftype(A, lutype(T)), pivot; check = check)

function Base.getproperty(F::LU{T,<:Bidiagonal}, d::Symbol) where {T}
m, n = size(F)
if d === :L
L = tril!(getfield(F, :factors)[1:m, 1:min(m,n)])
for i = 1:min(m,n); L[i,i] = one(T); end
return L
elseif d === :U
return triu!(getfield(F, :factors)[1:min(m,n), 1:n])
elseif d === :p
return ipiv2perm(getfield(F, :ipiv), m)
elseif d === :P
return Matrix{T}(I, m, m)[:,invperm(F.p)]
else
getfield(F, d)
end
end
function ldiv!(F::LU{<:Any,<:Bidiagonal}, B::AbstractVecOrMat)
Bi = F.factors
if Bi.uplo == 'U'
return ldiv!(Bi, B)
else
return ldiv!(Diagonal(Bi.dv), ldiv!(UnitLowerTriangular(Bi), B))
end
end
function rdiv!(B::AbstractMatrix, F::LU{<:Any,<:Bidiagonal})
Bi = F.factors
if Bi.uplo == 'U'
return rdiv!(B, Bi)
else
return rdiv!(rdiv!(B, Diagonal(Bi.dv)), UnitLowerTriangular(Bi))
end
end

# Eigensystems
eigvals(M::Bidiagonal) = M.dv
function eigvecs(M::Bidiagonal{T}) where T
Expand Down
28 changes: 28 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,34 @@ function getproperty(C::Cholesky{<:Any,<:Diagonal}, d::Symbol)
end
end

lu(D::Diagonal{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T} =
lu!(copy_oftype(D, lutype(T)), pivot, check = check)
function lu!(D::Diagonal{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
info = something(findfirst(iszero, D.diag), 0)
check && checknonsingular(info, pivot)
return LU{T}(D, 1:length(D.diag), info)
end

function Base.getproperty(F::LU{T,<:Diagonal}, d::Symbol) where T
m, n = size(F)
if d === :L
L = tril!(getfield(F, :factors)[1:m, 1:min(m,n)])
for i = 1:min(m,n); L[i,i] = one(T); end
return L
elseif d === :U
return triu!(getfield(F, :factors)[1:min(m,n), 1:n])
elseif d === :p
return ipiv2perm(getfield(F, :ipiv), m)
elseif d === :P
return Matrix{T}(I, m, m)[:,invperm(F.p)]
else
getfield(F, d)
end
end

ldiv!(F::LU{<:Any,<:Diagonal}, B::AbstractVecOrMat) = ldiv!(F.factors, B)
rdiv!(B::AbstractMatrix, F::LU{<:Any,<:Diagonal}) = rdiv!(B, F.factors)

Base._sum(A::Diagonal, ::Colon) = sum(A.diag)
function Base._sum(A::Diagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
Expand Down
17 changes: 9 additions & 8 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ adjoint(F::LU) = Adjoint(F)
transpose(F::LU) = Transpose(F)

# StridedMatrix
lu(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
lu!(copy_oftype(A, lutype(eltype(A))), pivot; check=check)

lu!(A::StridedMatrix{<:BlasFloat}; check::Bool = true) = lu!(A, RowMaximum(); check=check)
function lu!(A::StridedMatrix{T}, ::RowMaximum; check::Bool = true) where {T<:BlasFloat}
lpt = LAPACK.getrf!(A)
Expand All @@ -88,10 +85,6 @@ end
function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
return generic_lufact!(A, pivot; check = check)
end

lu(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
lu!(copy_oftype(A, lutype(eltype(A))), pivot; check=check)

function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
copytri!(A.data, A.uplo, isa(A, Hermitian))
lu!(A.data, pivot; check = check)
Expand Down Expand Up @@ -285,11 +278,12 @@ function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(
S = lutype(T)
lu!(copy_to_array(A, S), pivot; check = check)
end
lu(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
lu!(copy_oftype(A, lutype(eltype(A))), pivot; check=check)
# TODO: remove for Julia v2.0
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, RowMaximum(); check=check)
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, NoPivot(); check=check)


lu(S::LU) = S
function lu(x::Number; check::Bool=true)
info = x == 0 ? one(BlasInt) : zero(BlasInt)
Expand Down Expand Up @@ -755,3 +749,10 @@ AbstractMatrix(F::LU{T,Tridiagonal{T,V}}) where {T,V} = Tridiagonal(F)
AbstractArray(F::LU{T,Tridiagonal{T,V}}) where {T,V} = AbstractMatrix(F)
Matrix(F::LU{T,Tridiagonal{T,V}}) where {T,V} = Array(AbstractArray(F))
Array(F::LU{T,Tridiagonal{T,V}}) where {T,V} = Matrix(F)

# SymTridiagonal
lu(S::SymTridiagonal{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T} =
lu!(copy_oftype(Tridiagonal(S), lutype(T)), pivot, check = check)
function lu!(S::SymTridiagonal{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
return lu!(Tridiagonal(S), pivot, check = check)
end
36 changes: 22 additions & 14 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,20 +399,28 @@ end
end
end

@testset "lu(A) has a fallback for abstract matrices (#40831)" begin
# check that lu works for some structured arrays
A0 = rand(5, 5)
@test lu(Diagonal(A0)) isa LU
@test Matrix(lu(Diagonal(A0))) Diagonal(A0)
@test lu(Bidiagonal(A0, :U)) isa LU
@test Matrix(lu(Bidiagonal(A0, :U))) Bidiagonal(A0, :U)

# lu(A) copies A and then invokes lu!, make sure that the most efficient
# implementation of lu! continues to be used
A1 = Tridiagonal(rand(2), rand(3), rand(2))
@test lu(A1) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A1, RowMaximum()) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A1, RowMaximum(); check = false) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@testset "lu on *diagonal matrices" begin
dl = rand(3)
d = rand(4)
Bl = Bidiagonal(d, dl, :L)
Bu = Bidiagonal(d, dl, :U)
Tri = Tridiagonal(dl, d, dl)
Sym = SymTridiagonal(d, dl)
D = Diagonal(d)
b = ones(4)
B = rand(4,4)
for A in (Bl, Bu, Tri, Sym, D), pivot in (NoPivot(), RowMaximum())
@test A\b lu(A, pivot)\b
@test B/A B/lu(A, pivot)
@test B/A B/Matrix(A)
@test Matrix(lu(A, pivot)) A
@test @inferred(lu(A)) isa LU
if A isa Union{Tridiagonal, SymTridiagonal}
@test lu(A) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A, pivot) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A, pivot; check = false) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
end
end
end

end # module TestLU

0 comments on commit e6a9151

Please sign in to comment.