Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Val-types by singleton types in lu and qr #40623

Merged
merged 9 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ Standard library changes
* The shape of an `UpperHessenberg` matrix is preserved under certain arithmetic operations, e.g. when multiplying or dividing by an `UpperTriangular` matrix. ([#40039])
* `cis(A)` now supports matrix arguments ([#40194]).
* `dot` now supports `UniformScaling` with `AbstractMatrix` ([#40250]).
* `qr[!]` and `lu[!]` now support `LinearAlgebra.PivotingStrategy` (singleton type) values
as their optional `pivot` argument: defaults are `qr(A, NoPivot())` (vs.
`qr(A, ColumnNorm())` for pivoting) and `lu(A, RowMaximum())` (vs. `lu(A, NoPivot())`
without pivoting); the former `Val{true/false}`-based calls are deprecated. ([#40623])
* `det(M::AbstractMatrix{BigInt})` now calls `det_bareiss(M)`, which uses the [Bareiss](https://en.wikipedia.org/wiki/Bareiss_algorithm) algorithm to calculate precise values.([#40868]).

#### Markdown
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ export
BunchKaufman,
Cholesky,
CholeskyPivoted,
ColumnNorm,
Eigen,
GeneralizedEigen,
GeneralizedSVD,
GeneralizedSchur,
Hessenberg,
LU,
LDLt,
NoPivot,
QR,
QRPivoted,
LQ,
Schur,
SVD,
Hermitian,
RowMaximum,
Symmetric,
LowerTriangular,
UpperTriangular,
Expand Down Expand Up @@ -164,6 +167,10 @@ abstract type Algorithm end
struct DivideAndConquer <: Algorithm end
struct QRIteration <: Algorithm end

abstract type PivotingStrategy end
struct NoPivot <: PivotingStrategy end
struct RowMaximum <: PivotingStrategy end
struct ColumnNorm <: PivotingStrategy end

# Check that stride of matrix/vector is 1
# Writing like this to avoid splatting penalty when called with multiple arguments,
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ function factorize(A::StridedMatrix{T}) where T
end
return lu(A)
end
qr(A, Val(true))
qr(A, ColumnNorm())
end
factorize(A::Adjoint) = adjoint(factorize(parent(A)))
factorize(A::Transpose) = transpose(factorize(parent(A)))
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ size(F::Adjoint{<:Any,<:Factorization}) = reverse(size(parent(F)))
size(F::Transpose{<:Any,<:Factorization}) = reverse(size(parent(F)))

checkpositivedefinite(info) = info == 0 || throw(PosDefException(info))
checknonsingular(info, pivoted::Val{true}) = info == 0 || throw(SingularException(info))
checknonsingular(info, pivoted::Val{false}) = info == 0 || throw(ZeroPivotException(info))
checknonsingular(info) = checknonsingular(info, Val{true}())
checknonsingular(info, ::RowMaximum) = info == 0 || throw(SingularException(info))
checknonsingular(info, ::NoPivot) = info == 0 || throw(ZeroPivotException(info))
checknonsingular(info) = checknonsingular(info, RowMaximum())

"""
issuccess(F::Factorization)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
end
return lu(A) \ B
end
return qr(A,Val(true)) \ B
return qr(A, ColumnNorm()) \ B
end

(\)(a::AbstractVector, b::AbstractArray) = pinv(a) * b
Expand Down
52 changes: 32 additions & 20 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,26 @@ adjoint(F::LU) = Adjoint(F)
transpose(F::LU) = Transpose(F)

# StridedMatrix
function lu!(A::StridedMatrix{T}, pivot::Union{Val{false}, Val{true}} = Val(true);
check::Bool = true) where T<:BlasFloat
if pivot === Val(false)
return generic_lufact!(A, pivot; check = check)
end
lu!(A::StridedMatrix{<:BlasFloat}; check::Bool = true) = lu!(A, RowMaximum(); check=check)
function lu!(A::StridedMatrix{T}, ::RowMaximum; check::Bool = true) where {T<:BlasFloat}
lpt = LAPACK.getrf!(A)
check && checknonsingular(lpt[3])
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
end
function lu!(A::HermOrSym, pivot::Union{Val{false}, Val{true}} = Val(true); check::Bool = true)
function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
return generic_lufact!(A, pivot; check = check)
end
function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
copytri!(A.data, A.uplo, isa(A, Hermitian))
lu!(A.data, pivot; check = check)
end
# for backward compatibility
# TODO: remove towards Julia v2
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{true}; check::Bool = true) lu!(A, RowMaximum(); check=check)
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{false}; check::Bool = true) lu!(A, NoPivot(); check=check)

"""
lu!(A, pivot=Val(true); check = true) -> LU
lu!(A, pivot = RowMaximum(); check = true) -> LU

`lu!` is the same as [`lu`](@ref), but saves space by overwriting the
input `A`, instead of creating a copy. An [`InexactError`](@ref)
Expand Down Expand Up @@ -127,19 +131,22 @@ Stacktrace:
[...]
```
"""
lu!(A::StridedMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); check::Bool = true) =
lu!(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
generic_lufact!(A, pivot; check = check)
function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true);
check::Bool = true) where {T,Pivot}
function generic_lufact!(A::StridedMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum();
check::Bool = true) where {T}
# Extract values
m, n = size(A)
minmn = min(m,n)

# Initialize variables
info = 0
ipiv = Vector{BlasInt}(undef, minmn)
@inbounds begin
for k = 1:minmn
# find index max
kp = k
if Pivot && k < m
if pivot === RowMaximum() && k < m
amax = abs(A[k, k])
for i = k+1:m
absi = abs(A[i,k])
Expand Down Expand Up @@ -175,7 +182,7 @@ function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true);
end
end
end
check && checknonsingular(info, Val{Pivot}())
check && checknonsingular(info, pivot)
return LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end

Expand All @@ -200,7 +207,7 @@ end

# for all other types we must promote to a type which is stable under division
"""
lu(A, pivot=Val(true); check = true) -> F::LU
lu(A, pivot = RowMaximum(); check = true) -> F::LU

Compute the LU factorization of `A`.

Expand All @@ -211,7 +218,7 @@ validity (via [`issuccess`](@ref)) lies with the user.
In most cases, if `A` is a subtype `S` of `AbstractMatrix{T}` with an element
type `T` supporting `+`, `-`, `*` and `/`, the return type is `LU{T,S{T}}`. If
pivoting is chosen (default) the element type should also support [`abs`](@ref) and
[`<`](@ref).
[`<`](@ref). Pivoting can be turned off by passing `pivot = NoPivot()`.

The individual components of the factorization `F` can be accessed via [`getproperty`](@ref):

Expand Down Expand Up @@ -267,11 +274,14 @@ julia> l == F.L && u == F.U && p == F.p
true
```
"""
function lu(A::AbstractMatrix{T}, pivot::Union{Val{false}, Val{true}}=Val(true);
check::Bool = true) where T
function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
S = lutype(T)
lu!(copy_oftype(A, S), pivot; check = check)
end
# TODO: remove for Julia v2.0
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, RowMaximum(); check=check)
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, NoPivot(); check=check)


lu(S::LU) = S
function lu(x::Number; check::Bool=true)
Expand Down Expand Up @@ -481,9 +491,11 @@ inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A))
# Tridiagonal

# See dgttrf.f
function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true);
check::Bool = true) where {T,V}
function lu!(A::Tridiagonal{T,V}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T,V}
# Extract values
n = size(A, 1)

# Initialize variables
info = 0
ipiv = Vector{BlasInt}(undef, n)
dl = A.dl
Expand All @@ -500,7 +512,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true
end
for i = 1:n-2
# pivot or not?
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
# No interchange
if d[i] != 0
fact = dl[i]/d[i]
Expand All @@ -523,7 +535,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true
end
if n > 1
i = n-1
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
if d[i] != 0
fact = dl[i]/d[i]
dl[i] = fact
Expand Down
29 changes: 18 additions & 11 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,17 +246,17 @@ function qrfactPivotedUnblocked!(A::AbstractMatrix)
end

# LAPACK version
qr!(A::StridedMatrix{<:BlasFloat}, ::Val{false} = Val(false); blocksize=36) =
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; blocksize=36) =
QRCompactWY(LAPACK.geqrt!(A, min(min(size(A)...), blocksize))...)
qr!(A::StridedMatrix{<:BlasFloat}, ::Val{true}) = QRPivoted(LAPACK.geqp3!(A)...)
qr!(A::StridedMatrix{<:BlasFloat}, ::ColumnNorm) = QRPivoted(LAPACK.geqp3!(A)...)

# Generic fallbacks

"""
qr!(A, pivot=Val(false); blocksize)
qr!(A, pivot = NoPivot(); blocksize)

`qr!` is the same as [`qr`](@ref) when `A` is a subtype of
[`StridedMatrix`](@ref), but saves space by overwriting the input `A`, instead of creating a copy.
`qr!` is the same as [`qr`](@ref) when `A` is a subtype of [`StridedMatrix`](@ref),
but saves space by overwriting the input `A`, instead of creating a copy.
An [`InexactError`](@ref) exception is thrown if the factorization produces a number not
representable by the element type of `A`, e.g. for integer types.

Expand Down Expand Up @@ -292,14 +292,17 @@ Stacktrace:
[...]
```
"""
qr!(A::AbstractMatrix, ::Val{false}) = qrfactUnblocked!(A)
qr!(A::AbstractMatrix, ::Val{true}) = qrfactPivotedUnblocked!(A)
qr!(A::AbstractMatrix) = qr!(A, Val(false))
qr!(A::AbstractMatrix, ::NoPivot) = qrfactUnblocked!(A)
qr!(A::AbstractMatrix, ::ColumnNorm) = qrfactPivotedUnblocked!(A)
qr!(A::AbstractMatrix) = qr!(A, NoPivot())
# TODO: Remove in Julia v2.0
@deprecate qr!(A::AbstractMatrix, ::Val{true}) qr!(A, ColumnNorm())
@deprecate qr!(A::AbstractMatrix, ::Val{false}) qr!(A, NoPivot())

_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))

"""
qr(A, pivot=Val(false); blocksize) -> F
qr(A, pivot = NoPivot(); blocksize) -> F

Compute the QR factorization of the matrix `A`: an orthogonal (or unitary if `A` is
complex-valued) matrix `Q`, and an upper triangular matrix `R` such that
Expand All @@ -310,7 +313,7 @@ A = Q R

The returned object `F` stores the factorization in a packed format:

- if `pivot == Val(true)` then `F` is a [`QRPivoted`](@ref) object,
- if `pivot == ColumnNorm()` then `F` is a [`QRPivoted`](@ref) object,

- otherwise if the element type of `A` is a BLAS type ([`Float32`](@ref), [`Float64`](@ref),
`ComplexF32` or `ComplexF64`), then `F` is a [`QRCompactWY`](@ref) object,
Expand Down Expand Up @@ -340,7 +343,7 @@ and `F.Q*A` are supported. A `Q` matrix can be converted into a regular matrix w
orthogonal matrix.

The block size for QR decomposition can be specified by keyword argument
`blocksize :: Integer` when `pivot == Val(false)` and `A isa StridedMatrix{<:BlasFloat}`.
`blocksize :: Integer` when `pivot == NoPivot()` and `A isa StridedMatrix{<:BlasFloat}`.
It is ignored when `blocksize > minimum(size(A))`. See [`QRCompactWY`](@ref).

!!! compat "Julia 1.4"
Expand Down Expand Up @@ -382,6 +385,10 @@ function qr(A::AbstractMatrix{T}, arg...; kwargs...) where T
copyto!(AA, A)
return qr!(AA, arg...; kwargs...)
end
# TODO: remove in Julia v2.0
@deprecate qr(A::AbstractMatrix, ::Val{false}; kwargs...) qr(A, NoPivot(); kwargs...)
@deprecate qr(A::AbstractMatrix, ::Val{true}; kwargs...) qr(A, ColumnNorm(); kwargs...)

qr(x::Number) = qr(fill(x,1,1))
function qr(v::AbstractVector)
require_one_based_indexing(v)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ end
D = Diagonal(randn(5))
Q = qr(randn(5, 5)).Q
@test D * Q' == Array(D) * Q'
Q = qr(randn(5, 5), Val(true)).Q
Q = qr(randn(5, 5), ColumnNorm()).Q
@test_throws ArgumentError lmul!(Q, D)
end

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,13 @@ LinearAlgebra.Transpose(a::ModInt{n}) where {n} = transpose(a)
A = [ModInt{2}(1) ModInt{2}(0); ModInt{2}(1) ModInt{2}(1)]
b = [ModInt{2}(1), ModInt{2}(0)]

@test A*(lu(A, Val(false))\b) == b
@test A*(lu(A, NoPivot())\b) == b

# Needed for pivoting:
Base.abs(a::ModInt{n}) where {n} = a
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k

@test A*(lu(A, Val(true))\b) == b
@test A*(lu(A, RowMaximum())\b) == b
end

@testset "Issue 18742" begin
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/test/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ rectangularQ(Q::LinearAlgebra.LQPackedQ) = convert(Array, Q)
lqa = lq(a)
x = lqa\b
l,q = lqa.L, lqa.Q
qra = qr(a, Val(true))
qra = qr(a, ColumnNorm())
@testset "Basic ops" begin
@test size(lqa,1) == size(a,1)
@test size(lqa,3) == 1
Expand Down
24 changes: 12 additions & 12 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dimg = randn(n)/2
lua = factorize(a)
@test_throws ErrorException lua.Z
l,u,p = lua.L, lua.U, lua.p
ll,ul,pl = lu(a)
ll,ul,pl = @inferred lu(a)
@test ll * ul ≈ a[pl,:]
@test l*u ≈ a[p,:]
@test (l*u)[invperm(p),:] ≈ a
Expand All @@ -85,9 +85,9 @@ dimg = randn(n)/2
end
κd = cond(Array(d),1)
@testset "Tridiagonal LU" begin
lud = lu(d)
lud = @inferred lu(d)
@test LinearAlgebra.issuccess(lud)
@test lu(lud) == lud
@test @inferred(lu(lud)) == lud
@test_throws ErrorException lud.Z
@test lud.L*lud.U ≈ lud.P*Array(d)
@test lud.L*lud.U ≈ Array(d)[lud.p,:]
Expand Down Expand Up @@ -199,14 +199,14 @@ dimg = randn(n)/2
@test lua.L*lua.U ≈ lua.P*a[:,1:n1]
end
@testset "Fat LU" begin
lua = lu(a[1:n1,:])
lua = @inferred lu(a[1:n1,:])
@test lua.L*lua.U ≈ lua.P*a[1:n1,:]
end
end

@testset "LU of Symmetric/Hermitian" begin
for HS in (Hermitian(a'a), Symmetric(a'a))
luhs = lu(HS)
luhs = @inferred lu(HS)
@test luhs.L*luhs.U ≈ luhs.P*Matrix(HS)
end
end
Expand All @@ -229,12 +229,12 @@ end
@test_throws SingularException lu!(copy(A); check = true)
@test !issuccess(lu(A; check = false))
@test !issuccess(lu!(copy(A); check = false))
@test_throws ZeroPivotException lu(A, Val(false))
@test_throws ZeroPivotException lu!(copy(A), Val(false))
@test_throws ZeroPivotException lu(A, Val(false); check = true)
@test_throws ZeroPivotException lu!(copy(A), Val(false); check = true)
@test !issuccess(lu(A, Val(false); check = false))
@test !issuccess(lu!(copy(A), Val(false); check = false))
@test_throws ZeroPivotException lu(A, NoPivot())
@test_throws ZeroPivotException lu!(copy(A), NoPivot())
@test_throws ZeroPivotException lu(A, NoPivot(); check = true)
@test_throws ZeroPivotException lu!(copy(A), NoPivot(); check = true)
@test !issuccess(lu(A, NoPivot(); check = false))
@test !issuccess(lu!(copy(A), NoPivot(); check = false))
F = lu(A; check = false)
@test sprint((io, x) -> show(io, "text/plain", x), F) ==
"Failed factorization of type $(typeof(F))"
Expand Down Expand Up @@ -320,7 +320,7 @@ include("trickyarithmetic.jl")
@testset "lu with type whose sum is another type" begin
A = TrickyArithmetic.A[1 2; 3 4]
ElT = TrickyArithmetic.D{TrickyArithmetic.C,TrickyArithmetic.C}
B = lu(A, Val(false))
B = lu(A, NoPivot())
@test B isa LinearAlgebra.LU{ElT,Matrix{ElT}}
end

Expand Down
Loading