Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use isfresh like mechinisim for precs #547

Merged
merged 5 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
assumptions)
Tc = typeof(cacheval)
isfresh = true
precsisfresh = false

cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions)), typeof(sensealg)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
}(A, b, u0, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end
Expand Down
6 changes: 6 additions & 0 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
end

function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
if cache.precsisfresh && !isnothing(alg.precs)
Pl, Pr = alg.precs(cache.Pl, cache.Pr)
cache.Pl = Pl
cache.Pr = Pr
cache.precsisfresh = false
end
if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable)
solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl,
cache.Pr,
Expand Down
46 changes: 20 additions & 26 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
alg::Talg
cacheval::Tc # store alg cache here
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
precsisfresh::Bool # false => PR,PL is set wrt A, true => update PR,PL wrt A
Pl::Tl # preconditioners
Pr::Tr
abstol::Ttol
Expand All @@ -85,18 +86,10 @@ end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(x, cache.p)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
setfield!(cache, :isfresh, true)
setfield!(cache, :precsisfresh, true)
elseif name === :p
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(cache.A, x)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
setfield!(cache, :precsisfresh, true)
elseif name === :b
# In case there is something that needs to be done when b is updated
update_cacheval!(cache, :b, x)
Expand Down Expand Up @@ -208,11 +201,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
precsisfresh = false
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end
Expand All @@ -223,27 +217,26 @@ function SciMLBase.reinit!(cache::LinearCache;
b = cache.b,
u = cache.u,
p = nothing,
reinit_cache = false,)
reinit_cache = false,
reuse_precs = false)
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache

precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS
Pl, Pr = if isnothing(A) || isnothing(p)
if isnothing(A)
A = cache.A
end
if isnothing(p)
p = cache.p
end
precs(A, p)
else
(cache.Pl, cache.Pr)
end
isfresh = true

isfresh = !isnothing(A)
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
isfresh |= cache.isfresh
precsisfresh |= cache.precsisfresh

A = isnothing(A) ? cache.A : A
b = isnothing(b) ? cache.b : b
u = isnothing(u) ? cache.u : u
p = isnothing(p) ? cache.p : p
Pl = cache.Pl
Pr = cache.Pr
if reinit_cache
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
else
cache.A = A
Expand All @@ -253,6 +246,7 @@ function SciMLBase.reinit!(cache::LinearCache;
cache.Pl = Pl
cache.Pr = Pr
cache.isfresh = true
cache.precsisfresh = precsisfresh
end
end

Expand Down
6 changes: 6 additions & 0 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters::Int, abstol, re
end

function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
if cache.precsisfresh && !isnothing(alg.precs)
Pl, Pr = alg.precs(cache.A, cache.p)
cache.Pl = Pl
cache.Pr = Pr
cache.precsisfresh = false
end
if cache.isfresh
solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr,
cache.maxiters, cache.abstol, cache.reltol, cache.verbose,
Expand Down
24 changes: 24 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,30 @@ end
end
end

@testset "Reuse precs" begin
num_precs_calls = 0

function countingprecs(A, p = nothing)
num_precs_calls += 1
(BlockJacobiPreconditioner(A, 2), I)
end

n = 10
A = spdiagm(-1 => -ones(n - 1), 0 => fill(10.0, n), 1 => -ones(n - 1))
b = rand(n)
p = LinearProblem(A, b)
x0 = solve(p, KrylovJL_CG(precs = countingprecs, ldiv = false))
cache = x0.cache
x0 = copy(x0)
for i in 4:(n - 3)
A[i, i + 3] -= 1.0e-4
A[i - 3, i] -= 1.0e-4
end
LinearSolve.reinit!(cache; A, reuse_precs = true)
x1 = copy(solve!(cache))
@test all(x0 .< x1) && num_precs_calls == 1
end

if VERSION >= v"1.9-"
@testset "IterativeSolversJL" begin
kwargs = (; gmres_restart = 5)
Expand Down
Loading