Skip to content

Commit

Permalink
Add a fast method for diag of Cholesky matrices (#53767)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
  • Loading branch information
2 people authored and pull[bot] committed Oct 1, 2024
1 parent 95872b9 commit c123952
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
27 changes: 27 additions & 0 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,30 @@ then `CC = cholesky(C.U'C.U - v*v')` but the computation of `CC` only uses
`O(n^2)` operations.
"""
lowrankdowndate(C::Cholesky, v::AbstractVector) = lowrankdowndate!(copy(C), copy(v))

function diag(C::Cholesky{T}, k::Int = 0) where {T}
N = size(C, 1)
absk = abs(k)
iabsk = N - absk
z = Vector{T}(undef, iabsk)
UL = C.factors
if C.uplo == 'U'
for i in 1:iabsk
z[i] = zero(T)
for j in 1:min(i, i+absk)
z[i] += UL[j, i]'UL[j, i+absk]
end
end
else
for i in 1:iabsk
z[i] = zero(T)
for j in 1:min(i, i+absk)
z[i] += UL[i, j]*UL[i+absk, j]'
end
end
end
if !(T <: Real) && k < 0
z .= adjoint.(z)
end
return z
end
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@ julia> tr(A)
5
```
"""
function tr(A::AbstractMatrix)
function tr(A)
checksquare(A)
sum(diag(A))
end
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ function unary_ops_tests(a, ca, tol; n=size(a, 1))
@test_throws ErrorException ca.Z
@test size(ca) == size(a)
@test Array(copy(ca)) a
@test tr(ca) tr(a) skip=ca isa CholeskyPivoted
end

function factor_recreation_tests(a_U, a_L)
Expand Down Expand Up @@ -561,4 +562,13 @@ end
end
end

@testset "diag" begin
for T in (Float64, ComplexF64), k in (0, 1, -3), uplo in (:U, :L)
A = randn(T, 100, 100)
P = Hermitian(A' * A, uplo)
C = cholesky(P)
@test diag(P, k) diag(C, k)
end
end

end # module TestCholesky

0 comments on commit c123952

Please sign in to comment.