Skip to content

Commit

Permalink
Use Val(true/false) for qrfact(!) and lufact(!)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyferris committed Jun 24, 2017
1 parent 110cc5f commit 6e61a0a
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 51 deletions.
6 changes: 3 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ Library improvements
* Added `unique!` which is an inplace version of `unique` ([#20549]).

* Uses of `Val{c}` in `Base` has been replaced with `Val{c}()`, which is now easily
accessible via the `@pure` constructor `Val(c)`. Functions are defined as
`f(::Val{c}) = ...` and called by `f(Val(c))`. Notable affected function are:
`ntuple`, `fill_to_length`, `Base.literal_pow`.
accessible via the `@pure` constructor `Val(c)`. Functions are defined as
`f(::Val{c}) = ...` and called by `f(Val(c))`. Notable affected functions include:
`ntuple`, `Base.literal_pow`, `sqrtm`, `lufact`, `lufact!`, `qrfact`, `qrfact!`.

Compiler/Runtime improvements
-----------------------------
Expand Down
4 changes: 4 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,10 @@ export conv, conv2, deconv, filt, filt!, xcorr
@deprecate literal_pow{N}(a,b , ::Type{Val{N}}) literal_pow(f, Val(N))
#@deprecate IteratorsMD.split(t, V::Type{Val{n}}) IteratorsMD.split(t, Val(n))
@deprecate sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) sqrtm(A, Val(realmatrix))
@deprecate lufact(A, pivot::Union{Type{Val{false}}, Type{Val{true}}}) lufact(A, pivot)
@deprecate lufact!(A, pivot::Union{Type{Val{false}}, Type{Val{true}}}) lufact!(A, pivot)
@deprecate qrfact(A, pivot::Union{Type{Val{false}}, Type{Val{true}}}) qrfact(A, pivot)
@deprecate qrfact!(A, pivot::Union{Type{Val{false}}, Type{Val{true}}}) qrfact!(A, pivot)

# END 0.7 deprecations

Expand Down
32 changes: 16 additions & 16 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@ end
LU(factors::AbstractMatrix{T}, ipiv::Vector{BlasInt}, info::BlasInt) where {T} = LU{T,typeof(factors)}(factors, ipiv, info)

# StridedMatrix
function lufact!(A::StridedMatrix{T}, pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{true}) where T<:BlasFloat
if pivot === Val{false}
function lufact!(A::StridedMatrix{T}, pivot::Union{Val{false}, Val{true}} = Val(true)) where T<:BlasFloat
if pivot === Val(false)
return generic_lufact!(A, pivot)
end
lpt = LAPACK.getrf!(A)
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
end

"""
lufact!(A, pivot=Val{true}) -> LU
lufact!(A, pivot=Val(true)) -> LU
`lufact!` is the same as [`lufact`](@ref), but saves space by overwriting the
input `A`, instead of creating a copy. An [`InexactError`](@ref)
exception is thrown if the factorization produces a number not representable by the
element type of `A`, e.g. for integer types.
"""
lufact!(A::StridedMatrix, pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{true}) = generic_lufact!(A, pivot)
function generic_lufact!(A::StridedMatrix{T}, ::Type{Val{Pivot}} = Val{true}) where {T,Pivot}
lufact!(A::StridedMatrix, pivot::Union{Val{false}, Val{true}} = Val(true)) = generic_lufact!(A, pivot)
function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true)) where {T,Pivot}
m, n = size(A)
minmn = min(m,n)
info = 0
Expand Down Expand Up @@ -79,12 +79,12 @@ end

# floating point types doesn't have to be promoted for LU, but should default to pivoting
lufact(A::Union{AbstractMatrix{T}, AbstractMatrix{Complex{T}}},
pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{true}) where {T<:AbstractFloat} =
pivot::Union{Val{false}, Val{true}} = Val(true)) where {T<:AbstractFloat} =
lufact!(copy(A), pivot)

# for all other types we must promote to a type which is stable under division
"""
lufact(A [,pivot=Val{true}]) -> F::LU
lufact(A [,pivot=Val(true)]) -> F::LU
Compute the LU factorization of `A`.
Expand Down Expand Up @@ -135,7 +135,7 @@ julia> F[:L] * F[:U] == A[F[:p], :]
true
```
"""
function lufact(A::AbstractMatrix{T}, pivot::Union{Type{Val{false}}, Type{Val{true}}}) where T
function lufact(A::AbstractMatrix{T}, pivot::Union{Val{false}, Val{true}}) where T
S = typeof(zero(T)/one(T))
AA = similar(A, S, size(A))
copy!(AA, A)
Expand All @@ -146,13 +146,13 @@ function lufact(A::AbstractMatrix{T}) where T
S = typeof(zero(T)/one(T))
AA = similar(A, S, size(A))
copy!(AA, A)
F = lufact!(AA, Val{false})
F = lufact!(AA, Val(false))
if F.info == 0
return F
else
AA = similar(A, S, size(A))
copy!(AA, A)
return lufact!(AA, Val{true})
return lufact!(AA, Val(true))
end
end

Expand All @@ -162,11 +162,11 @@ lufact(F::LU) = F
lu(x::Number) = (one(x), x, 1)

"""
lu(A, pivot=Val{true}) -> L, U, p
lu(A, pivot=Val(true)) -> L, U, p
Compute the LU factorization of `A`, such that `A[p,:] = L*U`.
By default, pivoting is used. This can be overridden by passing
`Val{false}` for the second argument.
`Val(false)` for the second argument.
See also [`lufact`](@ref).
Expand All @@ -185,7 +185,7 @@ julia> A[p, :] == L * U
true
```
"""
function lu(A::AbstractMatrix, pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{true})
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true))
F = lufact(A, pivot)
F[:L], F[:U], F[:p]
end
Expand Down Expand Up @@ -319,7 +319,7 @@ end
# Tridiagonal

# See dgttrf.f
function lufact!(A::Tridiagonal{T}, pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{true}) where T
function lufact!(A::Tridiagonal{T}, pivot::Union{Val{false}, Val{true}} = Val(true)) where T
n = size(A, 1)
info = 0
ipiv = Vector{BlasInt}(n)
Expand All @@ -334,7 +334,7 @@ function lufact!(A::Tridiagonal{T}, pivot::Union{Type{Val{false}}, Type{Val{true
end
for i = 1:n-2
# pivot or not?
if pivot === Val{false} || abs(d[i]) >= abs(dl[i])
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
# No interchange
if d[i] != 0
fact = dl[i]/d[i]
Expand All @@ -357,7 +357,7 @@ function lufact!(A::Tridiagonal{T}, pivot::Union{Type{Val{false}}, Type{Val{true
end
if n > 1
i = n-1
if pivot === Val{false} || abs(d[i]) >= abs(dl[i])
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
if d[i] != 0
fact = dl[i]/d[i]
dl[i] = fact
Expand Down
30 changes: 15 additions & 15 deletions base/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,26 @@ function qrfactPivotedUnblocked!(A::StridedMatrix)
end

# LAPACK version
qrfact!(A::StridedMatrix{<:BlasFloat}, ::Type{Val{false}}) = QRCompactWY(LAPACK.geqrt!(A, min(minimum(size(A)), 36))...)
qrfact!(A::StridedMatrix{<:BlasFloat}, ::Type{Val{true}}) = QRPivoted(LAPACK.geqp3!(A)...)
qrfact!(A::StridedMatrix{<:BlasFloat}) = qrfact!(A, Val{false})
qrfact!(A::StridedMatrix{<:BlasFloat}, ::Val{false}) = QRCompactWY(LAPACK.geqrt!(A, min(minimum(size(A)), 36))...)
qrfact!(A::StridedMatrix{<:BlasFloat}, ::Val{true}) = QRPivoted(LAPACK.geqp3!(A)...)
qrfact!(A::StridedMatrix{<:BlasFloat}) = qrfact!(A, Val(false))

# Generic fallbacks

"""
qrfact!(A, pivot=Val{false})
qrfact!(A, pivot=Val(false))
`qrfact!` is the same as [`qrfact`](@ref) when `A` is a subtype of
`StridedMatrix`, but saves space by overwriting the input `A`, instead of creating a copy.
An [`InexactError`](@ref) exception is thrown if the factorization produces a number not
representable by the element type of `A`, e.g. for integer types.
"""
qrfact!(A::StridedMatrix, ::Type{Val{false}}) = qrfactUnblocked!(A)
qrfact!(A::StridedMatrix, ::Type{Val{true}}) = qrfactPivotedUnblocked!(A)
qrfact!(A::StridedMatrix) = qrfact!(A, Val{false})
qrfact!(A::StridedMatrix, ::Val{false}) = qrfactUnblocked!(A)
qrfact!(A::StridedMatrix, ::Val{true}) = qrfactPivotedUnblocked!(A)
qrfact!(A::StridedMatrix) = qrfact!(A, Val(false))

"""
qrfact(A, pivot=Val{false}) -> F
qrfact(A, pivot=Val(false)) -> F
Compute the QR factorization of the matrix `A`: an orthogonal (or unitary if `A` is
complex-valued) matrix `Q`, and an upper triangular matrix `R` such that
Expand All @@ -224,7 +224,7 @@ A = Q R
The returned object `F` stores the factorization in a packed format:
- if `pivot == Val{true}` then `F` is a [`QRPivoted`](@ref) object,
- if `pivot == Val(true)` then `F` is a [`QRPivoted`](@ref) object,
- otherwise if the element type of `A` is a BLAS type ([`Float32`](@ref), [`Float64`](@ref),
`Complex64` or `Complex128`), then `F` is a [`QRCompactWY`](@ref) object,
Expand Down Expand Up @@ -283,21 +283,21 @@ end
qrfact(x::Number) = qrfact(fill(x,1,1))

"""
qr(A, pivot=Val{false}; thin::Bool=true) -> Q, R, [p]
qr(A, pivot=Val(false); thin::Bool=true) -> Q, R, [p]
Compute the (pivoted) QR factorization of `A` such that either `A = Q*R` or `A[:,p] = Q*R`.
Also see [`qrfact`](@ref).
The default is to compute a thin factorization. Note that `R` is not
extended with zeros when the full `Q` is requested.
"""
qr(A::Union{Number, AbstractMatrix}, pivot::Union{Type{Val{false}}, Type{Val{true}}}=Val{false}; thin::Bool=true) =
qr(A::Union{Number, AbstractMatrix}, pivot::Union{Val{false}, Val{true}}=Val(false); thin::Bool=true) =
_qr(A, pivot, thin=thin)
function _qr(A::Union{Number, AbstractMatrix}, ::Type{Val{false}}; thin::Bool=true)
F = qrfact(A, Val{false})
function _qr(A::Union{Number, AbstractMatrix}, ::Val{false}; thin::Bool=true)
F = qrfact(A, Val(false))
full(getq(F), thin=thin), F[:R]::Matrix{eltype(F)}
end
function _qr(A::Union{Number, AbstractMatrix}, ::Type{Val{true}}; thin::Bool=true)
F = qrfact(A, Val{true})
function _qr(A::Union{Number, AbstractMatrix}, ::Val{true}; thin::Bool=true)
F = qrfact(A, Val(true))
full(getq(F), thin=thin), F[:R]::Matrix{eltype(F)}, F[:p]::Vector{BlasInt}
end

Expand Down
4 changes: 2 additions & 2 deletions base/sparse/spqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ function qmult(method::Integer, QR::Factorization{Tv}, X::Dense{Tv}) where Tv<:V
end


qrfact(A::SparseMatrixCSC, ::Type{Val{true}}) = factorize(ORDERING_DEFAULT, DEFAULT_TOL, Sparse(A, 0))
qrfact(A::SparseMatrixCSC, ::Val{true}) = factorize(ORDERING_DEFAULT, DEFAULT_TOL, Sparse(A, 0))

"""
qrfact(A) -> SPQR.Factorization
Expand All @@ -147,7 +147,7 @@ The main application of this type is to solve least squares problems with [`\\`]
calls the C library SPQR and a few additional functions from the library are wrapped but not
exported.
"""
qrfact(A::SparseMatrixCSC) = qrfact(A, Val{true})
qrfact(A::SparseMatrixCSC) = qrfact(A, Val(true))

# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
Expand Down
12 changes: 6 additions & 6 deletions test/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted, PosDefException

#pivoted upper Cholesky
if eltya != BigFloat
cz = cholfact(Hermitian(zeros(eltya,n,n)), Val{true})
cz = cholfact(Hermitian(zeros(eltya,n,n)), Val(true))
@test_throws Base.LinAlg.RankDeficientException Base.LinAlg.chkfullrank(cz)
cpapd = cholfact(apdh, Val{true})
cpapd = cholfact(apdh, Val(true))
@test rank(cpapd) == n
@test all(diff(diag(real(cpapd.factors))).<=0.) # diagonal should be non-increasing
if isreal(apd)
Expand Down Expand Up @@ -175,11 +175,11 @@ using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted, PosDefException

if eltya != BigFloat && eltyb != BigFloat # Note! Need to implement pivoted Cholesky decomposition in julia

cpapd = cholfact(apdh, Val{true})
cpapd = cholfact(apdh, Val(true))
@test norm(apd * (cpapd\b) - b)/norm(b) <= ε*κ*n # Ad hoc, revisit
@test norm(apd * (cpapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n

lpapd = cholfact(apdhL, Val{true})
lpapd = cholfact(apdhL, Val(true))
@test norm(apd * (lpapd\b) - b)/norm(b) <= ε*κ*n # Ad hoc, revisit
@test norm(apd * (lpapd\b[1:n]) - b[1:n])/norm(b[1:n]) <= ε*κ*n

Expand Down Expand Up @@ -251,7 +251,7 @@ end
0.25336108035924787 + 0.975317836492159im 0.0628393808469436 - 0.1253397353973715im
0.11192755545114 - 0.1603741874112385im 0.8439562576196216 + 1.0850814110398734im
-1.0568488936791578 - 0.06025820467086475im 0.12696236014017806 - 0.09853584666755086im]
cholfact(Hermitian(apd, :L), Val{true}) \ b
cholfact(Hermitian(apd, :L), Val(true)) \ b
r = factorize(apd)[:U]
E = abs.(apd - r'*r)
ε = eps(abs(float(one(Complex64))))
Expand All @@ -273,7 +273,7 @@ end
end

@testset "fail for non-BLAS element types" begin
@test_throws ArgumentError cholfact!(Hermitian(rand(Float16, 5,5)), Val{true})
@test_throws ArgumentError cholfact!(Hermitian(rand(Float16, 5,5)), Val(true))
end

@testset "throw for non positive definite matrix" begin
Expand Down
4 changes: 2 additions & 2 deletions test/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@ Base.transpose(a::ModInt{n}) where {n} = a # see Issue 20978
A = [ModInt{2}(1) ModInt{2}(0); ModInt{2}(1) ModInt{2}(1)]
b = [ModInt{2}(1), ModInt{2}(0)]

@test A*(lufact(A, Val{false})\b) == b
@test A*(lufact(A, Val(false))\b) == b

# Needed for pivoting:
Base.abs(a::ModInt{n}) where {n} = a
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k

@test A*(lufact(A, Val{true})\b) == b
@test A*(lufact(A, Val(true))\b) == b
14 changes: 7 additions & 7 deletions test/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ debug && println("QR decomposition (without pivoting)")
@test sprint(show,qra) == "$(typeof(qra)) with factors Q and R:\n$qstring\n$rstring"

debug && println("Thin QR decomposition (without pivoting)")
qra = @inferred qrfact(a[:,1:n1], Val{false})
@inferred qr(a[:,1:n1], Val{false})
qra = @inferred qrfact(a[:,1:n1], Val(false))
@inferred qr(a[:,1:n1], Val(false))
q,r = qra[:Q], qra[:R]
@test_throws KeyError qra[:Z]
@test q'*full(q, thin=false) eye(n)
Expand All @@ -82,8 +82,8 @@ debug && println("Thin QR decomposition (without pivoting)")
end

debug && println("(Automatic) Fat (pivoted) QR decomposition")
@inferred qrfact(a, Val{true})
@inferred qr(a, Val{true})
@inferred qrfact(a, Val(true))
@inferred qr(a, Val(true))

qrpa = factorize(a[1:n1,:])
q,r = qrpa[:Q], qrpa[:R]
Expand Down Expand Up @@ -134,7 +134,7 @@ debug && println("Matmul with QR factorizations")
@test_throws DimensionMismatch Base.LinAlg.A_mul_B!(q,zeros(eltya,n1+1))
@test_throws DimensionMismatch Base.LinAlg.Ac_mul_B!(q,zeros(eltya,n1+1))

qra = qrfact(a[:,1:n1], Val{false})
qra = qrfact(a[:,1:n1], Val(false))
q, r = qra[:Q], qra[:R]
@test A_mul_B!(full(q, thin=false)',q) eye(n)
@test_throws DimensionMismatch A_mul_B!(eye(eltya,n+1),q)
Expand All @@ -149,8 +149,8 @@ end
# Because transpose(x) == x
@test_throws ErrorException transpose(qrfact(randn(3,3)))
@test_throws ErrorException ctranspose(qrfact(randn(3,3)))
@test_throws ErrorException transpose(qrfact(randn(3,3), Val{false}))
@test_throws ErrorException ctranspose(qrfact(randn(3,3), Val{false}))
@test_throws ErrorException transpose(qrfact(randn(3,3), Val(false)))
@test_throws ErrorException ctranspose(qrfact(randn(3,3), Val(false)))
@test_throws ErrorException transpose(qrfact(big.(randn(3,3))))
@test_throws ErrorException ctranspose(qrfact(big.(randn(3,3))))

Expand Down

0 comments on commit 6e61a0a

Please sign in to comment.