From e0a5598e2f1de4417492d37dbb74ed191e69826b Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 19 Dec 2013 17:08:57 -0500 Subject: [PATCH] support multiple/matrix right-hand sides in sparse A \ B --- base/linalg/sparse.jl | 60 +++++++++++++++++++++--------------- base/linalg/umfpack.jl | 69 ++++++++++++++++++++++++++++-------------- 2 files changed, 81 insertions(+), 48 deletions(-) diff --git a/base/linalg/sparse.jl b/base/linalg/sparse.jl index 945dfa6ca421f..0dd9bd9c2cf31 100644 --- a/base/linalg/sparse.jl +++ b/base/linalg/sparse.jl @@ -173,7 +173,7 @@ function *{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) end ## solvers -function A_ldiv_B!(A::SparseMatrixCSC, b::AbstractVector) +function A_ldiv_B!(A::SparseMatrixCSC, b::AbstractVecOrMat) if istril(A) if istriu(A) return A_ldiv_B!(Diagonal(A.nzval), b) end return fwdTriSolve!(A, b) @@ -182,7 +182,7 @@ function A_ldiv_B!(A::SparseMatrixCSC, b::AbstractVector) return A_ldiv_B!(lufact(A),b) end -function fwdTriSolve!(A::SparseMatrixCSC, b::AbstractVector) +function fwdTriSolve!(A::SparseMatrixCSC, b::AbstractVecOrMat) # forward substitution for CSC matrices n = length(b) ncol = chksquare(A) @@ -192,39 +192,49 @@ function fwdTriSolve!(A::SparseMatrixCSC, b::AbstractVector) ja = A.rowval ia = A.colptr - for j = 1:n - 1 - i1 = ia[j] - i2 = ia[j+1]-1 - b[j] /= aa[i1] - bj = b[j] - for i = i1+1:i2 - b[ja[i]] -= bj*aa[i] + joff = 0 + for k = 1:size(b,2) + for j = 1:n-1 + jb = joff + j + i1 = ia[j] + i2 = ia[j+1]-1 + b[jb] /= aa[i1] + bj = b[jb] + for i = i1+1:i2 + b[joff+ja[i]] -= bj*aa[i] + end end + joff += n + b[joff] /= aa[end] end - b[end] /= aa[end] return b end -function bwdTriSolve!(A::SparseMatrixCSC, b::AbstractVector) +function bwdTriSolve!(A::SparseMatrixCSC, b::AbstractVecOrMat) # backward substitution for CSC matrices n = length(b) ncol = chksquare(A) if n != ncol throw(DimensionMismatch("A is $(ncol)X$(ncol) and b has length $(n)")) end - aa = A.nzval - ja = A.rowval - ia = A.colptr - - for j = n:-1:2 - i1 = ia[j] - i2 = ia[j+1]-1 - b[j] /= aa[i2] - bj = b[j] - for i = i2-1:-1:i1 - b[ja[i]] -= bj*aa[i] - end - end - b[1] /= aa[1] + aa = A.nzval + ja = A.rowval + ia = A.colptr + + joff = 0 + for k = 1:size(b,2) + for j = n:-1:2 + jb = joff + j + i1 = ia[j] + i2 = ia[j+1]-1 + b[jb] /= aa[i2] + bj = b[jb] + for i = i2-1:-1:i1 + b[joff+ja[i]] -= bj*aa[i] + end + end + b[joff+1] /= aa[1] + joff += n + end return b end diff --git a/base/linalg/umfpack.jl b/base/linalg/umfpack.jl index 5d900ae0fa8c1..1b911b1a4a994 100644 --- a/base/linalg/umfpack.jl +++ b/base/linalg/umfpack.jl @@ -200,28 +200,48 @@ for (sym_r,sym_c,num_r,num_c,sol_r,sol_c,det_r,det_z,lunz,get_num_r,get_num_z,it U.numeric = tmp[1] return U end - function solve{Tv<:Float64,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::Vector{Tv}, typ::Integer) + function solve{Tv<:Float64,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::VecOrMat{Tv}, typ::Integer) umfpack_numeric!(lu) x = similar(b) - @isok ccall(($sol_r, :libumfpack), Ti, - (Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64}, - Ptr{Float64}, Ptr{Void}, Ptr{Float64}, Ptr{Float64}), - typ, lu.colptr, lu.rowval, lu.nzval, x, b, lu.numeric, umf_ctrl, umf_info) + joff = 1 + for k = 1:size(b,2) + @isok ccall(($sol_r, :libumfpack), Ti, + (Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64}, + Ptr{Float64}, Ptr{Void}, Ptr{Float64}, Ptr{Float64}), + typ, lu.colptr, lu.rowval, lu.nzval, pointer(x,joff), pointer(b,joff), lu.numeric, umf_ctrl, umf_info) + joff += size(b,1) + end x end - function solve{Tv<:Complex128,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::Vector{Tv}, typ::Integer) + function solve{Tv<:Complex128,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}, b::VecOrMat{Tv}, typ::Integer) umfpack_numeric!(lu) - xr = similar(b, Float64) - xi = similar(b, Float64) - @isok ccall(($sol_c, :libumfpack), - Ti, - (Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64}, - Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, - Ptr{Void}, Ptr{Float64}, Ptr{Float64}), - typ, lu.colptr, lu.rowval, real(lu.nzval), imag(lu.nzval), - xr, xi, real(b), imag(b), - lu.numeric, umf_ctrl, umf_info) - complex(xr,xi) + x = similar(b) + n = size(b,1) + br = Array(Float64, n) + bi = Array(Float64, n) + xr = Array(Float64, n) + xi = Array(Float64, n) + joff = 0 + for k = 1:size(b,2) + for j = 1:n + bj = b[joff+j] + br[j] = real(bj) + bi[j] = imag(bj) + end + @isok ccall(($sol_c, :libumfpack), + Ti, + (Ti, Ptr{Ti}, Ptr{Ti}, Ptr{Float64}, Ptr{Float64}, + Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, + Ptr{Void}, Ptr{Float64}, Ptr{Float64}), + typ, lu.colptr, lu.rowval, real(lu.nzval), imag(lu.nzval), + xr, xi, br, bi, + lu.numeric, umf_ctrl, umf_info) + for j = 1:n + x[joff+j] = complex(xr[j],xi[j]) + end + joff += n + end + x end function det{Tv<:Float64,Ti<:$itype}(lu::UmfpackLU{Tv,Ti}) mx = Array(Tv,1) @@ -278,29 +298,32 @@ for (sym_r,sym_c,num_r,num_c,sol_r,sol_c,det_r,det_z,lunz,get_num_r,get_num_z,it end ### Solve with Factorization -A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::Vector{T}) = solve(lu, b, UMFPACK_A) +A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::VecOrMat{T}) = solve(lu, b, UMFPACK_A) function A_ldiv_B!{Tlu<:Real,Tb<:Complex}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) r = solve(lu, [convert(Tlu,real(be)) for be in b], UMFPACK_A) i = solve(lu, [convert(Tlu,imag(be)) for be in b], UMFPACK_A) Tb[r[k]+im*i[k] for k = 1:length(r)] end -A_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) = A_ldiv_B!(lu, convert(Vector{Tlu}, b)) +A_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::StridedVecOrMat{Tb}) = A_ldiv_B!(lu, convert(Array{Tlu}, b)) +A_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::AbstractVecOrMat{Tb}) = A_ldiv_B!(lu, convert(Array{Tlu}, b)) -Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::Vector{T}) = solve(lu, b, UMFPACK_At) +Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::VecOrMat{T}) = solve(lu, b, UMFPACK_At) function Ac_ldiv_B!{Tlu<:Real,Tb<:Complex}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) r = solve(lu, [convert(Float64,real(be)) for be in b], UMFPACK_At) i = solve(lu, [convert(Float64,imag(be)) for be in b], UMFPACK_At) Tb[r[k]+im*i[k] for k = 1:length(r)] end -Ac_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) = Ac_ldiv_B!(lu, convert(Vector{Tlu}, b)) +Ac_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::StridedVecOrMat{Tb}) = Ac_ldiv_B!(lu, convert(Array{Tlu}, b)) +Ac_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::AbstractVecOrMat{Tb}) = Ac_ldiv_B!(lu, convert(Array{Tlu}, b)) -At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::Vector{T}) = solve(lu, b, UMFPACK_Aat) +At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::VecOrMat{T}) = solve(lu, b, UMFPACK_Aat) function At_ldiv_B!{Tlu<:Real,Tb<:Complex}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) r = solve(lu, [convert(Float64,real(be)) for be in b], UMFPACK_Aat) i = solve(lu, [convert(Float64,imag(be)) for be in b], UMFPACK_Aat) Tb[r[k]+im*i[k] for k = 1:length(r)] end -At_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::Vector{Tb}) = At_ldiv_B!(lu, convert(Vector{Tlu}, b)) +At_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::StridedVecOrMat{Tb}) = At_ldiv_B!(lu, convert(Array{Tlu}, b)) +At_ldiv_B!{Tlu<:UMFVTypes,Tb<:Number}(lu::UmfpackLU{Tlu}, b::AbstractVecOrMat{Tb}) = At_ldiv_B!(lu, convert(Array{Tlu}, b)) function getindex(lu::UmfpackLU, d::Symbol) L,U,p,q,Rs = umf_extract(lu)