Skip to content

Commit

Permalink
Improved tridiagonal and new Woodbury matrix algebra
Browse files Browse the repository at this point in the history
- Converts Tridiagonal to being Lapack-compatible
- Wraps Lapack's major tridiagonal routines
- Adds a new "Woodbury" type for solving equations using the Woodbury matrix identity
- Moves the functionality into appropriate functions in base
- Provides a set of tests for this functionality
  • Loading branch information
timholy committed Sep 5, 2012
1 parent 2a011e3 commit cbaad36
Show file tree
Hide file tree
Showing 10 changed files with 594 additions and 108 deletions.
6 changes: 6 additions & 0 deletions base/export.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,20 @@ export
SubOrDArray,
SubString,
TransformedString,
Tridiagonal,
VecOrMat,
Vector,
VersionNumber,
WeakKeyDict,
Woodbury,
Zip,
Stat,
Factorization,
Cholesky,
LU,
LUTridiagonal,
LDLT,
LDLTTridiagonal,
QR,
QRP,

Expand Down Expand Up @@ -634,6 +639,7 @@ export
randsym,
rank,
rref,
solve,
svd,
svdvals,
trace,
Expand Down
45 changes: 45 additions & 0 deletions base/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,48 @@ end
##ToDo: Add methods for rank(A::QRP{T}) and adjust the (\) method accordingly
## Add rcond methods for Cholesky, LU, QR and QRP types
## Lower priority: Add LQ, QL and RQ factorizations


#### Factorizations for Tridiagonal ####
type LDLTTridiagonal{T} <: Factorization{T}
D::Vector{T}
E::Vector{T}
end
function LDLTTridiagonal{T<:LapackScalar}(A::Tridiagonal{T})
D = copy(A.d)
E = copy(A.dl)
_jl_lapack_pttrf(D, E)
LDLTTridiagonal(D, E)
end
LDLT(A::Tridiagonal) = LDLTTridiagonal(A)

(\){T<:LapackScalar}(C::LDLTTridiagonal{T}, B::StridedVecOrMat{T}) =
_jl_lapack_pttrs(C.D, C.E, copy(B))

type LUTridiagonal{T} <: Factorization{T}
lu::Tridiagonal{T}
ipiv::Vector{Int32}
function LUTridiagonal(lu::Tridiagonal{T}, ipiv::Vector{Int32})
m, n = size(lu)
m == numel(ipiv) ? new(lu, ipiv) : error("LU: dimension mismatch")
end
end
show(io, lu::LUTridiagonal) = print(io, "LU decomposition of ", summary(lu.lu))

function LU{T<:LapackScalar}(A::Tridiagonal{T})
lu, ipiv = _jl_lapack_gttrf(copy(A))
LUTridiagonal{T}(lu, ipiv)
end

function lu(A::Tridiagonal)
error("lu(A) is not defined when A is Tridiagonal. Use LU(A) instead.")
end

function det(lu::LUTridiagonal)
prod(lu.lu.d) * (bool(sum(lu.ipiv .!= 1:n) % 2) ? -1 : 1)
end

det(A::Tridiagonal) = det(LU(A))

(\){T<:LapackScalar}(lu::LUTridiagonal{T}, B::StridedVecOrMat{T}) =
_jl_lapack_gttrs('N', lu.lu, lu.ipiv, copy(B))
4 changes: 3 additions & 1 deletion base/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
## linalg.jl: Basic Linear Algebra interface specifications ##
## linalg.jl: Basic Linear Algebra interface specifications and
## specialized matrix types

#
# This file mostly contains commented functions which are supposed
# to be defined in type-specific linalg_<type>.jl files.
Expand Down
184 changes: 184 additions & 0 deletions base/linalg_lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,187 @@ end
expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A))
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))


#### Tridiagonal matrix routines ####
function \{T<:LapackScalar}(M::Tridiagonal{T}, rhs::StridedVecOrMat{T})
if stride(rhs, 1) == 1
x = copy(rhs)
Mc = copy(M)
Mlu, x = _jl_lapack_gtsv(Mc, x)
return x
end
solve(M, rhs) # use the Julia "fallback"
end

eig(M::Tridiagonal) = _jl_lapack_stev('V', copy(M))

# Decompositions
for (gttrf, pttrf, elty) in
((:dgttrf_,:dpttrf_,:Float64),
(:sgttrf_,:spttrf_,:Float32),
(:zgttrf_,:zpttrf_,:Complex128),
(:cgttrf_,:cpttrf_,:Complex64))
@eval begin
function _jl_lapack_gttrf(M::Tridiagonal{$elty})
info = zero(Int32)
n = int32(length(M.d))
ipiv = Array(Int32, n)
ccall(dlsym(_jl_liblapack, $string(gttrf)),
Void,
(Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, M.dl, M.d, M.du, M.dutmp, ipiv, &info)
if info != 0 throw(LapackException(info)) end
M, ipiv
end
function _jl_lapack_pttrf(D::Vector{$elty}, E::Vector{$elty})
info = zero(Int32)
n = int32(length(D))
if length(E) != n-1
error("subdiagonal must be one element shorter than diagonal")
end
ccall(dlsym(_jl_liblapack, $string(pttrf)),
Void,
(Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{Int32}),
&n, D, E, &info)
if info != 0 throw(LapackException(info)) end
D, E
end
end
end
# Direct solvers
for (gtsv, ptsv, elty) in
((:dgtsv_,:dptsv_,:Float64),
(:sgtsv_,:sptsv,:Float32),
(:zgtsv_,:zptsv,:Complex128),
(:cgtsv_,:cptsv,:Complex64))
@eval begin
function _jl_lapack_gtsv(M::Tridiagonal{$elty}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_gtsv: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(M.d))
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(gtsv)),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, &nrhs, M.dl, M.d, M.du, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
M, B
end
function _jl_lapack_ptsv(M::Tridiagonal{$elty}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_ptsv: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(M.d))
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(ptsv)),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, &nrhs, M.d, M.dl, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
M, B
end
end
end
# Solvers using decompositions
for (gttrs, pttrs, elty) in
((:dgttrs_,:dpttrs_,:Float64),
(:sgttrs_,:spttrs,:Float32),
(:zgttrs_,:zpttrs,:Complex128),
(:cgttrs_,:cpttrs,:Complex64))
@eval begin
function _jl_lapack_gttrs(trans::LapackChar, M::Tridiagonal{$elty}, ipiv::Vector{Int32}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_gttrs: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(M.d))
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(gttrs)),
Void,
(Ptr{Uint8}, Ptr{Int32}, Ptr{Int32},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}),
&trans, &n, &nrhs, M.dl, M.d, M.du, M.dutmp, ipiv, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
B
end
function _jl_lapack_pttrs(D::Vector{$elty}, E::Vector{$elty}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_pttrs: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(D))
if length(E) != n-1
error("subdiagonal must be one element shorter than diagonal")
end
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(pttrs)),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, &nrhs, D, E, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
B
end
end
end
# Eigenvalue-eigenvector (symmetric only)
for (stev, elty) in
((:dstev_,:Float64),
(:sstev_,:Float32),
(:zstev_,:Complex128),
(:cstev_,:Complex64))
@eval begin
function _jl_lapack_stev(Z::Array, M::Tridiagonal{$elty})
n = int32(length(M.d))
if isempty(Z)
job = 'N'
ldz = 1
work = Array($elty, 0)
Ztmp = work
else
if stride(Z,1) != 1
error("_jl_lapack_stev: eigenvector matrix columns must have contiguous elements");
end
if size(Z, 1) != n
error("_jl_lapack_stev: eigenvector matrix columns are not of the correct size")
end
Ztmp = Z
job = 'V'
ldz = int32(stride(Z, 2))
work = Array($elty, max(1, 2*n-2))
end
info = zero(Int32)
ccall(dlsym(_jl_liblapack, $string(stev)),
Void,
(Ptr{Uint8}, Ptr{Int32},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{$elty}, Ptr{Int32}),
&job, &n, M.d, M.dl, Ztmp, &ldz, work, &info)
if info != 0 throw(LapackException(info)) end
M.d
end
end
end
function _jl_lapack_stev(job::LapackChar, M::Tridiagonal)
if job == 'N' || job == 'n'
Z = []
elseif job == 'V' || job == 'v'
n = length(M.d)
Z = Array(eltype(M), n, n)
else
error("Job type not recognized")
end
D = _jl_lapack_stev(Z, M)
return D, Z
end
Loading

0 comments on commit cbaad36

Please sign in to comment.