From 0fa3fc542454074850adabfda71de7bb8d22da4c Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Tue, 30 Mar 2021 00:03:33 -0400 Subject: [PATCH] Add a normal equations solver --- src/KKT/cholmod.jl | 59 ++++++++++++++++++++++++++++++--------------- test/KKT/cholmod.jl | 7 +++++- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/src/KKT/cholmod.jl b/src/KKT/cholmod.jl index 0a3e9957..d41dcf90 100644 --- a/src/KKT/cholmod.jl +++ b/src/KKT/cholmod.jl @@ -210,22 +210,25 @@ mutable struct CholmodSPD <: CholmodSolver # Regularization regP::Vector{Float64} # primal regD::Vector{Float64} # dual + variant::Bool # Factorization F::CHOLMOD.Factor{Float64} # Constructor and initial memory allocation # TODO: symbolic only + allocation - function CholmodSPD(A::AbstractMatrix{Float64}) + function CholmodSPD(A::AbstractMatrix{Float64}; variant::Bool=false) m, n = size(A) - θ = ones(Float64, n) + θ = ones(n) + regP = zeros(n) + regD = ones(m) - S = sparse(A * A') + spdiagm(0 => ones(m)) + S = !variant ? sparse(A * A') + spdiagm(0 => ones(m)) : sparse(A' * A) + spdiagm(0 => ones(n)) # TODO: PSD-ness checks F = cholesky(Symmetric(S)) - return new(m, n, A, θ, zeros(Float64, n), ones(Float64, m), F) + return new(m, n, A, θ, regP, regD, variant, F) end end @@ -260,11 +263,15 @@ function update!( kkt.regP .= regP # Primal regularization is disabled for normal equations kkt.regD .= regD - # Re-compute factorization - # D = (Θ^{-1} + Rp)^{-1} - D = Diagonal(one(Float64) ./ (kkt.θ .+ kkt.regP)) - Rd = spdiagm(0 => kkt.regD) - S = kkt.A * D * kkt.A' + Rd + if !kkt.variant + M⁻¹ = Diagonal(1.0 ./ (kkt.θ .+ kkt.regP)) + N = spdiagm(0 => kkt.regD) + S = kkt.A * M⁻¹ * kkt.A' + N + else + N⁻¹ = Diagonal(1.0 ./ kkt.regD) + M = spdiagm(0 => kkt.θ .+ kkt.regP) + S = kkt.A' * N⁻¹ * kkt.A + M + end # Update factorization cholesky!(kkt.F, Symmetric(S), check=false) @@ -284,18 +291,32 @@ function solve!( ξp::Vector{Float64}, ξd::Vector{Float64} ) m, n = kkt.m, kkt.n - - d = one(Float64) ./ (kkt.θ .+ kkt.regP) - D = Diagonal(d) - # Set-up right-hand side - ξ_ = ξp .+ kkt.A * (D * ξd) + if !kkt.variant + # Compute M⁻¹ + M⁻¹ = Diagonal(1.0 ./ (kkt.θ .+ kkt.regP)) - # Solve augmented system - dy .= (kkt.F \ ξ_) + # Set-up right-hand side + ξ_ = ξp .+ kkt.A * (M⁻¹ * ξd) + + # Solve augmented system + dy .= (kkt.F \ ξ_) - # Recover dx - dx .= D * (kkt.A' * dy - ξd) + # Recover dx + dx .= M⁻¹ * (kkt.A' * dy - ξd) + else + # Compute N⁻¹ + N⁻¹ = Diagonal(1.0 ./ kkt.regD) + + # Set-up right-hand side + ξ_ = kkt.A' * (N⁻¹ * ξp) .- ξd + + # Solve augmented system + dx .= (kkt.F \ ξ_) + + # Recover dx + dy .= N⁻¹ * (ξp - kkt.A * dx) + end # TODO: Iterative refinement # * Max number of refine steps @@ -305,6 +326,6 @@ function solve!( # resD = - D \ dx + kkt.A' * dy - ξd # println("\n|resP| = $(norm(resP, Inf))\n|resD| = $(norm(resD, Inf))") - + return nothing end diff --git a/test/KKT/cholmod.jl b/test/KKT/cholmod.jl index 633fcfa3..4129c262 100644 --- a/test/KKT/cholmod.jl +++ b/test/KKT/cholmod.jl @@ -11,7 +11,12 @@ end @testset "Cholesky" begin - kkt = KKT.CholmodSolver(A, normal_equations=true) + kkt = KKT.CholmodSolver(A, normal_equations=true, variant=false) + KKT.run_ls_tests(A, kkt) + end + + @testset "Cholesky" begin + kkt = KKT.CholmodSolver(A, normal_equations=true, variant=true) KKT.run_ls_tests(A, kkt) end