Skip to content

Commit

Permalink
Merge pull request #547 from oscardssmith/os/precsisfresh
Browse files Browse the repository at this point in the history
use isfresh like mechinisim for precs
  • Loading branch information
ChrisRackauckas authored Oct 19, 2024
2 parents d7f0fb3 + 8d33f02 commit b7e50f0
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 27 deletions.
3 changes: 2 additions & 1 deletion ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
assumptions)
Tc = typeof(cacheval)
isfresh = true
precsisfresh = false

cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions)), typeof(sensealg)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
}(A, b, u0, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end
Expand Down
6 changes: 6 additions & 0 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
end

function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
if cache.precsisfresh && !isnothing(alg.precs)
Pl, Pr = alg.precs(cache.Pl, cache.Pr)
cache.Pl = Pl
cache.Pr = Pr
cache.precsisfresh = false
end
if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable)
solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl,
cache.Pr,
Expand Down
46 changes: 20 additions & 26 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
alg::Talg
cacheval::Tc # store alg cache here
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
precsisfresh::Bool # false => PR,PL is set wrt A, true => update PR,PL wrt A
Pl::Tl # preconditioners
Pr::Tr
abstol::Ttol
Expand All @@ -85,18 +86,10 @@ end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(x, cache.p)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
setfield!(cache, :isfresh, true)
setfield!(cache, :precsisfresh, true)
elseif name === :p
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(cache.A, x)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
setfield!(cache, :precsisfresh, true)
elseif name === :b
# In case there is something that needs to be done when b is updated
update_cacheval!(cache, :b, x)
Expand Down Expand Up @@ -208,11 +201,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
precsisfresh = false
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end
Expand All @@ -223,27 +217,26 @@ function SciMLBase.reinit!(cache::LinearCache;
b = cache.b,
u = cache.u,
p = nothing,
reinit_cache = false,)
reinit_cache = false,
reuse_precs = false)
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache

precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS
Pl, Pr = if isnothing(A) || isnothing(p)
if isnothing(A)
A = cache.A
end
if isnothing(p)
p = cache.p
end
precs(A, p)
else
(cache.Pl, cache.Pr)
end
isfresh = true

isfresh = !isnothing(A)
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
isfresh |= cache.isfresh
precsisfresh |= cache.precsisfresh

A = isnothing(A) ? cache.A : A
b = isnothing(b) ? cache.b : b
u = isnothing(u) ? cache.u : u
p = isnothing(p) ? cache.p : p
Pl = cache.Pl
Pr = cache.Pr
if reinit_cache
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
else
cache.A = A
Expand All @@ -253,6 +246,7 @@ function SciMLBase.reinit!(cache::LinearCache;
cache.Pl = Pl
cache.Pr = Pr
cache.isfresh = true
cache.precsisfresh = precsisfresh
end
end

Expand Down
6 changes: 6 additions & 0 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters::Int, abstol, re
end

function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
if cache.precsisfresh && !isnothing(alg.precs)
Pl, Pr = alg.precs(cache.A, cache.p)
cache.Pl = Pl
cache.Pr = Pr
cache.precsisfresh = false
end
if cache.isfresh
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr,
cache.maxiters, cache.abstol, cache.reltol, cache.verbose,
Expand Down
24 changes: 24 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,30 @@ end
end
end

@testset "Reuse precs" begin
num_precs_calls = 0

function countingprecs(A, p = nothing)
num_precs_calls += 1
(BlockJacobiPreconditioner(A, 2), I)
end

n = 10
A = spdiagm(-1 => -ones(n - 1), 0 => fill(10.0, n), 1 => -ones(n - 1))
b = rand(n)
p = LinearProblem(A, b)
x0 = solve(p, KrylovJL_CG(precs = countingprecs, ldiv = false))
cache = x0.cache
x0 = copy(x0)
for i in 4:(n - 3)
A[i, i + 3] -= 1.0e-4
A[i - 3, i] -= 1.0e-4
end
LinearSolve.reinit!(cache; A, reuse_precs = true)
x1 = copy(solve!(cache))
@test all(x0 .< x1) && num_precs_calls == 1
end

if VERSION >= v"1.9-"
@testset "IterativeSolversJL" begin
kwargs = (; gmres_restart = 5)
Expand Down

0 comments on commit b7e50f0

Please sign in to comment.