From 40743e101edc7abbaa90d066e5cb82ae7409c969 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 29 Dec 2023 00:40:17 -0500 Subject: [PATCH] Fix and test solvers for non-square operators Fixes https://github.com/SciML/LinearSolve.jl/issues/414 --- src/LinearSolve.jl | 2 ++ src/default.jl | 16 +++++++++------- src/iterative_wrappers.jl | 18 ++++++++++++++++-- test/default_algs.jl | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index fc039d13f..74f0efa7d 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -103,6 +103,8 @@ EnumX.@enumx DefaultAlgorithmChoice begin AppleAccelerateLUFactorization MKLLUFactorization QRFactorizationPivoted + KrylovJL_CRAIGMR + KrylovJL_LSMR end struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm diff --git a/src/default.jl b/src/default.jl index 0d2daf19c..8972b2a38 100644 --- a/src/default.jl +++ b/src/default.jl @@ -1,6 +1,6 @@ needs_concrete_A(alg::DefaultLinearSolver) = true mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, - T13, T14, T15, T16, T17, T18, T19} + T13, T14, T15, T16, T17, T18, T19, T20, T21} LUFactorization::T1 QRFactorization::T2 DiagonalFactorization::T3 @@ -20,6 +20,8 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, AppleAccelerateLUFactorization::T17 MKLLUFactorization::T18 QRFactorizationPivoted::T19 + KrylovJL_CRAIGMR::T20 + KrylovJL_LSMR::T21 end # Legacy fallback @@ -254,11 +256,11 @@ function algchoice_to_alg(alg::Symbol) elseif alg === :AppleAccelerateLUFactorization AppleAccelerateLUFactorization() elseif alg === :QRFactorizationPivoted - @static if VERSION ≥ v"1.7beta" - QRFactorization(ColumnNorm()) - else - QRFactorization(Val(true)) - end + QRFactorization(ColumnNorm()) + elseif alg === :KrylovJL_CRAIGMR + KrylovJL_CRAIGMR() + elseif alg === :KrylovJL_LSMR + KrylovJL_LSMR() else error("Algorithm choice symbol $alg not allowed in the default") end @@ -387,7 +389,7 @@ end quote getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy end - elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,)) + elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,DefaultAlgorithmChoice.KrylovJL_LSMR, DefaultAlgorithmChoice.KrylovJL_CRAIGMR)) quote invprob = LinearSolve.LinearProblem(transpose(cache.A), dy) solve(invprob, cache.alg; diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 294cfe7f1..319d373e2 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -243,7 +243,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) itmax = cache.maxiters verbose = cache.verbose ? 1 : 0 - args = (@get_cacheval(cache, :KrylovJL_GMRES), cache.A, cache.b) + cacheval = if cache.alg isa DefaultLinearSolver + if alg.KrylovAlg === Krylov.gmres! + @get_cacheval(cache, :KrylovJL_GMRES) + elseif alg.KrylovAlg === Krylov.craigmr! + @get_cacheval(cache, :KrylovJL_CRAIGMR) + elseif alg.KrylovAlg === Krylov.lsmr! + @get_cacheval(cache, :KrylovJL_LSMR) + else + error("Default linear solver can only be these three choices! Report this bug!") + end + else + cache.cacheval + end + + args = (cacheval, cache.A, cache.b) kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose, ldiv = true, history = true, alg.kwargs...) @@ -268,7 +282,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) end stats = @get_cacheval(cache, :KrylovJL_GMRES).stats - resid = stats.residuals |> last + resid = !isempty(stats.residuals) ? last(stats.residuals) : zero(eltype(stats.residuals)) retcode = if !stats.solved if stats.status == "maximum number of iterations exceeded" diff --git a/test/default_algs.jl b/test/default_algs.jl index 11d536136..5aad1a30c 100644 --- a/test/default_algs.jl +++ b/test/default_algs.jl @@ -67,3 +67,42 @@ prob = LinearProblem(sparse(A), b) prob = LinearProblem(big.(rand(10, 10)), big.(zeros(10))) solve(prob) + +## Operator defaults +## https://github.com/SciML/LinearSolve.jl/issues/414 + +m, n = 2, 2 +A = rand(m, n) +b = rand(m) +x = rand(n) +f = (du, u, p, t) -> mul!(du, A, u) +fadj = (du, u, p, t) -> mul!(du, A', u) +fo = FunctionOperator(f, x, b; op_adjoint = fadj) +prob = LinearProblem(fo, b) +sol1 = solve(prob) +sol2 = solve(prob, LinearSolve.KrylovJL_GMRES()) +@test sol1.u == sol2.u + +m, n = 3, 2 +A = rand(m, n) +b = rand(m) +x = rand(n) +f = (du, u, p, t) -> mul!(du, A, u) +fadj = (du, u, p, t) -> mul!(du, A', u) +fo = FunctionOperator(f, x, b; op_adjoint = fadj) +prob = LinearProblem(fo, b) +sol1 = solve(prob) +sol2 = solve(prob, LinearSolve.KrylovJL_LSMR()) +@test sol1.u == sol2.u + +m, n = 2, 3 +A = rand(m, n) +b = rand(m) +x = rand(n) +f = (du, u, p, t) -> mul!(du, A, u) +fadj = (du, u, p, t) -> mul!(du, A', u) +fo = FunctionOperator(f, x, b; op_adjoint = fadj) +prob = LinearProblem(fo, b) +sol1 = solve(prob) +sol2 = solve(prob, LinearSolve.KrylovJL_CRAIGMR()) +@test sol1.u == sol2.u