Skip to content

Commit

Permalink
Partially fix type instability in \ (#9191). Also, avoid copying the …
Browse files Browse the repository at this point in the history
…rhs in A_ldiv_B!

(cherry picked from commit 506907f)
  • Loading branch information
andreasnoack committed Jan 14, 2015
1 parent cf99eba commit 1763fef
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ function A_mul_Bc{TA,TB}(A::AbstractArray{TA}, B::Union(QRCompactWYQ{TB},QRPacke
convert(AbstractMatrix{TAB}, B))
end

A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedVector{T}) = A_ldiv_B!(Triangular(A[:R], :U), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2)))
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedMatrix{T}) = A_ldiv_B!(Triangular(A[:R], :U), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2), 1:size(B, 2)))
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, b::StridedVector{T}) = (A_ldiv_B!(Triangular(A[:R], :U), sub(Ac_mul_B!(A[:Q], b), 1:size(A, 2))); b)
A_ldiv_B!{T<:BlasFloat}(A::QRCompactWY{T}, B::StridedMatrix{T}) = (A_ldiv_B!(Triangular(A[:R], :U), sub(Ac_mul_B!(A[:Q], B), 1:size(A, 2), 1:size(B, 2))); B)

# Julia implementation similarly to xgelsy
function A_ldiv_B!{T<:BlasFloat}(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Real)
Expand Down Expand Up @@ -316,7 +316,8 @@ function A_ldiv_B!{T<:BlasFloat}(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Re
A_ldiv_B!(Triangular(C[1:rnk,1:rnk],:U),sub(Ac_mul_B!(getq(A),sub(B, 1:mA, 1:nrhs)),1:rnk,1:nrhs))
B[rnk+1:end,:] = zero(T)
LAPACK.ormrz!('L', iseltype(B, Complex) ? 'C' : 'T', C, τ, sub(B,1:nA,1:nrhs))
return isa(A,QRPivoted) ? B[invperm(A[:p]::Vector{BlasInt}),:] : B[1:nA,:], rnk
B[1:nA,:] = sub(B, 1:nA, :)[invperm(A[:p]::Vector{BlasInt}),:]
return B, rnk
end
A_ldiv_B!{T<:BlasFloat}(A::QRPivoted{T}, B::StridedVector{T}) = vec(A_ldiv_B!(A,reshape(B,length(B),1)))
A_ldiv_B!{T<:BlasFloat}(A::QRPivoted{T}, B::StridedVecOrMat{T}) = A_ldiv_B!(A, B, maximum(size(A))*eps(real(float(one(eltype(B))))))[1]
Expand Down Expand Up @@ -369,22 +370,32 @@ function A_ldiv_B!{T}(A::QR{T},B::StridedMatrix{T})
end
end
end
return B[1:n,:]
return B
end
A_ldiv_B!(A::QR, B::StridedVector) = A_ldiv_B!(A, reshape(B, length(B), 1))[:]
A_ldiv_B!(A::QRPivoted, B::StridedVector) = A_ldiv_B!(QR(A.factors,A.τ),B)[invperm(A.jpvt)]
A_ldiv_B!(A::QRPivoted, B::StridedMatrix) = A_ldiv_B!(QR(A.factors,A.τ),B)[invperm(A.jpvt),:]
function A_ldiv_B!(A::QRPivoted, b::StridedVector)
A_ldiv_B!(QR(A.factors,A.τ), b)
b[1:size(A.factors, 2)] = sub(b, 1:size(A.factors, 2))[invperm(A.jpvt)]
b
end
function A_ldiv_B!(A::QRPivoted, B::StridedMatrix)
A_ldiv_B!(QR(A.factors, A.τ), B)
B[1:size(A.factors, 2),:] = sub(B, 1:size(A.factors, 2), :)[invperm(A.jpvt)]
B
end
function \{TA,Tb}(A::Union(QR{TA},QRCompactWY{TA},QRPivoted{TA}),b::StridedVector{Tb})
S = promote_type(TA,Tb)
m,n = size(A)
m == length(b) || throw(DimensionMismatch("left hand side has $m rows, but right hand side has length $(length(b))"))
n > m ? A_ldiv_B!(convert(Factorization{S},A),[b,zeros(S,n-m)]) : A_ldiv_B!(convert(Factorization{S},A), S == Tb ? copy(b) : convert(AbstractVector{S}, b))
x = n > m ? A_ldiv_B!(convert(Factorization{S},A),[b,zeros(S,n-m)]) : A_ldiv_B!(convert(Factorization{S},A), S == Tb ? copy(b) : convert(AbstractVector{S}, b))
return length(x) > n ? x[1:n] : x
end
function \{TA,TB}(A::Union(QR{TA},QRCompactWY{TA},QRPivoted{TA}),B::StridedMatrix{TB})
S = promote_type(TA,TB)
m,n = size(A)
m == size(B,1) || throw(DimensionMismatch("left hand side has $m rows, but right hand side has $(size(B,1)) rows"))
n > m ? A_ldiv_B!(convert(Factorization{S},A),[B;zeros(S,n-m,size(B,2))]) : A_ldiv_B!(convert(Factorization{S},A), S == TB ? copy(B) : convert(AbstractMatrix{S}, B))
X = n > m ? A_ldiv_B!(convert(Factorization{S},A),[B;zeros(S,n-m,size(B,2))]) : A_ldiv_B!(convert(Factorization{S},A), S == TB ? copy(B) : convert(AbstractMatrix{S}, B))
return size(X, 1) > n ? X[1:n,:] : X
end

##TODO: Add methods for rank(A::QRP{T}) and adjust the (\) method accordingly
Expand Down

0 comments on commit 1763fef

Please sign in to comment.