From 38fcc3ada076b629120fa9a6822ee6ddd2687924 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Thu, 28 Jun 2018 18:44:25 +0200 Subject: [PATCH] QR fixes --- src/qr.jl | 48 ++++++++++++++++++++++-------------------------- test/qr.jl | 29 ++++------------------------- test/runtests.jl | 2 +- 3 files changed, 27 insertions(+), 52 deletions(-) diff --git a/src/qr.jl b/src/qr.jl index 6c31fde4f..822e8684f 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -1,36 +1,37 @@ -_thin_must_hold(thin) = - thin || throw(ArgumentError("For the sake of type stability, `thin = true` must hold.")) +# define our own struct since LinearAlgebra.QR are restricted to Matrix +struct QR{Q,R} + Q::Q + R::R +end + +# iteration for destructuring into components +Base.iterate(S::QR) = (S.Q, Val(:R)) +Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:done)) +Base.iterate(S::QR, ::Val{:done}) = nothing """ - qr(A::StaticMatrix, pivot=Val{false}; thin=true) -> Q, R, [p] + qr(A::StaticMatrix, pivot=Val(false)) -> Q, R, [p] Compute the QR factorization of `A` such that `A = Q*R` or `A[:,p] = Q*R`, see [`qr`](@ref). -This function does not support `thin=false` keyword option due to type inference instability. -To use this option call `qr(A, pivot, Val{false})` instead. """ -@inline function qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}} = Val(false); thin::Bool=true) - _thin_must_hold(thin) - return _qr(Size(A), A, pivot, Val(true)) +@inline function qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}} = Val(false)) + Q, R = _qr(Size(A), A, pivot) + return QR(Q, R) end - -@inline qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}}, thin::Union{Val{false}, Val{true}}) = _qr(Size(A), A, pivot, thin) - - _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T)))) -@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Union{Val{false}, Val{true}} = Val(false), thin::Union{Val{false}, Val{true}} = Val(true)) where {sA, TA} - - isthin = thin == Val(true) +@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Union{Val{false}, Val{true}} = Val(false)) where {sA, TA} - SizeQ = Size( sA[1], isthin ? diagsize(Size(A)) : sA[1] ) + SizeQ = Size( sA[1], diagsize(Size(A)) ) SizeR = Size( diagsize(Size(A)), sA[2] ) if pivot == Val(true) return quote @_inline_meta - Q0, R0, p0 = qr(Matrix(A), pivot, thin=$isthin) + QRp = qr(Matrix(A), pivot) + Q0, R0, p0 = QRp.Q, QRp.R, QRp.p T = _qreltype(TA) return similar_type(A, T, $(SizeQ))(Q0), similar_type(A, T, $(SizeR))(R0), @@ -40,14 +41,13 @@ _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T)))) if (sA[1]*sA[1] + sA[1]*sA[2])÷2 * diagsize(Size(A)) < 17*17*17 return quote @_inline_meta - return qr_unrolled(Size(A), A, pivot, thin) + return qr_unrolled(Size(A), A, pivot) end else return quote @_inline_meta Q0R0 = qr(Matrix(A), pivot) - Q0 = Q0R0.Q - R0 = Q0R0.R + Q0, R0 = Matrix(Q0R0.Q), Q0R0.R T = _qreltype(TA) return similar_type(A, T, $(SizeQ))(Q0), similar_type(A, T, $(SizeR))(R0) @@ -64,7 +64,7 @@ end # in the case of `thin=false` Q is full, but R is still reduced, see [`qr`](@ref). # # For original source code see below. -@generated function qr_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Val{false}, thin::Union{Val{false}, Val{true}} = Val(true)) where {sA, TA} +@generated function qr_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Val{false}) where {sA, TA} m, n = sA[1], sA[2] Q = [Symbol("Q_$(i)_$(j)") for i = 1:m, j = 1:m] @@ -124,11 +124,7 @@ end end # truncate Q and R sizes in LAPACK consilient way - if thin == Val(true) - mQ, nQ = m, min(m, n) - else - mQ, nQ = m, m - end + mQ, nQ = m, min(m, n) mR, nR = min(m, n), n return quote diff --git a/test/qr.jl b/test/qr.jl index c68db8d02..896b6047b 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -13,39 +13,20 @@ srand(42) T = eltype(arr) - # thin=true case QR = @inferred qr(arr) - @test QR isa Tuple - @test length(QR) == 2 - Q, R = QR + @test QR isa StaticArrays.QR + Q, R = QR # destructing via iteration @test Q isa StaticMatrix @test R isa StaticMatrix @test eltype(Q) == eltype(R) == typeof((one(T)*zero(T) + zero(T))/norm([one(T)])) - Q_ref,R_ref = qr(Matrix(arr)) - @test abs.(Q) ≈ abs.(Q_ref) # QR is unique up to diag(Q) signs + Q_ref, R_ref = qr(Matrix(arr)) + @test abs.(Q) ≈ abs.(Matrix(Q_ref)) # QR is unique up to diag(Q) signs @test abs.(R) ≈ abs.(R_ref) @test Q*R ≈ arr @test Q'*Q ≈ one(Q'*Q) @test istriu(R) - # fat (thin=false) case - QR = @inferred qr(arr, Val(false), Val(false)) - @test QR isa Tuple - @test length(QR) == 2 - Q, R = QR - @test Q isa StaticMatrix - @test R isa StaticMatrix - @test eltype(Q) == eltype(R) == typeof((one(T)*zero(T) + zero(T))/norm([one(T)])) - - Q_ref,R_ref = qr(Matrix(arr)) - @test abs.(Q) ≈ abs.(Q_ref) # QR is unique up to diag(Q) signs - @test abs.(R) ≈ abs.(R_ref) - R0 = vcat(R, @SMatrix(zeros(size(arr)[1]-size(R)[1], size(R)[2])) ) - @test Q*R0 ≈ arr - @test Q'*Q ≈ one(Q'*Q) - @test istriu(R) - # # pivot=true cases are not released yet # pivot = Val(true) # QRp = @inferred qr(arr, pivot) @@ -62,8 +43,6 @@ srand(42) # @test p == p_ref end - @test_throws ArgumentError qr(@SMatrix randn(1,2); thin=false) - for eltya in (Float32, Float64, BigFloat, Int), rel in (real, complex), sz in [(3,3), (3,4), (4,3)] diff --git a/test/runtests.jl b/test/runtests.jl index ffa81c014..4cdb06a42 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,7 +37,7 @@ include("expm.jl") include("sqrtm.jl") include("lyap.jl") include("lu.jl") -# srand(42); include("qr.jl") # hmm ... +srand(42); include("qr.jl") srand(42); include("chol.jl") # hermitian_type(::Type{Any}) for block algorithm include("deque.jl") include("io.jl")