Skip to content

Commit

Permalink
Merge pull request #5196 from stevengj/sparse_ldiv_matrix
Browse files Browse the repository at this point in the history
support multiple/matrix right-hand sides in sparse A \ B
  • Loading branch information
ViralBShah committed Dec 25, 2013
2 parents 3231f73 + e0a5598 commit cef7983
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 48 deletions.
60 changes: 35 additions & 25 deletions base/linalg/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ function *{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti})
end

## solvers
function A_ldiv_B!(A::SparseMatrixCSC, b::AbstractVector)
function A_ldiv_B!(A::SparseMatrixCSC, b::AbstractVecOrMat)
if istril(A)
if istriu(A) return A_ldiv_B!(Diagonal(A.nzval), b) end
return fwdTriSolve!(A, b)
Expand All @@ -186,7 +186,7 @@ function A_ldiv_B!(A::SparseMatrixCSC, b::AbstractVector)
return A_ldiv_B!(lufact(A),b)
end

function fwdTriSolve!(A::SparseMatrixCSC, b::AbstractVector)
function fwdTriSolve!(A::SparseMatrixCSC, b::AbstractVecOrMat)
# forward substitution for CSC matrices
n = length(b)
ncol = chksquare(A)
Expand All @@ -196,39 +196,49 @@ function fwdTriSolve!(A::SparseMatrixCSC, b::AbstractVector)
ja = A.rowval
ia = A.colptr

for j = 1:n - 1
i1 = ia[j]
i2 = ia[j+1]-1
b[j] /= aa[i1]
bj = b[j]
for i = i1+1:i2
b[ja[i]] -= bj*aa[i]
joff = 0
for k = 1:size(b,2)
for j = 1:n-1
jb = joff + j
i1 = ia[j]
i2 = ia[j+1]-1
b[jb] /= aa[i1]
bj = b[jb]
for i = i1+1:i2
b[joff+ja[i]] -= bj*aa[i]
end
end
joff += n
b[joff] /= aa[end]
end
b[end] /= aa[end]
return b
end

function bwdTriSolve!(A::SparseMatrixCSC, b::AbstractVector)
function bwdTriSolve!(A::SparseMatrixCSC, b::AbstractVecOrMat)
# backward substitution for CSC matrices
n = length(b)
ncol = chksquare(A)
if n != ncol throw(DimensionMismatch("A is $(ncol)X$(ncol) and b has length $(n)")) end

aa = A.nzval
ja = A.rowval
ia = A.colptr

for j = n:-1:2
i1 = ia[j]
i2 = ia[j+1]-1
b[j] /= aa[i2]
bj = b[j]
for i = i2-1:-1:i1
b[ja[i]] -= bj*aa[i]
end
end
b[1] /= aa[1]
aa = A.nzval
ja = A.rowval
ia = A.colptr

joff = 0
for k = 1:size(b,2)
for j = n:-1:2
jb = joff + j
i1 = ia[j]
i2 = ia[j+1]-1
b[jb] /= aa[i2]
bj = b[jb]
for i = i2-1:-1:i1
b[joff+ja[i]] -= bj*aa[i]
end
end
b[joff+1] /= aa[1]
joff += n
end
return b
end

Expand Down
69 changes: 46 additions & 23 deletions base/linalg/umfpack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,28 +200,48 @@ for (sym_r,sym_c,num_r,num_c,sol_r,sol_c,det_r,det_z,lunz,get_num_r,get_num_z,it
U.numeric = tmp[1]
return U
end
function solve{Tv<:Float64,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::Vector{Tv}, typ::Integer)
function solve{Tv<:Float64,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::VecOrMat{Tv}, typ::Integer)
umfpack_numeric!(lu)
x = similar(b)
@isok ccall(($sol_r, :libumfpack), Ti,
(Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Void}, Ptr{Float64}, Ptr{Float64}),
typ, lu.colptr, lu.rowval, lu.nzval, x, b, lu.numeric, umf_ctrl, umf_info)
joff = 1
for k = 1:size(b,2)
@isok ccall(($sol_r, :libumfpack), Ti,
(Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Void}, Ptr{Float64}, Ptr{Float64}),
typ, lu.colptr, lu.rowval, lu.nzval, pointer(x,joff), pointer(b,joff), lu.numeric, umf_ctrl, umf_info)
joff += size(b,1)
end
x
end
function solve{Tv<:Complex128,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::Vector{Tv}, typ::Integer)
function solve{Tv<:Complex128,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::VecOrMat{Tv}, typ::Integer)
umfpack_numeric!(lu)
xr = similar(b, Float64)
xi = similar(b, Float64)
@isok ccall(($sol_c, :libumfpack),
Ti,
(Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64},
Ptr{Void}, Ptr{Float64}, Ptr{Float64}),
typ, lu.colptr, lu.rowval, real(lu.nzval), imag(lu.nzval),
xr, xi, real(b), imag(b),
lu.numeric, umf_ctrl, umf_info)
complex(xr,xi)
x = similar(b)
n = size(b,1)
br = Array(Float64, n)
bi = Array(Float64, n)
xr = Array(Float64, n)
xi = Array(Float64, n)
joff = 0
for k = 1:size(b,2)
for j = 1:n
bj = b[joff+j]
br[j] = real(bj)
bi[j] = imag(bj)
end
@isok ccall(($sol_c, :libumfpack),
Ti,
(Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64},
Ptr{Void}, Ptr{Float64}, Ptr{Float64}),
typ, lu.colptr, lu.rowval, real(lu.nzval), imag(lu.nzval),
xr, xi, br, bi,
lu.numeric, umf_ctrl, umf_info)
for j = 1:n
x[joff+j] = complex(xr[j],xi[j])
end
joff += n
end
x
end
function det{Tv<:Float64,Ti<:$itype}(lu::UmfpackLU{Tv,Ti})
mx = Array(Tv,1)
Expand Down Expand Up @@ -278,29 +298,32 @@ for (sym_r,sym_c,num_r,num_c,sol_r,sol_c,det_r,det_z,lunz,get_num_r,get_num_z,it
end

### Solve with Factorization
A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::Vector{T}) = solve(lu, b, UMFPACK_A)
A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::VecOrMat{T}) = solve(lu, b, UMFPACK_A)
function A_ldiv_B!{Tlu<:Real,Tb<:Complex}(lu::UmfpackLU{Tlu}, b::Vector{Tb})
r = solve(lu, [convert(Tlu,real(be)) for be in b], UMFPACK_A)
i = solve(lu, [convert(Tlu,imag(be)) for be in b], UMFPACK_A)
Tb[r[k]+im*i[k] for k = 1:length(r)]
end
A_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) = A_ldiv_B!(lu, convert(Vector{Tlu}, b))
A_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::StridedVecOrMat{Tb}) = A_ldiv_B!(lu, convert(Array{Tlu}, b))
A_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::AbstractVecOrMat{Tb}) = A_ldiv_B!(lu, convert(Array{Tlu}, b))

Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::Vector{T}) = solve(lu, b, UMFPACK_At)
Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::VecOrMat{T}) = solve(lu, b, UMFPACK_At)
function Ac_ldiv_B!{Tlu<:Real,Tb<:Complex}(lu::UmfpackLU{Tlu}, b::Vector{Tb})
r = solve(lu, [convert(Float64,real(be)) for be in b], UMFPACK_At)
i = solve(lu, [convert(Float64,imag(be)) for be in b], UMFPACK_At)
Tb[r[k]+im*i[k] for k = 1:length(r)]
end
Ac_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) = Ac_ldiv_B!(lu, convert(Vector{Tlu}, b))
Ac_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::StridedVecOrMat{Tb}) = Ac_ldiv_B!(lu, convert(Array{Tlu}, b))
Ac_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::AbstractVecOrMat{Tb}) = Ac_ldiv_B!(lu, convert(Array{Tlu}, b))

At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::Vector{T}) = solve(lu, b, UMFPACK_Aat)
At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::VecOrMat{T}) = solve(lu, b, UMFPACK_Aat)
function At_ldiv_B!{Tlu<:Real,Tb<:Complex}(lu::UmfpackLU{Tlu}, b::Vector{Tb})
r = solve(lu, [convert(Float64,real(be)) for be in b], UMFPACK_Aat)
i = solve(lu, [convert(Float64,imag(be)) for be in b], UMFPACK_Aat)
Tb[r[k]+im*i[k] for k = 1:length(r)]
end
At_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) = At_ldiv_B!(lu, convert(Vector{Tlu}, b))
At_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::StridedVecOrMat{Tb}) = At_ldiv_B!(lu, convert(Array{Tlu}, b))
At_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::AbstractVecOrMat{Tb}) = At_ldiv_B!(lu, convert(Array{Tlu}, b))

function getindex(lu::UmfpackLU, d::Symbol)
L,U,p,q,Rs = umf_extract(lu)
Expand Down

0 comments on commit cef7983

Please sign in to comment.