diff --git a/src/WoodburyMatrices.jl b/src/WoodburyMatrices.jl index 0300ee2..b4965eb 100644 --- a/src/WoodburyMatrices.jl +++ b/src/WoodburyMatrices.jl @@ -98,10 +98,23 @@ _ldiv!(dest, W, A, B) = _ldiv!(dest, W, defaultfactor(W, A), B) defaultfactor(::AbstractWoodbury, A) = lu(A) -det(W::AbstractWoodbury) = det(W.A)*det(W.C)/det(W.Cp) -logdet(W::AbstractWoodbury) = logdet(W.A) + logdet(W.C) - logdet(W.Cp) +function det(W::AbstractWoodbury) + ret = det(W.A) + if !isempty(W.C) + ret *= det(W.C)/det(W.Cp) + end + return ret +end +function logdet(W::AbstractWoodbury) + ret = logdet(W.A) + if !isempty(W.C) + ret += logdet(W.C) - logdet(W.Cp) + end + return ret +end function logabsdet(W::AbstractWoodbury) lad_A = logabsdet(W.A) + isempty(W.C) && return lad_A lad_C = logabsdet(W.C) lad_Cp = logabsdet(W.Cp) lad = lad_A[1] + lad_C[1] - lad_Cp[1] diff --git a/test/symwoodbury.jl b/test/symwoodbury.jl index 0706f3c..ae76fef 100644 --- a/test/symwoodbury.jl +++ b/test/symwoodbury.jl @@ -226,4 +226,16 @@ W = SymWoodbury([randpsd(50) for _ in 1:3]...) @test W \ [13, 14, 15, 16] ≈ Matrix(W) \ [13, 14, 15, 16] end +@testset "Empty B" begin + A = Diagonal( Float64[1,2,3,4] ) + B = Matrix{Float64}(undef, 4, 0) + D = Diagonal(Float64[]) + + W = SymWoodbury( A, B, D) + @test W \ [13, 14, 15, 16] ≈ Matrix(W) \ [13, 14, 15, 16] + @test det(W) ≈ det(Matrix(W)) + @test logdet(W) ≈ logdet(Matrix(W)) + @test all(logabsdet(W) .≈ logabsdet(Matrix(W))) +end + end # @testset "SymWoodbury" diff --git a/test/woodbury.jl b/test/woodbury.jl index 8fd49fc..eeea2ee 100644 --- a/test/woodbury.jl +++ b/test/woodbury.jl @@ -184,4 +184,16 @@ W = Woodbury([randpsd(50) for _ in 1:4]...) @test logdet(W) ≈ log(det(W)) ≈ logdet(Array(W)) @test all(logabsdet(W) .≈ logabsdet(Array(W))) +@testset "Empty U" begin + A = Diagonal( Float64[1,2,3,4] ) + B = Matrix{Float64}(undef, 4, 0) + D = Diagonal(Float64[]) + + W = Woodbury( A, B, D, B') + @test W \ [13, 14, 15, 16] ≈ Matrix(W) \ [13, 14, 15, 16] + @test det(W) ≈ det(Matrix(W)) + @test logdet(W) ≈ logdet(Matrix(W)) + @test all(logabsdet(W) .≈ logabsdet(Matrix(W))) +end + end # @testset "Woodbury"