Skip to content

Commit

Permalink
Add getindex method to LU{Triangular} for extracting factors
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack committed Jul 13, 2015
1 parent e3dfa56 commit 777b81d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
41 changes: 30 additions & 11 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,18 @@ function ipiv2perm{T}(v::AbstractVector{T}, maxi::Integer)
return p
end

function getindex{T,S<:StridedMatrix}(A::LU{T,S}, d::Symbol)
m, n = size(A)
function getindex{T,S<:StridedMatrix}(F::LU{T,S}, d::Symbol)
m, n = size(F)
if d == :L
L = tril!(A.factors[1:m, 1:min(m,n)])
L = tril!(F.factors[1:m, 1:min(m,n)])
for i = 1:min(m,n); L[i,i] = one(T); end
return L
elseif d == :U
return triu!(A.factors[1:min(m,n), 1:n])
return triu!(F.factors[1:min(m,n), 1:n])
elseif d == :p
return ipiv2perm(A.ipiv, m)
return ipiv2perm(F.ipiv, m)
elseif d == :P
p = A[:p]
P = zeros(T, m, m)
for i in 1:m
P[i,p[i]] = one(T)
end
return P
return eye(T, m)[:,invperm(F[:p])]
else
throw(KeyError(d))
end
Expand Down Expand Up @@ -263,6 +258,30 @@ end

factorize(A::Tridiagonal) = lufact(A)

function getindex{T}(F::Base.LinAlg.LU{T,Tridiagonal{T}}, d::Symbol)
m, n = size(F)
if d == :L
L = full(Bidiagonal(ones(T, n), F.factors.dl, false))
for i = 2:n
tmp = L[F.ipiv[i], 1:i - 1]
L[F.ipiv[i], 1:i - 1] = L[i, 1:i - 1]
L[i, 1:i - 1] = tmp
end
return L
elseif d == :U
U = full(Bidiagonal(F.factors.d, F.factors.du, true))
for i = 1:n - 2
U[i,i + 2] = F.factors.du2[i]
end
return U
elseif d == :p
return ipiv2perm(F.ipiv, m)
elseif d == :P
return eye(T, m)[:,invperm(F[:p])]
end
throw(KeyError(d))
end

# See dgtts2.f
function A_ldiv_B!{T}(A::LU{T,Tridiagonal{T}}, B::AbstractVecOrMat)
n = size(A,1)
Expand Down
5 changes: 4 additions & 1 deletion test/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

debug = false

using Base.Test

n = 10

# Split n into 2 parts for tests needing two matrices
Expand Down Expand Up @@ -57,6 +59,8 @@ debug && println("(Automatic) Square LU decomposition")
debug && println("Tridiagonal LU")
κd = cond(full(d),1)
lud = lufact(d)
@test_approx_eq lud[:L]*lud[:U] lud[:P]*full(d)
@test_approx_eq lud[:L]*lud[:U] full(d)[lud[:p],:]
@test norm(d*(lud\b) - b, 1) < ε*κd*n*2 # Two because the right hand side has two columns
if eltya <: Real
@test norm((lud.'\b) - full(d.')\b, 1) < ε*κd*n*2 # Two because the right hand side has two columns
Expand Down Expand Up @@ -112,6 +116,5 @@ for elty in (Float32, Float64, Complex64, Complex128)
# @test norm(F[:vectors]*Diagonal(F[:values])/F[:vectors] - A) > 0.01
end


@test @inferred(logdet(Complex64[1.0f0 0.5f0; 0.5f0 -1.0f0])) === 0.22314355f0 + 3.1415927f0im
@test_throws DomainError logdet([1 1; 1 -1])

0 comments on commit 777b81d

Please sign in to comment.