-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A[c|t]_ldiv_B! specializations for UmfpackLU-StridedVecOrMat, less generalized linear indexing and meta-fu #20046
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -383,43 +383,56 @@ function nnz(lu::UmfpackLU) | |
end | ||
|
||
### Solve with Factorization | ||
for (f!, umfpack) in ((:A_ldiv_B!, :UMFPACK_A), | ||
(:Ac_ldiv_B!, :UMFPACK_At), | ||
(:At_ldiv_B!, :UMFPACK_Aat)) | ||
@eval begin | ||
function $f!{T<:UMFVTypes}(x::StridedVecOrMat{T}, lu::UmfpackLU{T}, b::StridedVecOrMat{T}) | ||
n = size(x, 2) | ||
if n != size(b, 2) | ||
throw(DimensionMismatch("in and output arrays must have the same number of columns")) | ||
end | ||
for j in 1:n | ||
solve!(view(x, :, j), lu, view(b, :, j), $umfpack) | ||
end | ||
return x | ||
end | ||
$f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVector{T}) = $f!(b, lu, copy(b)) | ||
$f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedMatrix{T}) = $f!(b, lu, copy(b)) | ||
|
||
function $f!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, b::StridedVector{Tb}) | ||
m, n = size(x, 1), size(x, 2) | ||
if n != size(b, 2) | ||
throw(DimensionMismatch("in and output arrays must have the same number of columns")) | ||
end | ||
# TODO: Optionally let user allocate these and pass in somehow | ||
r = similar(b, Float64, m) | ||
i = similar(b, Float64, m) | ||
for j in 1:n | ||
solve!(r, lu, convert(Vector{Float64}, real(view(b, :, j))), $umfpack) | ||
solve!(i, lu, convert(Vector{Float64}, imag(view(b, :, j))), $umfpack) | ||
|
||
map!((t,s) -> t + im*s, view(x, :, j), r, i) | ||
end | ||
return x | ||
end | ||
$f!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVector{Tb}) = $f!(b, lu, copy(b)) | ||
A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = A_ldiv_B!(b, lu, copy(b)) | ||
At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = At_ldiv_B!(b, lu, copy(b)) | ||
Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = Ac_ldiv_B!(b, lu, copy(b)) | ||
|
||
A_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = | ||
_Aq_ldiv_B!(X, lu, B, UMFPACK_A) | ||
At_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = | ||
_Aq_ldiv_B!(X, lu, B, UMFPACK_At) | ||
Ac_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = | ||
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat) | ||
|
||
A_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = A_ldiv_B!(b, lu, copy(b)) | ||
At_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = At_ldiv_B!(b, lu, copy(b)) | ||
Ac_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = Ac_ldiv_B!(b, lu, copy(b)) | ||
|
||
A_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = | ||
_Aq_ldiv_B!(X, lu, B, UMFPACK_A) | ||
At_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = | ||
_Aq_ldiv_B!(X, lu, B, UMFPACK_At) | ||
Ac_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = | ||
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat) | ||
|
||
_Aq_ldiv_B!(X::StridedVecOrMat, lu::UmfpackLU, B::StridedVecOrMat, transtype) = | ||
(_AqldivB_checkshapecompat(X, B); _AqldivB_kernel!(X, lu, B, transtype); return X) | ||
|
||
_AqldivB_checkshapecompat(X::StridedVecOrMat, B::StridedVecOrMat) = | ||
size(X, 2) == size(B, 2) || throw(DimensionMismatch("input and output must have same column count")) | ||
|
||
_AqldivB_kernel!{T<:UMFVTypes}(x::StridedVector{T}, lu::UmfpackLU{T}, b::StridedVector{T}, transtype) = | ||
solve!(x, lu, b, transtype) | ||
_AqldivB_kernel!{T<:UMFVTypes}(X::StridedMatrix{T}, lu::UmfpackLU{T}, B::StridedMatrix{T}, transtype) = | ||
for col in 1:size(X, 1) solve!(view(X, :, col), lu, view(B, :, col), transtype) end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one-liner for loops are not very readable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a better way to avoid introducing another three primarily empty/extraneous lines? Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if a one-liner has a nontrivial block of code or is doing multiple things, it shouldn't be a one-liner There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Expanded in #20055, and likewise a few of the other one-liners. Thanks! |
||
|
||
function _AqldivB_kernel!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, b::StridedVector{Tb}, transtype) | ||
r, i = similar(b, Float64), similar(b, Float64) | ||
solve!(r, lu, Vector{Float64}(real(b)), transtype) | ||
solve!(i, lu, Vector{Float64}(imag(b)), transtype) | ||
map!(complex, x, r, i) | ||
end | ||
function _AqldivB_kernel!{Tb<:Complex}(X::StridedMatrix{Tb}, lu::UmfpackLU{Float64}, B::StridedMatrix{Tb}, transtype) | ||
r = similar(B, Float64, size(B, 1)) | ||
i = similar(B, Float64, size(B, 1)) | ||
for j in 1:size(B, 2) | ||
solve!(r, lu, Vector{Float64}(real(view(B, :, j))), transtype) | ||
solve!(i, lu, Vector{Float64}(imag(view(B, :, j))), transtype) | ||
map!(complex, view(X, :, j), r, i) | ||
end | ||
end | ||
|
||
|
||
function getindex(lu::UmfpackLU, d::Symbol) | ||
L,U,p,q,Rs = umf_extract(lu) | ||
if d == :L | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the rewording from "number of columns" to "column count" ? was better before IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reworded to reduce line length overrun, but happy to change that in a collection of fixups? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #20055. Thanks!