Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: typestable factorizations #9575

Merged
merged 1 commit into from
Jan 28, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ Compiler improvements
Library improvements
--------------------

* Factorization api is now type-stable, functions dispatch on `Val{false}` or `Val{true}` instead of a boolean value ([#9575]).

* `convert` now checks for overflow when truncating integers or converting between
signed and unsigned ([#5413]).

Expand Down
40 changes: 27 additions & 13 deletions base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,17 @@ function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol)
return uplo == :U ? UpperTriangular(A) : LowerTriangular(A)
end

function cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; tol=0.0) =
_cholfact!(A, pivot, uplo, tol=tol)
function _cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, ::Type{Val{false}}, uplo::Symbol=:U; tol=0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the underscore versions necessary? Couldn't the underscore versions just be used as the ordinary cholfact! methods instead of having cholfact! calling _cholfact!?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require changing the API as dispatching on Val{false} and Val{true} would require pivot to not be a default argument and hence change places with uplo. It would still require 3 functions, and since _cholfact are not exported anyway I preferred to not changed the API

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. I hadn't thought about that. Thanks for the explanation. Let's stick with your solution. I'll merge.

uplochar = char_uplo(uplo)
if pivot
A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol)
return CholeskyPivoted{T,typeof(A)}(A, uplochar, piv, rank, tol, info)
end
return Cholesky(chol!(A, uplo).data, uplo)
end
function _cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, ::Type{Val{true}}, uplo::Symbol=:U; tol=0.0)
uplochar = char_uplo(uplo)
A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol)
return CholeskyPivoted{T,StridedMatrix{T}}(A, uplochar, piv, rank, tol, info)
end
cholfact!(A::AbstractMatrix, uplo::Symbol=:U) = Cholesky(chol!(A, uplo).data, uplo)

function cholfact!{T<:BlasFloat,S,UpLo}(C::Cholesky{T,S,UpLo})
Expand All @@ -100,14 +103,25 @@ function cholfact!{T<:BlasFloat,S,UpLo}(C::Cholesky{T,S,UpLo})
C
end

cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) = cholfact!(copy(A), uplo, pivot=pivot, tol=tol)
function cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
S = promote_type(typeof(chol(one(T))),Float32)
S <: BlasFloat && return cholfact!(convert(AbstractMatrix{S}, A), uplo, pivot = pivot, tol = tol)
pivot && throw(ArgumentError("pivot only supported for Float32, Float64, Complex{Float32} and Complex{Float64}"))
S != T && return cholfact!(convert(AbstractMatrix{S}, A), uplo)
return cholfact!(copy(A), uplo)
end
cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; tol=0.0) =
cholfact!(copy(A), uplo, pivot, tol=tol)


copy_oftype{T}(A::StridedMatrix{T}, ::Type{T}) = copy(A)
copy_oftype{T,S}(A::StridedMatrix{T}, ::Type{S}) = convert(AbstractMatrix{S}, A)
cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; tol=0.0) =
_cholfact(copy_oftype(A, promote_type(typeof(chol(one(T))),Float32)), pivot, uplo, tol=tol)
_cholfact{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Type{Val{true}}, uplo::Symbol=:U; tol=0.0) =
cholfact!(A, uplo, pivot, tol = tol)
_cholfact{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Type{Val{false}}, uplo::Symbol=:U; tol=0.0) =
cholfact!(A, uplo, pivot, tol = tol)

_cholfact{T}(A::StridedMatrix{T}, ::Type{Val{false}}, uplo::Symbol=:U; tol=0.0) =
cholfact!(A, uplo)
_cholfact{T}(A::StridedMatrix{T}, ::Type{Val{true}}, uplo::Symbol=:U; tol=0.0) =
throw(ArgumentError("pivot only supported for Float32, Float64, Complex{Float32} and Complex{Float64}"))


function cholfact(x::Number, uplo::Symbol=:U)
xf = fill(chol!(x, uplo), 1, 1)
Cholesky(xf, uplo)
Expand Down
21 changes: 12 additions & 9 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,23 @@ function factorize{T}(A::Matrix{T})
end
return lufact(A)
end
qrfact(A,pivot=typeof(zero(T)/sqrt(zero(T) + zero(T)))<:BlasFloat) # Generic pivoted QR not implemented yet
qrfact(A,typeof(zero(T)/sqrt(zero(T) + zero(T)))<:BlasFloat?Val{true}:Val{false}) # Generic pivoted QR not implemented yet
end

(\)(a::Vector, B::StridedVecOrMat) = (\)(reshape(a, length(a), 1), B)
function (\)(A::StridedMatrix, B::StridedVecOrMat)
m, n = size(A)
if m == n
if istril(A)
return istriu(A) ? \(Diagonal(A),B) : \(LowerTriangular(A),B)

for (T1,PIVOT) in ((BlasFloat,Val{true}),(Any,Val{false}))
@eval function (\){T<:$T1}(A::StridedMatrix{T}, B::StridedVecOrMat)
m, n = size(A)
if m == n
if istril(A)
return istriu(A) ? \(Diagonal(A),B) : \(LowerTriangular(A),B)
end
istriu(A) && return \(UpperTriangular(A),B)
return \(lufact(A),B)
end
istriu(A) && return \(UpperTriangular(A),B)
return \(lufact(A),B)
return qrfact(A,$PIVOT)\B
end
return qrfact(A,pivot=eltype(A)<:BlasFloat)\B
end

## Moore-Penrose inverse
Expand Down
28 changes: 16 additions & 12 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ immutable QRPivoted{T,S<:AbstractMatrix} <: Factorization{T}
end
QRPivoted{T}(factors::AbstractMatrix{T}, τ::Vector{T}, jpvt::Vector{BlasInt}) = QRPivoted{T,typeof(factors)}(factors, τ, jpvt)

qrfact!{T<:BlasFloat}(A::StridedMatrix{T}; pivot=false) = pivot ? QRPivoted(LAPACK.geqp3!(A)...) : QRCompactWY(LAPACK.geqrt!(A, min(minimum(size(A)), 36))...)
function qrfact!{T}(A::AbstractMatrix{T}; pivot=false)
pivot && warn("pivoting only implemented for Float32, Float64, Complex64 and Complex128")
function qrfact!{T}(A::AbstractMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false})
pivot==Val{true} && warn("pivoting only implemented for Float32, Float64, Complex64 and Complex128")
m, n = size(A)
τ = zeros(T, min(m,n))
@inbounds begin
Expand All @@ -64,17 +63,22 @@ function qrfact!{T}(A::AbstractMatrix{T}; pivot=false)
end
QR(A, τ)
end
qrfact{T<:BlasFloat}(A::StridedMatrix{T}; pivot=false) = qrfact!(copy(A),pivot=pivot)
qrfact{T}(A::StridedMatrix{T}; pivot=false) = (S = typeof(one(T)/norm(one(T)));S != T ? qrfact!(convert(AbstractMatrix{S},A), pivot=pivot) : qrfact!(copy(A),pivot=pivot))
qrfact!{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}) = pivot==Val{true} ? QRPivoted(LAPACK.geqp3!(A)...) : QRCompactWY(LAPACK.geqrt!(A, min(minimum(size(A)), 36))...)
qrfact{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}) = qrfact!(copy(A), pivot)
copy_oftype{T}(A::StridedMatrix{T}, ::Type{T}) = copy(A)
copy_oftype{T,S}(A::StridedMatrix{T}, ::Type{S}) = convert(AbstractMatrix{S}, A)
qrfact{T}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}) = qrfact!(copy_oftype(A, typeof(one(T)/norm(one(T)))), pivot)
qrfact(x::Number) = qrfact(fill(x,1,1))

function qr(A::Union(Number, AbstractMatrix); pivot::Bool=false, thin::Bool=true)
F = qrfact(A, pivot=pivot)
if pivot
full(F[:Q], thin=thin), F[:R], F[:p]
else
full(F[:Q], thin=thin), F[:R]
end
qr(A::Union(Number, AbstractMatrix), pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; thin::Bool=true) =
_qr(A, pivot, thin=thin)
function _qr(A::Union(Number, AbstractMatrix), ::Type{Val{false}}; thin::Bool=true)
F = qrfact(A, Val{false})
full(F[:Q], thin=thin), F[:R]
end
function _qr(A::Union(Number, AbstractMatrix), ::Type{Val{true}}; thin::Bool=true)
F = qrfact(A, Val{true})
full(F[:Q], thin=thin), F[:R], F[:p]
end

convert{T}(::Type{QR{T}},A::QR) = QR(convert(AbstractMatrix{T}, A.factors), convert(Vector{T}, A.τ))
Expand Down
24 changes: 12 additions & 12 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ end
LU{T}(factors::AbstractMatrix{T}, ipiv::Vector{BlasInt}, info::BlasInt) = LU{T,typeof(factors)}(factors, ipiv, info)

# StridedMatrix
function lufact!{T<:BlasFloat}(A::StridedMatrix{T}; pivot = true)
!pivot && return generic_lufact!(A, pivot=pivot)
function lufact!{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
pivot==Val{false} && return generic_lufact!(A, pivot)
lpt = LAPACK.getrf!(A)
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
end
lufact!(A::StridedMatrix; pivot = true) = generic_lufact!(A, pivot=pivot)
function generic_lufact!{T}(A::StridedMatrix{T}; pivot = true)
lufact!(A::StridedMatrix, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true}) = generic_lufact!(A, pivot)
function generic_lufact!{T}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
m, n = size(A)
minmn = min(m,n)
info = 0
Expand All @@ -25,7 +25,7 @@ function generic_lufact!{T}(A::StridedMatrix{T}; pivot = true)
for k = 1:minmn
# find index max
kp = k
if pivot
if pivot==Val{true}
amax = real(zero(T))
for i = k:m
absi = abs(A[i,k])
Expand Down Expand Up @@ -63,14 +63,14 @@ function generic_lufact!{T}(A::StridedMatrix{T}; pivot = true)
end
LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end
lufact{T<:BlasFloat}(A::AbstractMatrix{T}; pivot = true) = lufact!(copy(A), pivot=pivot)
lufact{T}(A::AbstractMatrix{T}; pivot = true) = (S = typeof(zero(T)/one(T)); S != T ? lufact!(convert(AbstractMatrix{S}, A), pivot=pivot) : lufact!(copy(A), pivot=pivot))
lufact{T<:BlasFloat}(A::AbstractMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true}) = lufact!(copy(A), pivot)
lufact{T}(A::AbstractMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true}) = (S = typeof(zero(T)/one(T)); S != T ? lufact!(convert(AbstractMatrix{S}, A), pivot) : lufact!(copy(A), pivot))
lufact(x::Number) = LU(fill(x, 1, 1), BlasInt[1], x == 0 ? one(BlasInt) : zero(BlasInt))
lufact(F::LU) = F

lu(x::Number) = (one(x), x, 1)
function lu(A::AbstractMatrix; pivot = true)
F = lufact(A, pivot = pivot)
function lu(A::AbstractMatrix, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
F = lufact(A, pivot)
F[:L], F[:U], F[:p]
end

Expand Down Expand Up @@ -156,7 +156,7 @@ cond(A::LU, p::Number) = norm(A[:L]*A[:U],p)*norm(inv(A),p)
# Tridiagonal

# See dgttrf.f
function lufact!{T}(A::Tridiagonal{T}; pivot = true)
function lufact!{T}(A::Tridiagonal{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
n = size(A, 1)
info = 0
ipiv = Array(BlasInt, n)
Expand All @@ -171,7 +171,7 @@ function lufact!{T}(A::Tridiagonal{T}; pivot = true)
end
for i = 1:n-2
# pivot or not?
if !pivot || abs(d[i]) >= abs(dl[i])
if pivot==Val{false} || abs(d[i]) >= abs(dl[i])
# No interchange
if d[i] != 0
fact = dl[i]/d[i]
Expand All @@ -194,7 +194,7 @@ function lufact!{T}(A::Tridiagonal{T}; pivot = true)
end
if n > 1
i = n-1
if !pivot || abs(d[i]) >= abs(dl[i])
if pivot==Val{false} || abs(d[i]) >= abs(dl[i])
if d[i] != 0
fact = dl[i]/d[i]
dl[i] = fact
Expand Down
22 changes: 11 additions & 11 deletions doc/helpdb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6459,7 +6459,7 @@ popdisplay(d::Display)

"),

("Base","lufact","lufact(A[, pivot=true]) -> F
("Base","lufact","lufact(A[, pivot=Val{true}]) -> F

Compute the LU factorization of \"A\". The return type of \"F\"
depends on the type of \"A\". In most cases, if \"A\" is a subtype
Expand Down Expand Up @@ -6536,18 +6536,18 @@ popdisplay(d::Display)

"),

("Base","cholfact","cholfact(A, [LU,][pivot=false,][tol=-1.0]) -> Cholesky
("Base","cholfact","cholfact(A [,LU=:U [,pivot=Val{false}]][;tol=-1.0]) -> Cholesky

Compute the Cholesky factorization of a dense symmetric positive
(semi)definite matrix \"A\" and return either a \"Cholesky\" if
\"pivot=false\" or \"CholeskyPivoted\" if \"pivot=true\". \"LU\"
\"pivot==Val{false}\" or \"CholeskyPivoted\" if \"pivot==Val{true}\". \"LU\"
may be \":L\" for using the lower part or \":U\" for the upper
part. The default is to use \":U\". The triangular matrix can be
obtained from the factorization \"F\" with: \"F[:L]\" and
\"F[:U]\". The following functions are available for \"Cholesky\"
objects: \"size\", \"\\\", \"inv\", \"det\". For
\"CholeskyPivoted\" there is also defined a \"rank\". If
\"pivot=false\" a \"PosDefException\" exception is thrown in case
\"pivot==Val{false}\" a \"PosDefException\" exception is thrown in case
the matrix is not positive definite. The argument \"tol\"
determines the tolerance for determining the rank. For negative
values, the tolerance is the machine precision.
Expand All @@ -6574,7 +6574,7 @@ popdisplay(d::Display)

"),

("Base","cholfact!","cholfact!(A, [LU,][pivot=false,][tol=-1.0]) -> Cholesky
("Base","cholfact!","cholfact!(A [,LU=:U,[pivot=Val{false}]][;tol=-1.0]) -> Cholesky

\"cholfact!\" is the same as \"cholfact()\", but saves space by
overwriting the input \"A\", instead of creating a copy.
Expand All @@ -6592,7 +6592,7 @@ popdisplay(d::Display)

"),

("Base","qr","qr(A, [pivot=false,][thin=true]) -> Q, R, [p]
("Base","qr","qr(A [,pivot=Val{false}][;thin=true]) -> Q, R, [p]

Compute the (pivoted) QR factorization of \"A\" such that either
\"A = Q*R\" or \"A[:,p] = Q*R\". Also see \"qrfact\". The default
Expand All @@ -6601,20 +6601,20 @@ popdisplay(d::Display)

"),

("Base","qrfact","qrfact(A[, pivot=false]) -> F
("Base","qrfact","qrfact(A[, pivot=Val{false}]) -> F

Computes the QR factorization of \"A\". The return type of \"F\"
depends on the element type of \"A\" and whether pivoting is
specified (with \"pivot=true\").
specified (with \"pivot==Val{true}\").

+------------------+-------------------+-----------+---------------------------------------+
| Return type | \\\"eltype(A)\\\" | \\\"pivot\\\" | Relationship between \\\"F\\\" and \\\"A\\\" |
+------------------+-------------------+-----------+---------------------------------------+
| \\\"QR\\\" | not \\\"BlasFloat\\\" | either | \\\"A==F[:Q]*F[:R]\\\" |
+------------------+-------------------+-----------+---------------------------------------+
| \\\"QRCompactWY\\\" | \\\"BlasFloat\\\" | \\\"false\\\" | \\\"A==F[:Q]*F[:R]\\\" |
| \\\"QRCompactWY\\\" | \\\"BlasFloat\\\" | \\\"Val{false}\\\" | \\\"A==F[:Q]*F[:R]\\\" |
+------------------+-------------------+-----------+---------------------------------------+
| \\\"QRPivoted\\\" | \\\"BlasFloat\\\" | \\\"true\\\" | \\\"A[:,F[:p]]==F[:Q]*F[:R]\\\" |
| \\\"QRPivoted\\\" | \\\"BlasFloat\\\" | \\\"Val{true}\\\" | \\\"A[:,F[:p]]==F[:Q]*F[:R]\\\" |
+------------------+-------------------+-----------+---------------------------------------+

\"BlasFloat\" refers to any of: \"Float32\", \"Float64\",
Expand Down Expand Up @@ -6681,7 +6681,7 @@ popdisplay(d::Display)

"),

("Base","qrfact!","qrfact!(A[, pivot=false])
("Base","qrfact!","qrfact!(A[, pivot=Val{false}])

\"qrfact!\" is the same as \"qrfact()\", but saves space by
overwriting the input \"A\", instead of creating a copy.
Expand Down
Loading