Skip to content

Commit

Permalink
QR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed Jun 28, 2018
1 parent 28e7bdf commit 38fcc3a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 52 deletions.
48 changes: 22 additions & 26 deletions src/qr.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
_thin_must_hold(thin) =
thin || throw(ArgumentError("For the sake of type stability, `thin = true` must hold."))
# define our own struct since LinearAlgebra.QR are restricted to Matrix
struct QR{Q,R}
Q::Q
R::R
end

# iteration for destructuring into components
Base.iterate(S::QR) = (S.Q, Val(:R))
Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:done))
Base.iterate(S::QR, ::Val{:done}) = nothing

"""
qr(A::StaticMatrix, pivot=Val{false}; thin=true) -> Q, R, [p]
qr(A::StaticMatrix, pivot=Val(false)) -> Q, R, [p]
Compute the QR factorization of `A` such that `A = Q*R` or `A[:,p] = Q*R`, see [`qr`](@ref).
This function does not support `thin=false` keyword option due to type inference instability.
To use this option call `qr(A, pivot, Val{false})` instead.
"""
@inline function qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}} = Val(false); thin::Bool=true)
_thin_must_hold(thin)
return _qr(Size(A), A, pivot, Val(true))
@inline function qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}} = Val(false))
Q, R = _qr(Size(A), A, pivot)
return QR(Q, R)
end


@inline qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}}, thin::Union{Val{false}, Val{true}}) = _qr(Size(A), A, pivot, thin)


_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))


@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Union{Val{false}, Val{true}} = Val(false), thin::Union{Val{false}, Val{true}} = Val(true)) where {sA, TA}

isthin = thin == Val(true)
@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Union{Val{false}, Val{true}} = Val(false)) where {sA, TA}

SizeQ = Size( sA[1], isthin ? diagsize(Size(A)) : sA[1] )
SizeQ = Size( sA[1], diagsize(Size(A)) )
SizeR = Size( diagsize(Size(A)), sA[2] )

if pivot == Val(true)
return quote
@_inline_meta
Q0, R0, p0 = qr(Matrix(A), pivot, thin=$isthin)
QRp = qr(Matrix(A), pivot)
Q0, R0, p0 = QRp.Q, QRp.R, QRp.p
T = _qreltype(TA)
return similar_type(A, T, $(SizeQ))(Q0),
similar_type(A, T, $(SizeR))(R0),
Expand All @@ -40,14 +41,13 @@ _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
if (sA[1]*sA[1] + sA[1]*sA[2])÷2 * diagsize(Size(A)) < 17*17*17
return quote
@_inline_meta
return qr_unrolled(Size(A), A, pivot, thin)
return qr_unrolled(Size(A), A, pivot)
end
else
return quote
@_inline_meta
Q0R0 = qr(Matrix(A), pivot)
Q0 = Q0R0.Q
R0 = Q0R0.R
Q0, R0 = Matrix(Q0R0.Q), Q0R0.R
T = _qreltype(TA)
return similar_type(A, T, $(SizeQ))(Q0),
similar_type(A, T, $(SizeR))(R0)
Expand All @@ -64,7 +64,7 @@ end
# in the case of `thin=false` Q is full, but R is still reduced, see [`qr`](@ref).
#
# For original source code see below.
@generated function qr_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Val{false}, thin::Union{Val{false}, Val{true}} = Val(true)) where {sA, TA}
@generated function qr_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Val{false}) where {sA, TA}
m, n = sA[1], sA[2]

Q = [Symbol("Q_$(i)_$(j)") for i = 1:m, j = 1:m]
Expand Down Expand Up @@ -124,11 +124,7 @@ end
end

# truncate Q and R sizes in LAPACK consilient way
if thin == Val(true)
mQ, nQ = m, min(m, n)
else
mQ, nQ = m, m
end
mQ, nQ = m, min(m, n)
mR, nR = min(m, n), n

return quote
Expand Down
29 changes: 4 additions & 25 deletions test/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,20 @@ srand(42)

T = eltype(arr)

# thin=true case
QR = @inferred qr(arr)
@test QR isa Tuple
@test length(QR) == 2
Q, R = QR
@test QR isa StaticArrays.QR
Q, R = QR # destructing via iteration
@test Q isa StaticMatrix
@test R isa StaticMatrix
@test eltype(Q) == eltype(R) == typeof((one(T)*zero(T) + zero(T))/norm([one(T)]))

Q_ref,R_ref = qr(Matrix(arr))
@test abs.(Q) abs.(Q_ref) # QR is unique up to diag(Q) signs
Q_ref, R_ref = qr(Matrix(arr))
@test abs.(Q) abs.(Matrix(Q_ref)) # QR is unique up to diag(Q) signs
@test abs.(R) abs.(R_ref)
@test Q*R arr
@test Q'*Q one(Q'*Q)
@test istriu(R)

# fat (thin=false) case
QR = @inferred qr(arr, Val(false), Val(false))
@test QR isa Tuple
@test length(QR) == 2
Q, R = QR
@test Q isa StaticMatrix
@test R isa StaticMatrix
@test eltype(Q) == eltype(R) == typeof((one(T)*zero(T) + zero(T))/norm([one(T)]))

Q_ref,R_ref = qr(Matrix(arr))
@test abs.(Q) abs.(Q_ref) # QR is unique up to diag(Q) signs
@test abs.(R) abs.(R_ref)
R0 = vcat(R, @SMatrix(zeros(size(arr)[1]-size(R)[1], size(R)[2])) )
@test Q*R0 arr
@test Q'*Q one(Q'*Q)
@test istriu(R)

# # pivot=true cases are not released yet
# pivot = Val(true)
# QRp = @inferred qr(arr, pivot)
Expand All @@ -62,8 +43,6 @@ srand(42)
# @test p == p_ref
end

@test_throws ArgumentError qr(@SMatrix randn(1,2); thin=false)

for eltya in (Float32, Float64, BigFloat, Int),
rel in (real, complex),
sz in [(3,3), (3,4), (4,3)]
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ include("expm.jl")
include("sqrtm.jl")
include("lyap.jl")
include("lu.jl")
# srand(42); include("qr.jl") # hmm ...
srand(42); include("qr.jl")
srand(42); include("chol.jl") # hermitian_type(::Type{Any}) for block algorithm
include("deque.jl")
include("io.jl")
Expand Down

0 comments on commit 38fcc3a

Please sign in to comment.