diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 9c7935911..0f93ea574 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,3 @@ style = "sciml" -format_markdown = true \ No newline at end of file +format_markdown = true +annotate_untyped_fields_with_any = false \ No newline at end of file diff --git a/.gitignore b/.gitignore index e454bf595..1b6ed4dea 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ Manifest.toml *.swp +.vscode +wip \ No newline at end of file diff --git a/Project.toml b/Project.toml index 57f017c8b..7e84a65a8 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.6.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" @@ -28,18 +29,22 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] +LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" +LinearSolveKernelAbstractionsExt = "KernelAbstractions" LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMKLExt = "MKL_jll" LinearSolveMetalExt = "Metal" @@ -47,6 +52,7 @@ LinearSolvePardisoExt = "Pardiso" [compat] ArrayInterface = "7.4.11" +BlockDiagonals = "0.1" DocStringExtensions = "0.8, 0.9" EnumX = "1" FastLapackInterface = "1, 2" @@ -54,6 +60,7 @@ GPUArraysCore = "0.1" HYPRE = "1.4.0" IterativeSolvers = "0.9.2" KLU = "0.3.0, 0.4" +KernelAbstractions = "0.9" Krylov = "0.9" KrylovKit = "0.5, 0.6" PrecompileTools = "1" @@ -69,15 +76,17 @@ UnPack = "1" julia = "1.6" [extras] +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -85,4 +94,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"] diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index f7e52c2a7..fe8a8b6e4 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -72,6 +72,14 @@ choice of Krylov method should be the one most constrained to the type of operat has, for example if positive definite then `Krylov_CG()`, but if no good properties then use `Krylov_GMRES()`. +!!! tip + + If your materialized operator is a uniform block diagonal matrix, then you can use + `SimpleGMRES(; blocksize = )` to further improve performance. + This often shows up in Neural Networks where the Jacobian wrt the Inputs (almost always) + is a Uniform Block Diagonal matrix of Block Size = size of the input divided by the + batch size. + ## Full List of Methods ### RecursiveFactorization.jl @@ -106,6 +114,7 @@ LinearSolve.jl contains some linear solvers built in for specailized cases. ```@docs SimpleLUFactorization DiagonalFactorization +SimpleGMRES ``` ### FastLapackInterface.jl diff --git a/ext/LinearSolveBlockDiagonalsExt.jl b/ext/LinearSolveBlockDiagonalsExt.jl new file mode 100644 index 000000000..1e9b053eb --- /dev/null +++ b/ext/LinearSolveBlockDiagonalsExt.jl @@ -0,0 +1,24 @@ +module LinearSolveBlockDiagonalsExt + +using LinearSolve, BlockDiagonals + +function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...; + kwargs...) + @assert ndims(A) == 2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2." + # We need to perform this check even when `zeroinit == true`, since the type of the + # cache is dependent on whether we are able to use the specialized dispatch. + bsizes = blocksizes(A) + usize = first(first(bsizes)) + uniform_blocks = true + for bsize in bsizes + if bsize[1] != usize || bsize[2] != usize + uniform_blocks = false + break + end + end + # Can't help but perform dynamic dispatch here + return LinearSolve._init_cacheval(Val(uniform_blocks), alg, A, b, args...; + blocksize = usize, kwargs...) +end + +end diff --git a/ext/LinearSolveKernelAbstractionsExt.jl b/ext/LinearSolveKernelAbstractionsExt.jl new file mode 100644 index 000000000..ba620382f --- /dev/null +++ b/ext/LinearSolveKernelAbstractionsExt.jl @@ -0,0 +1,24 @@ +module LinearSolveKernelAbstractionsExt + +using LinearSolve, KernelAbstractions + +LinearSolve.__is_extension_loaded(::Val{:KernelAbstractions}) = true + +using GPUArraysCore + +function LinearSolve._fast_sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis) + backend = get_backend(Hbis) + kernel! = __fast_sym_givens_kernel!(backend) + kernel!(c[inner_iter], s[inner_iter], R[nr + inner_iter], Hbis; ndrange=bsize) + return c, s, R +end + +@kernel function __fast_sym_givens_kernel!(c, s, R, @Const(Hbis)) + idx = @index(Global) + @inbounds _c, _s, _ρ = LinearSolve._sym_givens(R[idx], Hbis[idx]) + @inbounds c[idx] = _c + @inbounds s[idx] = _s + @inbounds R[idx] = _ρ +end + +end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index fb82c7535..644c86288 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -28,15 +28,16 @@ PrecompileTools.@recompile_invalidations begin import InteractiveUtils using LinearAlgebra: BlasInt, LU - using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, + using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, @blasfunc, chkargsok import GPUArraysCore import Preferences + import ConcreteStructs: @concrete # wrap import Krylov - + using SciMLBase end @@ -62,6 +63,11 @@ _isidentity_struct(λ::Number) = isone(λ) _isidentity_struct(A::UniformScaling) = isone(A.λ) _isidentity_struct(::SciMLOperators.IdentityOperator) = true +# Dispatch Friendly way to check if an extension is loaded +__is_extension_loaded(::Val) = false + +function _fast_sym_givens! end + # Code const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS) @@ -92,6 +98,7 @@ end include("common.jl") include("factorization.jl") include("simplelu.jl") +include("simplegmres.jl") include("iterative_wrappers.jl") include("preconditioners.jl") include("solve_function.jl") @@ -176,6 +183,8 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES, IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES, KrylovKitJL, KrylovKitJL_CG, KrylovKitJL_GMRES +export SimpleGMRES + export HYPREAlgorithm export CudaOffloadFactorization export MKLPardisoFactorize, MKLPardisoIterate diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 402c71609..b37571cb5 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -253,7 +253,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) Krylov.solve!(args...; M = M, kwargs...) elseif cache.cacheval isa Krylov.GmresSolver - Krylov.solve!(args...; M = M, N = N, + Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0, kwargs...) elseif cache.cacheval isa Krylov.BicgstabSolver Krylov.solve!(args...; M = M, N = N, diff --git a/src/simplegmres.jl b/src/simplegmres.jl new file mode 100644 index 000000000..924cdaeea --- /dev/null +++ b/src/simplegmres.jl @@ -0,0 +1,600 @@ +""" + SimpleGMRES(; restart::Bool = true, blocksize::Int = 0, warm_start::Bool = false, + memory::Int = 20) + +A simple GMRES implementation for square non-Hermitian linear systems. + +This implementation handles Block Diagonal Matrices with Uniformly Sized Square Blocks with +specialized dispatches. + +## Arguments + +* `restart::Bool`: If `true`, then the solver will restart after `memory` iterations. +* `memory::Int = 20`: The number of iterations before restarting. If restart is false, this + value is used to allocate memory and later expanded if more memory is required. +* `blocksize::Int = 0`: If blocksize is `> 0`, the solver assumes that the matrix has a + uniformly sized block diagonal structure with square blocks of size `blocksize`. Misusing + this option will lead to incorrect results. + * If this is set `≤ 0` and during runtime we get a Block Diagonal Matrix, then we will + check if the specialized dispatch can be used. + +!!! warning + + Most users should be using the `KrylovJL_GMRES` solver instead of this implementation. + +!!! tip + + We can automatically detect if the matrix is a Block Diagonal Matrix with Uniformly + Sized Square Blocks. If this is the case, then we can use a specialized dispatch. + However, on most modern systems performing a single matrix-vector multiplication is + faster than performing multiple smaller matrix-vector multiplications (as in the case + of Block Diagonal Matrix). We recommend making the matrix dense (if size permits) and + specifying the `blocksize` argument. +""" +struct SimpleGMRES{UBD} <: AbstractKrylovSubspaceMethod + restart::Bool + memory::Int + blocksize::Int + warm_start::Bool + + function SimpleGMRES(; restart::Bool = true, blocksize::Int = 0, + warm_start::Bool = false, memory::Int = 20) + return new{blocksize > 0}(restart, memory, blocksize, warm_start) + end +end + +@concrete mutable struct SimpleGMRESCache{UBD} + memory::Int + n::Int + restart::Bool + maxiters::Int + blocksize::Int + ε + PlisI::Bool + PrisI::Bool + Pl + Pr + Δx + q + p + x + A + b + abstol + reltol + w + V + s + c + z + R + β + warm_start::Bool +end + +""" + (c, s, ρ) = _sym_givens(a, b) + +Numerically stable symmetric Givens reflection. +Given `a` and `b` reals, return `(c, s, ρ)` such that + + [ c s ] [ a ] = [ ρ ] + [ s -c ] [ b ] = [ 0 ]. +""" +function _sym_givens(a::T, b::T) where {T <: AbstractFloat} + # This has taken from Krylov.jl + if b == 0 + c = ifelse(a == 0, one(T), sign(a)) # In Julia, sign(0) = 0. + s = zero(T) + ρ = abs(a) + elseif a == 0 + c = zero(T) + s = sign(b) + ρ = abs(b) + elseif abs(b) > abs(a) + t = a / b + s = sign(b) / sqrt(one(T) + t * t) + c = s * t + ρ = b / s # Computationally better than ρ = a / c since |c| ≤ |s|. + else + t = b / a + c = sign(a) / sqrt(one(T) + t * t) + s = c * t + ρ = a / c # Computationally better than ρ = b / s since |s| ≤ |c| + end + return (c, s, ρ) +end + +function _sym_givens!(c, s, R, nr::Int, inner_iter::Int, bsize::Int, Hbis) + if __is_extension_loaded(Val(:KernelAbstractions)) + return _fast_sym_givens!(c, s, R, nr, inner_iter, bsize, Hbis) + end + __res = _sym_givens.(R[nr + inner_iter], Hbis) + GPUArraysCore.@allowscalar foreach(1:bsize) do i + c[inner_iter][i] = __res[i][1] + s[inner_iter][i] = __res[i][2] + R[nr + inner_iter][i] = __res[i][3] + end + return c, s, R +end + +_no_preconditioner(::Nothing) = true +_no_preconditioner(::IdentityOperator) = true +_no_preconditioner(::UniformScaling) = true +_no_preconditioner(_) = false + +_norm2(x) = norm(x, 2) +_norm2(x, dims) = .√(sum(abs2, x; dims)) + +default_alias_A(::SimpleGMRES, ::Any, ::Any) = false +default_alias_b(::SimpleGMRES, ::Any, ::Any) = false + +function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) + if cache.isfresh + solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, + cache.maxiters, cache.abstol, cache.reltol, cache.verbose, + cache.assumptions; zeroinit = false) + cache.cacheval = solver + cache.isfresh = false + end + return SciMLBase.solve!(cache.cacheval, cache) +end + +function init_cacheval(alg::SimpleGMRES{UDB}, args...; kwargs...) where {UDB} + return _init_cacheval(Val(UDB), alg, args...; kwargs...) +end + +function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, ::Bool, ::OperatorAssumptions; zeroinit = true, kwargs...) + @unpack memory, restart, blocksize, warm_start = alg + + if zeroinit + return SimpleGMRESCache{false}(memory, 0, restart, maxiters, blocksize, + zero(eltype(u)) * reltol + abstol, false, false, Pl, Pr, similar(u, 0), + similar(u, 0), similar(u, 0), u, A, b, abstol, reltol, similar(u, 0), + Vector{typeof(u)}(undef, 0), Vector{eltype(u)}(undef, 0), + Vector{eltype(u)}(undef, 0), Vector{eltype(u)}(undef, 0), + Vector{eltype(u)}(undef, 0), zero(eltype(u)), warm_start) + end + + T = eltype(u) + n = LinearAlgebra.checksquare(A) + @assert n==length(b) "The size of `A` and `b` must match." + memory = min(memory, maxiters) + + PlisI = _no_preconditioner(Pl) + PrisI = _no_preconditioner(Pr) + + Δx = restart ? similar(u, n) : similar(u, 0) + q = PlisI ? similar(u, 0) : similar(u, n) + p = PrisI ? similar(u, 0) : similar(u, n) + x = u + x .= zero(T) + + w = similar(u, n) + V = [similar(u) for _ in 1:memory] + s = Vector{eltype(x)}(undef, memory) + c = Vector{eltype(x)}(undef, memory) + + z = Vector{eltype(x)}(undef, memory) + R = Vector{eltype(x)}(undef, (memory * (memory + 1)) ÷ 2) + + q = PlisI ? w : q + r₀ = PlisI ? w : q + + # Initial residual r₀. + if warm_start + mul!(w, A, Δx) + axpby!(one(T), b, -one(T), w) + restart && axpy!(one(T), Δx, x) + else + w .= b + end + PlisI || mul!(r₀, Pl, w) # r₀ = Pl(b - Ax₀) + β = _norm2(r₀) # β = ‖r₀‖₂ + + rNorm = β + ε = abstol + reltol * rNorm + + return SimpleGMRESCache{false}(memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, + Pl, Pr, Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start) +end + +function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) + @unpack memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, Pl, Pr = cache + @unpack Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start = cache + + T = eltype(x) + q = PlisI ? w : q + r₀ = PlisI ? w : q + xr = restart ? Δx : x + + if β == 0 + return SciMLBase.build_linear_solution(lincache.alg, x, r₀, lincache; + retcode = ReturnCode.Success) + end + + rNorm = β + npass = 0 # Number of pass + + iter = 0 # Cumulative number of iterations + inner_iter = 0 # Number of iterations in a pass + + # Tolerance for breakdown detection. + btol = eps(T)^(3 / 4) + + # Stopping criterion + breakdown = false + inconsistent = false + solved = rNorm ≤ ε + inner_maxiters = maxiters + tired = iter ≥ maxiters + inner_tired = inner_iter ≥ inner_maxiters + status = ReturnCode.Default + + while !(solved || tired || breakdown) + # Initialize workspace. + nr = 0 # Number of coefficients stored in Rₖ. + + if restart + xr .= zero(T) # xr === Δx when restart is set to true + if npass ≥ 1 + mul!(w, A, x) + axpby!(one(T), b, -one(T), w) + PlisI || ldiv!(r₀, Pl, w) + end + end + + # Initial ζ₁ and V₁ + β = _norm2(r₀) + z[1] = β + V[1] .= r₀ / β + + npass = npass + 1 + inner_iter = 0 + inner_tired = false + + while !(solved || inner_tired || breakdown) + # Update iteration index + inner_iter += 1 + # Update workspace if more storage is required and restart is set to false + if !restart && (inner_iter > memory) + append!(R, zeros(T, inner_iter)) + push!(s, zero(T)) + push!(c, zero(T)) + end + + # Continue the Arnoldi process. + p = PrisI ? V[inner_iter] : p + PrisI || ldiv!(p, Pr, V[inner_iter]) # p ← Nvₖ + mul!(w, A, p) # w ← ANvₖ + PlisI || ldiv!(q, Pl, w) # q ← MANvₖ + for i in 1:inner_iter + R[nr + i] = dot(V[i], q) # hᵢₖ = (vᵢ)ᴴq + axpy!(-R[nr + i], V[i], q) # q ← q - hᵢₖvᵢ + end + + # Compute hₖ₊₁.ₖ + Hbis = _norm2(q) # hₖ₊₁.ₖ = ‖vₖ₊₁‖₂ + + # Update the QR factorization of Hₖ₊₁.ₖ. + # Apply previous Givens reflections Ωᵢ. + # [cᵢ sᵢ] [ r̄ᵢ.ₖ ] = [ rᵢ.ₖ ] + # [s̄ᵢ -cᵢ] [rᵢ₊₁.ₖ] [r̄ᵢ₊₁.ₖ] + for i in 1:(inner_iter - 1) + Rtmp = c[i] * R[nr + i] + s[i] * R[nr + i + 1] + R[nr + i + 1] = conj(s[i]) * R[nr + i] - c[i] * R[nr + i + 1] + R[nr + i] = Rtmp + end + + # Compute and apply current Givens reflection Ωₖ. + # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] + # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] + (c[inner_iter], s[inner_iter], R[nr + inner_iter]) = _sym_givens(R[nr + inner_iter], + Hbis) + + # Update zₖ = (Qₖ)ᴴβe₁ + ζₖ₊₁ = conj(s[inner_iter]) * z[inner_iter] + z[inner_iter] = c[inner_iter] * z[inner_iter] + + # Update residual norm estimate. + # ‖ Pl(b - Axₖ) ‖₂ = |ζₖ₊₁| + rNorm = abs(ζₖ₊₁) + + # Update the number of coefficients in Rₖ + nr = nr + inner_iter + + # Stopping conditions that do not depend on user input. + # This is to guard against tolerances that are unreasonably small. + resid_decrease_mach = (rNorm + one(T) ≤ one(T)) + + # Update stopping criterion. + resid_decrease_lim = rNorm ≤ ε + breakdown = Hbis ≤ btol + solved = resid_decrease_lim || resid_decrease_mach + inner_tired = restart ? inner_iter ≥ min(memory, inner_maxiters) : + inner_iter ≥ inner_maxiters + + # Compute vₖ₊₁. + if !(solved || inner_tired || breakdown) + if !restart && (inner_iter ≥ memory) + push!(V, similar(first(V))) + push!(z, zero(T)) + end + @. V[inner_iter + 1] = q / Hbis # hₖ₊₁.ₖvₖ₊₁ = q + z[inner_iter + 1] = ζₖ₊₁ + end + end + + # Compute yₖ by solving Rₖyₖ = zₖ with backward substitution. + y = z # yᵢ = zᵢ + for i in inner_iter:-1:1 + pos = nr + i - inner_iter # position of rᵢ.ₖ + for j in inner_iter:-1:(i + 1) + y[i] = y[i] - R[pos] * y[j] # yᵢ ← yᵢ - rᵢⱼyⱼ + pos = pos - j + 1 # position of rᵢ.ⱼ₋₁ + end + # Rₖ can be singular if the system is inconsistent + if abs(R[pos]) ≤ btol + y[i] = zero(T) + inconsistent = true + else + y[i] = y[i] / R[pos] # yᵢ ← yᵢ / rᵢᵢ + end + end + + # Form xₖ = NVₖyₖ + for i in 1:inner_iter + axpy!(y[i], V[i], xr) + end + if !PrisI + p .= xr + ldiv!(xr, Pr, p) + end + restart && axpy!(one(T), xr, x) + + # Update inner_itmax, iter, tired and overtimed variables. + inner_maxiters = inner_maxiters - inner_iter + iter = iter + inner_iter + tired = iter ≥ maxiters + end + + # Termination status + tired && (status = ReturnCode.MaxIters) + solved && (status = ReturnCode.Success) + inconsistent && (status = ReturnCode.Infeasible) + + # Update x + warm_start && !restart && axpy!(one(T), Δx, x) + cache.warm_start = false + + return SciMLBase.build_linear_solution(lincache.alg, x, rNorm, lincache; + retcode = status, iters = iter) +end + +function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, ::Bool, ::OperatorAssumptions; zeroinit = true, + blocksize = alg.blocksize) + @unpack memory, restart, warm_start = alg + + if zeroinit + return SimpleGMRESCache{true}(memory, 0, restart, maxiters, blocksize, + zero(eltype(u)) * reltol + abstol, false, false, Pl, Pr, similar(u, 0), + similar(u, 0), similar(u, 0), u, A, b, abstol, reltol, similar(u, 0), + [u], [u], [u], [u], [u], zero(eltype(u)), warm_start) + end + + T = eltype(u) + n = LinearAlgebra.checksquare(A) + @assert mod(n, blocksize)==0 "The blocksize must divide the size of the matrix." + @assert n==length(b) "The size of `A` and `b` must match." + memory = min(memory, maxiters) + bsize = n ÷ blocksize + + PlisI = _no_preconditioner(Pl) + PrisI = _no_preconditioner(Pr) + + Δx = restart ? similar(u, n) : similar(u, 0) + q = PlisI ? similar(u, 0) : similar(u, n) + p = PrisI ? similar(u, 0) : similar(u, n) + x = u + x .= zero(T) + + w = similar(u, n) + V = [similar(u) for _ in 1:memory] + s = [similar(u, bsize) for _ in 1:memory] + c = [similar(u, bsize) for _ in 1:memory] + + z = [similar(u, bsize) for _ in 1:memory] + R = [similar(u, bsize) for _ in 1:((memory * (memory + 1)) ÷ 2)] + + q = PlisI ? w : q + r₀ = PlisI ? w : q + + # Initial residual r₀. + if warm_start + mul!(w, A, Δx) + axpby!(one(T), b, -one(T), w) + restart && axpy!(one(T), Δx, x) + else + w .= b + end + PlisI || ldiv!(r₀, Pl, w) # r₀ = Pl(b - Ax₀) + β = _norm2(r₀) # β = ‖r₀‖₂ + + rNorm = β + ε = abstol + reltol * rNorm + + return SimpleGMRESCache{true}(memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, + Pl, Pr, Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start) +end + +function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) + @unpack memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, Pl, Pr = cache + @unpack Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start = cache + bsize = n ÷ blocksize + + __batch = Base.Fix2(reshape, (blocksize, bsize)) + + T = eltype(x) + q = PlisI ? w : q + r₀ = PlisI ? w : q + xr = restart ? Δx : x + + if β == 0 + return SciMLBase.build_linear_solution(lincache.alg, x, r₀, lincache; + retcode = ReturnCode.Success) + end + + rNorm = β + npass = 0 # Number of pass + + iter = 0 # Cumulative number of iterations + inner_iter = 0 # Number of iterations in a pass + + # Tolerance for breakdown detection. + btol = eps(T)^(3 / 4) + + # Stopping criterion + breakdown = false + inconsistent = false + solved = rNorm ≤ ε + inner_maxiters = maxiters + tired = iter ≥ maxiters + inner_tired = inner_iter ≥ inner_maxiters + status = ReturnCode.Default + + while !(solved || tired || breakdown) + # Initialize workspace. + # TODO: Check that not zeroing out (V, s, c, R, z) doesn't lead to incorrect results. + nr = 0 # Number of coefficients stored in Rₖ. + + if restart + xr .= zero(T) # xr === Δx when restart is set to true + if npass ≥ 1 + mul!(w, A, x) + axpby!(one(T), b, -one(T), w) + PlisI || ldiv!(r₀, Pl, w) + end + end + + # Initial ζ₁ and V₁ + β = _norm2(__batch(r₀), 1) + z[1] .= vec(β) + V[1] .= vec(__batch(r₀) ./ β) + + npass = npass + 1 + inner_iter = 0 + inner_tired = false + + while !(solved || inner_tired || breakdown) + # Update iteration index + inner_iter += 1 + # Update workspace if more storage is required and restart is set to false + if !restart && (inner_iter > memory) + append!(R, [similar(first(R), bsize) for _ in 1:inner_iter]) + push!(s, similar(first(s), bsize)) + push!(c, similar(first(c), bsize)) + end + + # Continue the Arnoldi process. + p = PrisI ? V[inner_iter] : p + PrisI || ldiv!(p, Pr, V[inner_iter]) # p ← Nvₖ + mul!(w, A, p) # w ← ANvₖ + PlisI || ldiv!(q, Pl, w) # q ← MANvₖ + for i in 1:inner_iter + sum!(R[nr + i]', __batch(V[i]) .* __batch(q)) + q .-= vec(R[nr + i]' .* __batch(V[i])) # q ← q - hᵢₖvᵢ + end + + # Compute hₖ₊₁.ₖ + Hbis = vec(_norm2(__batch(q), 1)) # hₖ₊₁.ₖ = ‖vₖ₊₁‖₂ + + # Update the QR factorization of Hₖ₊₁.ₖ. + # Apply previous Givens reflections Ωᵢ. + # [cᵢ sᵢ] [ r̄ᵢ.ₖ ] = [ rᵢ.ₖ ] + # [s̄ᵢ -cᵢ] [rᵢ₊₁.ₖ] [r̄ᵢ₊₁.ₖ] + for i in 1:(inner_iter - 1) + Rtmp = c[i] .* R[nr + i] .+ s[i] .* R[nr + i + 1] + R[nr + i + 1] .= conj.(s[i]) .* R[nr + i] .- c[i] .* R[nr + i + 1] + R[nr + i] .= Rtmp + end + + # Compute and apply current Givens reflection Ωₖ. + # [cₖ sₖ] [ r̄ₖ.ₖ ] = [rₖ.ₖ] + # [s̄ₖ -cₖ] [hₖ₊₁.ₖ] [ 0 ] + _sym_givens!(c, s, R, nr, inner_iter, bsize, Hbis) + + # Update zₖ = (Qₖ)ᴴβe₁ + ζₖ₊₁ = conj.(s[inner_iter]) .* z[inner_iter] + z[inner_iter] .= c[inner_iter] .* z[inner_iter] + + # Update residual norm estimate. + # ‖ Pl(b - Axₖ) ‖₂ = |ζₖ₊₁| + rNorm = maximum(abs, ζₖ₊₁) + + # Update the number of coefficients in Rₖ + nr = nr + inner_iter + + # Stopping conditions that do not depend on user input. + # This is to guard against tolerances that are unreasonably small. + resid_decrease_mach = (rNorm + one(T) ≤ one(T)) + + # Update stopping criterion. + resid_decrease_lim = rNorm ≤ ε + breakdown = maximum(Hbis) ≤ btol + solved = resid_decrease_lim || resid_decrease_mach + inner_tired = restart ? inner_iter ≥ min(memory, inner_maxiters) : + inner_iter ≥ inner_maxiters + + # Compute vₖ₊₁. + if !(solved || inner_tired || breakdown) + if !restart && (inner_iter ≥ memory) + push!(V, similar(first(V))) + push!(z, similar(first(z), bsize)) + end + V[inner_iter + 1] .= vec(__batch(q) ./ Hbis') # hₖ₊₁.ₖvₖ₊₁ = q + z[inner_iter + 1] .= ζₖ₊₁ + end + end + + # Compute yₖ by solving Rₖyₖ = zₖ with backward substitution. + y = z # yᵢ = zᵢ + for i in inner_iter:-1:1 + pos = nr + i - inner_iter # position of rᵢ.ₖ + for j in inner_iter:-1:(i + 1) + y[i] .= y[i] .- R[pos] .* y[j] # yᵢ ← yᵢ - rᵢⱼyⱼ + pos = pos - j + 1 # position of rᵢ.ⱼ₋₁ + end + # Rₖ can be singular if the system is inconsistent + y[i] .= ifelse.(abs.(R[pos]) .≤ btol, zero(T), y[i] ./ R[pos]) # yᵢ ← yᵢ / rᵢᵢ + inconsistent = any(abs.(R[pos]) .≤ btol) + end + + # Form xₖ = NVₖyₖ + for i in 1:inner_iter + xr .+= vec(__batch(V[i]) .* y[i]') + end + if !PrisI + p .= xr + ldiv!(xr, Pr, p) + end + restart && axpy!(one(T), xr, x) + + # Update inner_itmax, iter, tired and overtimed variables. + inner_maxiters = inner_maxiters - inner_iter + iter = iter + inner_iter + tired = iter ≥ maxiters + end + + # Termination status + tired && (status = ReturnCode.MaxIters) + solved && (status = ReturnCode.Success) + inconsistent && (status = ReturnCode.Infeasible) + + # Update x + warm_start && !restart && axpy!(one(T), Δx, x) + + return SciMLBase.build_linear_solution(lincache.alg, x, rNorm, lincache; + retcode = status, iters = iter) +end diff --git a/test/basictests.jl b/test/basictests.jl index a58d4987d..401568d19 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -235,6 +235,10 @@ end end end + @testset "Simple GMRES: restart = $restart" for restart in (true, false) + test_interface(SimpleGMRES(; restart), prob1, prob2) + end + @testset "KrylovJL" begin kwargs = (; gmres_restart = 5) for alg in (("Default", KrylovJL(kwargs...)), @@ -412,7 +416,7 @@ end @testset "DirectLdiv!" begin function get_operator(A, u; add_inverse = true) - + function f(u, p, t) println("using FunctionOperator OOP mul") A * u @@ -470,3 +474,28 @@ lp = LinearProblem(A, b; u0 = view(u0, :)); truesol = solve(lp, LUFactorization()) krylovsol = solve(lp, KrylovJL_GMRES()) @test truesol ≈ krylovsol + +# Block Diagonals +using BlockDiagonals + +@testset "Block Diagonal Specialization" begin + A = BlockDiagonal([rand(2, 2) for _ in 1:3]) + b = rand(size(A, 1)) + + if VERSION > v"1.9-" + x1 = zero(b) + x2 = zero(b) + prob1 = LinearProblem(A, b, x1) + prob2 = LinearProblem(A, b, x2) + test_interface(SimpleGMRES(), prob1, prob2) + end + + x1 = zero(b) + x2 = zero(b) + prob1 = LinearProblem(Array(A), b, x1) + prob2 = LinearProblem(Array(A), b, x2) + + test_interface(SimpleGMRES(; blocksize=2), prob1, prob2) + + @test solve(prob1, SimpleGMRES(; blocksize=2)).u ≈ solve(prob2, SimpleGMRES()).u +end diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 8ea63055c..7fc6e3847 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,4 +1,5 @@ [deps] +BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" diff --git a/test/gpu/cuda.jl b/test/gpu/cuda.jl index 9ae035501..042576300 100644 --- a/test/gpu/cuda.jl +++ b/test/gpu/cuda.jl @@ -42,10 +42,32 @@ function test_interface(alg, prob1, prob2) return end -test_interface(CudaOffloadFactorization(), prob1, prob2) +@testset "CudaOffloadFactorization" begin + test_interface(CudaOffloadFactorization(), prob1, prob2) +end + +@testset "Simple GMRES: restart = $restart" for restart in (true, false) + test_interface(SimpleGMRES(; restart), prob1, prob2) +end A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0; y = solve(prob1) @test A1 * y ≈ b1 + +using BlockDiagonals + +@testset "Block Diagonal Specialization" begin + A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu + b = rand(size(A, 1)) |> cu + + x1 = zero(b) + x2 = zero(b) + prob1 = LinearProblem(A, b, x1) + prob2 = LinearProblem(A, b, x2) + + test_interface(SimpleGMRES(; blocksize=2), prob1, prob2) + + @test solve(prob1, SimpleGMRES(; blocksize=2)).u ≈ solve(prob2, SimpleGMRES()).u +end