Skip to content

Commit

Permalink
Provide sqrtm of Triangular matrices (ref #4006)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahao committed Jan 26, 2014
1 parent 89c8cce commit 9bd4618
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 28 deletions.
30 changes: 3 additions & 27 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,41 +292,17 @@ function sqrtm{T<:Real}(A::StridedMatrix{T}, cond::Bool)

n = chksquare(A)
SchurF = schurfact!(complex(A))
R = zeros(eltype(SchurF[:T]), n, n)
for j = 1:n
R[j,j] = sqrt(SchurF[:T][j,j])
for i = j - 1:-1:1
r = SchurF[:T][i,j]
for k = i + 1:j - 1
r -= R[i,k]*R[k,j]
end
if r != 0
R[i,j] = r / (R[i,i] + R[j,j])
end
end
end
R = full(sqrtm(Triangular(SchurF[:T])))
retmat = SchurF[:vectors]*R*SchurF[:vectors]'
retmat2= all(imag(retmat) .== 0) ? real(retmat) : retmat
cond ? (retmat2, alpha) : retmat2
cond ? (retmat2, norm(R)^2/norm(SchurF[:T])) : retmat2
end
function sqrtm{T<:Complex}(A::StridedMatrix{T}, cond::Bool)
ishermitian(A) && return sqrtm(Hermitian(A), cond)

n = chksquare(A)
SchurF = schurfact(A)
R = zeros(eltype(SchurF[:T]), n, n)
for j = 1:n
R[j,j] = sqrt(SchurF[:T][j,j])
for i = j - 1:-1:1
r = SchurF[:T][i,j]
for k = i + 1:j - 1
r -= R[i,k]*R[k,j]
end
if r != 0
R[i,j] = r / (R[i,i] + R[j,j])
end
end
end
R = full(sqrtm(Triangular(SchurF[:T])))
retmat = SchurF[:vectors]*R*SchurF[:vectors]'
cond ? (retmat, norm(R)^2/norm(SchurF[:T])) : retmat
end
Expand Down
2 changes: 1 addition & 1 deletion base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Hermitian matrices
#Symmetric and Hermitian matrices

for ty in (:Hermitian, :Symmetric)
@eval begin
Expand Down
20 changes: 20 additions & 0 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
end
end

function sqrtm{T}(A::Triangular{T})
n = size(A, 1)
R = zeros(T, n, n)
if A.uplo == 'U'
for j = 1:n
(T<:Complex || A[j,j]>=0) ? (R[j,j]=sqrt(A[j,j])) : throw(SingularException(j))
for i = j-1:-1:1
r = A[i,j]
for k = i+1:j-1
r -= R[i,k]*R[k,j]
end
r==0 || (R[i,j] = r / (R[i,i] + R[j,j]))
end
end
return Triangular(R)
else #A.uplo == 'L' #Not the usual case
return sqrtm(A.').'
end
end

#Generic solver using naive substitution
function naivesub!(A::Triangular, b::AbstractVector, x::AbstractVector=b)
N = size(A, 2)
Expand Down

0 comments on commit 9bd4618

Please sign in to comment.