diff --git a/Project.toml b/Project.toml index f9014da21..bf84ffc90 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ KLU = "ef3ab10e-7fda-4108-b977-705223b18434" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" @@ -37,7 +38,6 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" @@ -49,7 +49,6 @@ LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" LinearSolveKrylovKitExt = "KrylovKit" -LinearSolveMKLExt = "MKL_jll" LinearSolveMetalExt = "Metal" LinearSolvePardisoExt = "Pardiso" @@ -91,7 +90,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" @@ -101,4 +99,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"] diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 6cdfe8794..62ac8cf31 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -40,12 +40,16 @@ PrecompileTools.@recompile_invalidations begin import Krylov using SciMLBase + + using MKL_jll end using Reexport @reexport using SciMLBase using SciMLBase: _unwrap_val +const usemkl = MKL_jll.is_available() + abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end diff --git a/src/default.jl b/src/default.jl index 17335958a..94ec4d658 100644 --- a/src/default.jl +++ b/src/default.jl @@ -1,6 +1,6 @@ needs_concrete_A(alg::DefaultLinearSolver) = true mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, - T13, T14, T15, T16, T17} + T13, T14, T15, T16, T17, T18} LUFactorization::T1 QRFactorization::T2 DiagonalFactorization::T3 @@ -18,6 +18,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, CholeskyFactorization::T15 NormalCholeskyFactorization::T16 AppleAccelerateLUFactorization::T17 + MKLLUFactorization::T18 end # Legacy fallback @@ -162,19 +163,24 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.GenericLUFactorization elseif VERSION >= v"1.8" && appleaccelerate_isavailable() DefaultAlgorithmChoice.AppleAccelerateLUFactorization - elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && + elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) || + (usemkl && length(b) <= 200)) && (A === nothing ? eltype(b) <: Union{Float32, Float64} : eltype(A) <: Union{Float32, Float64}) DefaultAlgorithmChoice.RFLUFactorization #elseif A === nothing || A isa Matrix # alg = FastLUFactorization() + elseif usemkl + DefaultAlgorithmChoice.MKLLUFactorization else - DefaultAlgorithmChoice.GenericLUFactorization + DefaultAlgorithmChoice.LUFactorization end elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned DefaultAlgorithmChoice.QRFactorization elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned DefaultAlgorithmChoice.SVDFactorization + elseif usemkl + DefaultAlgorithmChoice.MKLLUFactorization else DefaultAlgorithmChoice.LUFactorization end @@ -209,6 +215,8 @@ function algchoice_to_alg(alg::Symbol) LDLtFactorization() elseif alg === :LUFactorization LUFactorization() + elseif alg === :MKLLUFactorization + MKLLUFactorization() elseif alg === :QRFactorization QRFactorization() elseif alg === :DiagonalFactorization diff --git a/ext/LinearSolveMKLExt.jl b/src/mkl.jl similarity index 92% rename from ext/LinearSolveMKLExt.jl rename to src/mkl.jl index a094abe7e..1a4a018e1 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/src/mkl.jl @@ -1,16 +1,3 @@ -module LinearSolveMKLExt - -using MKL_jll -using LinearAlgebra: BlasInt, LU -using LinearAlgebra.LAPACK: require_one_based_indexing, - chkfinite, chkstride1, - @blasfunc, chkargsok -using LinearAlgebra -const usemkl = MKL_jll.is_available() - -using LinearSolve -using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase - function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))), info = Ref{BlasInt}(), @@ -140,6 +127,4 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) =# -end - -end +end \ No newline at end of file