Skip to content

Commit

Permalink
Fix and test solvers for non-square operators
Browse files Browse the repository at this point in the history
Fixes #414
  • Loading branch information
ChrisRackauckas committed Dec 29, 2023
1 parent 30ff6eb commit 40743e1
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ EnumX.@enumx DefaultAlgorithmChoice begin
AppleAccelerateLUFactorization
MKLLUFactorization
QRFactorizationPivoted
KrylovJL_CRAIGMR
KrylovJL_LSMR
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand Down
16 changes: 9 additions & 7 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18, T19}
T13, T14, T15, T16, T17, T18, T19, T20, T21}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -20,6 +20,8 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
QRFactorizationPivoted::T19
KrylovJL_CRAIGMR::T20
KrylovJL_LSMR::T21
end

# Legacy fallback
Expand Down Expand Up @@ -254,11 +256,11 @@ function algchoice_to_alg(alg::Symbol)
elseif alg === :AppleAccelerateLUFactorization
AppleAccelerateLUFactorization()
elseif alg === :QRFactorizationPivoted
@static if VERSION v"1.7beta"
QRFactorization(ColumnNorm())
else
QRFactorization(Val(true))
end
QRFactorization(ColumnNorm())
elseif alg === :KrylovJL_CRAIGMR
KrylovJL_CRAIGMR()
elseif alg === :KrylovJL_LSMR
KrylovJL_LSMR()
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -387,7 +389,7 @@ end
quote
getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
end
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,DefaultAlgorithmChoice.KrylovJL_LSMR, DefaultAlgorithmChoice.KrylovJL_CRAIGMR))
quote
invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
solve(invprob, cache.alg;
Expand Down
18 changes: 16 additions & 2 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
itmax = cache.maxiters
verbose = cache.verbose ? 1 : 0

args = (@get_cacheval(cache, :KrylovJL_GMRES), cache.A, cache.b)
cacheval = if cache.alg isa DefaultLinearSolver
if alg.KrylovAlg === Krylov.gmres!
@get_cacheval(cache, :KrylovJL_GMRES)
elseif alg.KrylovAlg === Krylov.craigmr!
@get_cacheval(cache, :KrylovJL_CRAIGMR)
elseif alg.KrylovAlg === Krylov.lsmr!
@get_cacheval(cache, :KrylovJL_LSMR)
else
error("Default linear solver can only be these three choices! Report this bug!")

Check warning on line 254 in src/iterative_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/iterative_wrappers.jl#L254

Added line #L254 was not covered by tests
end
else
cache.cacheval
end

args = (cacheval, cache.A, cache.b)
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
ldiv = true, history = true, alg.kwargs...)

Expand All @@ -268,7 +282,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
end

stats = @get_cacheval(cache, :KrylovJL_GMRES).stats
resid = stats.residuals |> last
resid = !isempty(stats.residuals) ? last(stats.residuals) : zero(eltype(stats.residuals))

retcode = if !stats.solved
if stats.status == "maximum number of iterations exceeded"
Expand Down
39 changes: 39 additions & 0 deletions test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,42 @@ prob = LinearProblem(sparse(A), b)

prob = LinearProblem(big.(rand(10, 10)), big.(zeros(10)))
solve(prob)

## Operator defaults
## https://github.com/SciML/LinearSolve.jl/issues/414

m, n = 2, 2
A = rand(m, n)
b = rand(m)
x = rand(n)
f = (du, u, p, t) -> mul!(du, A, u)
fadj = (du, u, p, t) -> mul!(du, A', u)
fo = FunctionOperator(f, x, b; op_adjoint = fadj)

Check warning on line 80 in test/default_algs.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"fo" should be "of" or "for" or "do" or "go" or "to".
prob = LinearProblem(fo, b)

Check warning on line 81 in test/default_algs.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"fo" should be "of" or "for" or "do" or "go" or "to".
sol1 = solve(prob)
sol2 = solve(prob, LinearSolve.KrylovJL_GMRES())
@test sol1.u == sol2.u

m, n = 3, 2
A = rand(m, n)
b = rand(m)
x = rand(n)
f = (du, u, p, t) -> mul!(du, A, u)
fadj = (du, u, p, t) -> mul!(du, A', u)
fo = FunctionOperator(f, x, b; op_adjoint = fadj)

Check warning on line 92 in test/default_algs.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"fo" should be "of" or "for" or "do" or "go" or "to".
prob = LinearProblem(fo, b)

Check warning on line 93 in test/default_algs.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"fo" should be "of" or "for" or "do" or "go" or "to".
sol1 = solve(prob)
sol2 = solve(prob, LinearSolve.KrylovJL_LSMR())
@test sol1.u == sol2.u

m, n = 2, 3
A = rand(m, n)
b = rand(m)
x = rand(n)
f = (du, u, p, t) -> mul!(du, A, u)
fadj = (du, u, p, t) -> mul!(du, A', u)
fo = FunctionOperator(f, x, b; op_adjoint = fadj)

Check warning on line 104 in test/default_algs.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"fo" should be "of" or "for" or "do" or "go" or "to".
prob = LinearProblem(fo, b)

Check warning on line 105 in test/default_algs.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"fo" should be "of" or "for" or "do" or "go" or "to".
sol1 = solve(prob)
sol2 = solve(prob, LinearSolve.KrylovJL_CRAIGMR())
@test sol1.u == sol2.u

0 comments on commit 40743e1

Please sign in to comment.