Skip to content

Commit

Permalink
Merge pull request #12137 from JuliaLang/anj/lu
Browse files Browse the repository at this point in the history
Add getindex method to LU{Tridiagonal} for extracting factors (and full)
  • Loading branch information
andreasnoack committed Jul 14, 2015
2 parents e2b76ad + 857b367 commit 0c2d664
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
76 changes: 65 additions & 11 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,18 @@ function ipiv2perm{T}(v::AbstractVector{T}, maxi::Integer)
return p
end

function getindex{T,S<:StridedMatrix}(A::LU{T,S}, d::Symbol)
m, n = size(A)
function getindex{T,S<:StridedMatrix}(F::LU{T,S}, d::Symbol)
m, n = size(F)
if d == :L
L = tril!(A.factors[1:m, 1:min(m,n)])
L = tril!(F.factors[1:m, 1:min(m,n)])
for i = 1:min(m,n); L[i,i] = one(T); end
return L
elseif d == :U
return triu!(A.factors[1:min(m,n), 1:n])
return triu!(F.factors[1:min(m,n), 1:n])
elseif d == :p
return ipiv2perm(A.ipiv, m)
return ipiv2perm(F.ipiv, m)
elseif d == :P
p = A[:p]
P = zeros(T, m, m)
for i in 1:m
P[i,p[i]] = one(T)
end
return P
return eye(T, m)[:,invperm(F[:p])]
else
throw(KeyError(d))
end
Expand Down Expand Up @@ -263,6 +258,65 @@ end

factorize(A::Tridiagonal) = lufact(A)

function getindex{T}(F::Base.LinAlg.LU{T,Tridiagonal{T}}, d::Symbol)
m, n = size(F)
if d == :L
L = full(Bidiagonal(ones(T, n), F.factors.dl, false))
for i = 2:n
tmp = L[F.ipiv[i], 1:i - 1]
L[F.ipiv[i], 1:i - 1] = L[i, 1:i - 1]
L[i, 1:i - 1] = tmp
end
return L
elseif d == :U
U = full(Bidiagonal(F.factors.d, F.factors.du, true))
for i = 1:n - 2
U[i,i + 2] = F.factors.du2[i]
end
return U
elseif d == :p
return ipiv2perm(F.ipiv, m)
elseif d == :P
return eye(T, m)[:,invperm(F[:p])]
end
throw(KeyError(d))
end

function full{T}(F::Base.LinAlg.LU{T,Tridiagonal{T}})
n = size(F, 1)

dl = copy(F.factors.dl)
d = copy(F.factors.d)
du = copy(F.factors.du)
du2 = copy(F.factors.du2)

for i = n - 1:-1:1
li = dl[i]
dl[i] = li*d[i]
d[i + 1] += li*du[i]
if i < n - 1
du[i + 1] += li*du2[i]
end

if F.ipiv[i] != i
tmp = dl[i]
dl[i] = d[i]
d[i] = tmp

tmp = d[i + 1]
d[i + 1] = du[i]
du[i] = tmp

if i < n - 1
tmp = du[i + 1]
du[i + 1] = du2[i]
du2[i] = tmp
end
end
end
return Tridiagonal(dl, d, du)
end

# See dgtts2.f
function A_ldiv_B!{T}(A::LU{T,Tridiagonal{T}}, B::AbstractVecOrMat)
n = size(A,1)
Expand Down
6 changes: 5 additions & 1 deletion test/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

debug = false

using Base.Test

n = 10

# Split n into 2 parts for tests needing two matrices
Expand Down Expand Up @@ -57,6 +59,9 @@ debug && println("(Automatic) Square LU decomposition")
debug && println("Tridiagonal LU")
κd = cond(full(d),1)
lud = lufact(d)
@test_approx_eq lud[:L]*lud[:U] lud[:P]*full(d)
@test_approx_eq lud[:L]*lud[:U] full(d)[lud[:p],:]
@test_approx_eq full(lud) d
@test norm(d*(lud\b) - b, 1) < ε*κd*n*2 # Two because the right hand side has two columns
if eltya <: Real
@test norm((lud.'\b) - full(d.')\b, 1) < ε*κd*n*2 # Two because the right hand side has two columns
Expand Down Expand Up @@ -112,6 +117,5 @@ for elty in (Float32, Float64, Complex64, Complex128)
# @test norm(F[:vectors]*Diagonal(F[:values])/F[:vectors] - A) > 0.01
end


@test @inferred(logdet(Complex64[1.0f0 0.5f0; 0.5f0 -1.0f0])) === 0.22314355f0 + 3.1415927f0im
@test_throws DomainError logdet([1 1; 1 -1])

0 comments on commit 0c2d664

Please sign in to comment.