Skip to content

Commit

Permalink
Merge pull request #3225 from JuliaLang/anj/geig
Browse files Browse the repository at this point in the history
Add generalized eigenvalue decomposition
  • Loading branch information
andreasnoack committed May 28, 2013
2 parents bd8c04b + 2d16535 commit 9a46e26
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 6 deletions.
68 changes: 67 additions & 1 deletion base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ function eigfact!{T<:BlasReal}(A::StridedMatrix{T})

WR, WI, VL, VR = LAPACK.geev!('N', 'V', A)
if all(WI .== 0.) return Eigen(WR, VR) end
evec = complex(zeros(T, n, n))
evec = zeros(Complex{T}, n, n)
j = 1
while j <= n
if WI[j] == 0.0
Expand Down Expand Up @@ -476,6 +476,72 @@ end
inv(A::Eigen) = scale(A.vectors, 1.0/A.values)*A.vectors'
det(A::Eigen) = prod(A.values)

# Generalized eigenvalue problem.
type GeneralizedEigen{T,V}
values::Vector{V}
vectors::Matrix{T}
end

function getindex(A::GeneralizedEigen, d::Symbol)
if d == :values return A.values end
if d == :vectors return A.vectors end
error("No such type field")
end

function eigfact!{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T})
if ishermitian(A) & ishermitian(B) return eigfact!(Hermitian(A), Hermitian(B)) end
n = size(A, 1)
alphar, alphai, beta, ~, vr = LAPACK.ggev!('N', 'V', A, B)
if all(alphai .== 0)
return GeneralizedEigen(alphar ./ beta, vr)
else
vecs = zeros(Complex{T}, n, n)
j = 1
while j <= n
if alphai[j] == 0.0
vecs[:,j] = vr[:,j]
else
vecs[:,j] = vr[:,j] + im*vr[:,j+1]
vecs[:,j+1] = vr[:,j] - im*vr[:,j+1]
j += 1
end
j += 1
end
return GeneralizedEigen(complex(alphar, alphai)./beta, vecs)
end
end
function eigfact!{T<:BlasComplex}(A::StridedMatrix{T}, B::StridedMatrix{T})
if ishermitian(A) & ishermitian(B) return eigfact!(Hermitian(A), Hermitian(B)) end
alpha, beta, ~, vr = LAPACK.ggev!('N', 'V', A, B)
return GeneralizedEigen(alpha./beta, vr)
end
eigfact!(A::StridedMatrix, B::StridedMatrix) = eigfact!(float(A), float(B))
eigfact{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = eigfact!(copy(A), copy(B))
eigfact(A::StridedMatrix, B::StridedMatrix) = eigfact!(float(A), float(B))

function eig(A::StridedMatrix, B::StridedMatrix)
F = eigfact(A, B)
return F[:values], F[:vectors]
end

function eigvals!{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T})
if ishermitian(A) & ishermitian(B) return eigvals!(Hermitian(A), Hermitian(B)) end
alphar, alphai, beta, vl, vr = LAPACK.ggev!('N', 'N', A, B)
if all(alphai .== 0)
return alphar./beta
else
return complex(alphar, alphai)./beta
end
end
function eigvals!{T<:BlasComplex}(A::StridedMatrix{T}, B::StridedMatrix{T})
if ishermitian(A) & ishermitian(B) return eigvals!(Hermitian(A), Hermitian(B)) end
alpha, beta, vl, vr = LAPACK.ggev!('N', 'N', A, B)
return alpha./beta
end
eigvals!(A::AbstractMatrix, B::AbstractMatrix) = eigvals!(float(A), float(B))
eigvals{T<:BlasFloat}(A::AbstractMatrix{T}, B::AbstractMatrix{T}) = eigvals!(copy(A), copy(B))
eigvals(A::AbstractMatrix, B::AbstractMatrix) = eigvals!(float(A), float(B))

# SVD
type SVD{T<:BlasFloat,Tr} <: Factorization{T}
U::Matrix{T}
Expand Down
10 changes: 9 additions & 1 deletion base/linalg/hermitian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ end
Hermitian{T<:BlasFloat}(S::Matrix{T}, uplo::Char) = Hermitian{T}(S, uplo)
Hermitian(A::StridedMatrix) = Hermitian(A, 'U')

copy(A::Hermitian) = Hermitian(copy(A.S), A.uplo)
size(A::Hermitian, args...) = size(A.S, args...)
print_matrix(io::IO, A::Hermitian) = print_matrix(io, full(A))
full(A::Hermitian) = A.S
Expand All @@ -32,13 +33,20 @@ end
inv(A::Hermitian) = inv(BunchKaufman(copy(A.S), A.uplo))

eigfact!(A::Hermitian) = Eigen(LAPACK.syevr!('V', 'A', A.uplo, A.S, 0.0, 0.0, 0, 0, -1.0)...)
eigfact(A::Hermitian) = Eigen(LAPACK.syevr!('V', 'A', A.uplo, copy(A.S), 0.0, 0.0, 0, 0, -1.0)...)
eigfact(A::Hermitian) = eigfact!(copy(A))
eigvals(A::Hermitian, il::Int, ih::Int) = LAPACK.syevr!('N', 'I', A.uplo, copy(A.S), 0.0, 0.0, il, ih, -1.0)[1]
eigvals(A::Hermitian, vl::Real, vh::Real) = LAPACK.syevr!('N', 'V', A.uplo, copy(A.S), vl, vh, 0, 0, -1.0)[1]
eigvals(A::Hermitian) = eigvals(A, 1, size(A, 1))
eigmax(A::Hermitian) = eigvals(A, size(A, 1), size(A, 1))[1]
eigmin(A::Hermitian) = eigvals(A, 1, 1)[1]

function eigfact!(A::Hermitian, B::Hermitian)
vals, vecs, _ = LAPACK.sygvd!(1, 'V', A.uplo, A.S, B.uplo == A.uplo ? B.S : B.S')
return GeneralizedEigen(vals, vecs)
end
eigfact(A::Hermitian, B::Hermitian) = eigfact!(copy(A), copy(B))
eigvals!(A::Hermitian, B::Hermitian) = LAPACK.sygvd!(1, 'N', A.uplo, A.S, B.uplo == A.uplo ? B.S : B.S')[1]

function expm(A::Hermitian)
F = eigfact(A)
scale(F[:vectors], exp(F[:values])) * F[:vectors]'
Expand Down
236 changes: 234 additions & 2 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,125 @@ for (geev, gesvd, gesdd, ggsvd, elty, relty) in
end
end
end
for (ggev, elty) in
((:dggev_,:Float64),
(:sggev_,:Float32))
@eval begin
# SUBROUTINE DGGEV( JOBVL, JOBVR, N, A, LDA, B, LDB, ALPHAR, ALPHAI,
# $ BETA, VL, LDVL, VR, LDVR, WORK, LWORK, INFO )
# *
# * -- LAPACK driver routine (version 3.2) --
# * -- LAPACK is a software package provided by Univ. of Tennessee, --
# * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
# * November 2006
# *
# * .. Scalar Arguments ..
# CHARACTER JOBVL, JOBVR
# INTEGER INFO, LDA, LDB, LDVL, LDVR, LWORK, N
# * ..
# * .. Array Arguments ..
# DOUBLE PRECISION A( LDA, * ), ALPHAI( * ), ALPHAR( * ),
# $ B( LDB, * ), BETA( * ), VL( LDVL, * ),
# $ VR( LDVR, * ), WORK( * )
function ggev!(jobvl::BlasChar, jobvr::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
chkstride1(A,B)
n = size(A, 1)
if size(A, 2) != n | size(B, 1) != size(B, 2) throw(DimensionMismatch("Matrices must be square")) end
if size(B, 1) != n throw(DimensionMismatch("Matrices must have same size")) end
lda = max(1, n)
ldb = max(1, n)
alphar = Array($elty, n)
alphai = Array($elty, n)
beta = Array($elty, n)
ldvl = jobvl == 'V' ? n : 1
vl = Array($elty, ldvl, n)
ldvr = jobvr == 'V' ? n : 1
vr = Array($elty, ldvr, n)
work = Array($elty, 1)
lwork = -one(BlasInt)
info = Array(BlasInt, 1)
for i = 1:2
ccall(($(string(ggev)), liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}),
&jobvl, &jobvr, &n, A,
&lda, B, &ldb, alphar,
alphai, beta, vl, &ldvl,
vr, &ldvr, work, &lwork,
info)
if i == 1
lwork = blas_int(work[1])
work = Array($elty, lwork)
end
end
if info[1] != 0; throw(LAPACKException(info[1])); end
return alphar, alphai, beta, vl, vr
end
end
end
for (ggev, elty, relty) in
((:zggev_,:Complex128,:Float64),
(:cggev_,:Complex64,:Float32))
@eval begin
# SUBROUTINE ZGGEV( JOBVL, JOBVR, N, A, LDA, B, LDB, ALPHA, BETA,
# $ VL, LDVL, VR, LDVR, WORK, LWORK, RWORK, INFO )
# *
# * -- LAPACK driver routine (version 3.2) --
# * -- LAPACK is a software package provided by Univ. of Tennessee, --
# * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
# * November 2006
# *
# * .. Scalar Arguments ..
# CHARACTER JOBVL, JOBVR
# INTEGER INFO, LDA, LDB, LDVL, LDVR, LWORK, N
# * ..
# * .. Array Arguments ..
# DOUBLE PRECISION RWORK( * )
# COMPLEX*16 A( LDA, * ), ALPHA( * ), B( LDB, * ),
# $ BETA( * ), VL( LDVL, * ), VR( LDVR, * ),
# $ WORK( * )
function ggev!(jobvl::BlasChar, jobvr::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
chkstride1(A,B)
n = size(A, 1)
if size(A, 2) != n | size(B, 1) != size(B, 2) throw(DimensionMismatch("Matrices must be square")) end
if size(B, 1) != n throw(DimensionMismatch("Matrices must have same size")) end
lda = max(1, n)
ldb = max(1, n)
alpha = Array($elty, n)
beta = Array($elty, n)
ldvl = jobvl == 'V' ? n : 1
vl = Array($elty, ldvl, n)
ldvr = jobvr == 'V' ? n : 1
vr = Array($elty, ldvr, n)
work = Array($elty, 1)
lwork = -one(BlasInt)
rwork = Array($relty, 8n)
info = Array(BlasInt, 1)
for i = 1:2
ccall(($(string(ggev)), liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$relty},
Ptr{BlasInt}),
&jobvl, &jobvr, &n, A,
&lda, B, &ldb, alpha,
beta, vl, &ldvl, vr,
&ldvr, work, &lwork, rwork,
info)
if i == 1
lwork = blas_int(real(work[1]))
work = Array($elty, lwork)
end
end
if info[1] != 0; throw(LAPACKException(info[1])); end
return alpha, beta, vl, vr
end
end
end

# (GT) General tridiagonal, decomposition, solver and direct solver
for (gtsv, gttrf, gttrs, elty) in
Expand Down Expand Up @@ -1684,8 +1803,121 @@ for (syconv, syev, sysv, sytrf, sytri, sytrs, elty, relty) in
end
end
end


for (sygvd, elty) in
((:dsygvd_,:Float64),
(:ssygvd_,:Float32))
@eval begin
# SUBROUTINE DSYGVD( ITYPE, JOBZ, UPLO, N, A, LDA, B, LDB, W, WORK,
# $ LWORK, IWORK, LIWORK, INFO )
# *
# * -- LAPACK driver routine (version 3.3.1) --
# * -- LAPACK is a software package provided by Univ. of Tennessee, --
# * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
# * -- April 2011 --
# *
# * .. Scalar Arguments ..
# CHARACTER JOBZ, UPLO
# INTEGER INFO, ITYPE, LDA, LDB, LIWORK, LWORK, N
# * ..
# * .. Array Arguments ..
# INTEGER IWORK( * )
# DOUBLE PRECISION A( LDA, * ), B( LDB, * ), W( * ), WORK( * )
function sygvd!(itype::Integer, jobz::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
chkstride1(A,B)
n = size(A, 1)
if size(A, 2) != n | size(B, 1) != size(B, 2) throw(DimensionMismatch("Matrices must be square")) end
if size(B, 1) != n throw(DimensionMismatch("Matrices must have same size")) end
lda = max(1, n)
ldb = max(1, n)
w = Array($elty, n)
work = Array($elty, 1)
lwork = -one(BlasInt)
iwork = Array(BlasInt, 1)
liwork = -one(BlasInt)
info = Array(BlasInt, 1)
for i = 1:2
ccall(($(string(sygvd)),liblapack), Void,
(Ptr{BlasInt}, Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{BlasInt}),
&itype, &jobz, &uplo, &n,
A, &lda, B, &ldb,
w, work, &lwork, iwork,
&liwork, info)
if i == 1
lwork = blas_int(work[1])
work = Array($elty, lwork)
liwork = iwork[1]
iwork = Array(BlasInt, liwork)
end
end
if info[1] < 0 throw(LAPACKException(info[1])) end
if info[1] > 0 throw(SingularException(info[1])) end
return w, A, B
end
end
end
for (sygvd, elty, relty) in
((:zhegvd_,:Complex128,:Float64),
(:chegvd_,:Complex64,:Float32))
@eval begin
# SUBROUTINE ZHEGVD( ITYPE, JOBZ, UPLO, N, A, LDA, B, LDB, W, WORK,
# $ LWORK, RWORK, LRWORK, IWORK, LIWORK, INFO )
# *
# * -- LAPACK driver routine (version 3.3.1) --
# * -- LAPACK is a software package provided by Univ. of Tennessee, --
# * -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
# * -- April 2011 --
# *
# * .. Scalar Arguments ..
# CHARACTER JOBZ, UPLO
# INTEGER INFO, ITYPE, LDA, LDB, LIWORK, LRWORK, LWORK, N
# * ..
# * .. Array Arguments ..
# INTEGER IWORK( * )
# DOUBLE PRECISION RWORK( * ), W( * )
# COMPLEX*16 A( LDA, * ), B( LDB, * ), WORK( * )
function sygvd!(itype::Integer, jobz::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
chkstride1(A,B)
n = size(A, 1)
if size(A, 2) != n | size(B, 1) != size(B, 2) throw(DimensionMismatch("Matrices must be square")) end
if size(B, 1) != n throw(DimensionMismatch("Matrices must have same size")) end
lda = max(1, n)
ldb = max(1, n)
w = Array($relty, n)
work = Array($elty, 1)
lwork = -one(BlasInt)
iwork = Array(BlasInt, 1)
liwork = -one(BlasInt)
rwork = Array($relty)
lrwork = -one(BlasInt)
info = Array(BlasInt, 1)
for i = 1:2
ccall(($(string(sygvd)),liblapack), Void,
(Ptr{BlasInt}, Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$relty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$relty},
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
&itype, &jobz, &uplo, &n,
A, &lda, B, &ldb,
w, work, &lwork, rwork,
&lrwork, iwork, &liwork, info)
if i == 1
lwork = blas_int(real(work[1]))
work = Array($elty, lwork)
liwork = iwork[1]
iwork = Array(BlasInt, liwork)
lrwork = blas_int(rwork[1])
rwork = Array($relty, lrwork)
end
end
if info[1] < 0 throw(LAPACKException(info[1])) end
if info[1] > 0 throw(SingularException(info[1])) end
return w, A, B
end
end
end

#Find the leading dimension
ld = x->max(1,stride(x,2))
Expand Down
12 changes: 10 additions & 2 deletions doc/stdlib/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ Linear algebra functions in Julia are largely implemented by calling functions f

Compute eigenvalues and eigenvectors of A

.. function:: eig(A, B) -> D, V

Compute generalized eigenvalues and vectors of A and B

.. function:: eigvals(A)

Returns the eigenvalues of ``A``.
Expand All @@ -115,9 +119,13 @@ Linear algebra functions in Julia are largely implemented by calling functions f

Compute the eigenvalue decomposition of ``A`` and return an ``Eigen`` object. If ``F`` is the factorization object, the eigenvalues can be accessed with ``F[:values]`` and the eigenvectors with ``F[:vectors]``. The following functions are available for ``Eigen`` objects: ``inv``, ``det``.

.. function:: eigfact!(A)
.. function:: eigfact(A, B)

Compute the generalized eigenvalue decomposition of ``A`` and ``B`` and return an ``GeneralizedEigen`` object. If ``F`` is the factorization object, the eigenvalues can be accessed with ``F[:values]`` and the eigenvectors with ``F[:vectors]``.

.. function:: eigfact!(A, [B])

``eigfact!`` is the same as ``eigfact`` but saves space by overwriting the input A, instead of creating a copy.
``eigfact!`` is the same as ``eigfact`` but saves space by overwriting the input A (and B), instead of creating a copy.

.. function:: hessfact(A)

Expand Down
Loading

0 comments on commit 9a46e26

Please sign in to comment.