Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MKL the default when it's available #387

Merged
merged 5 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand All @@ -37,7 +38,6 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"

Expand All @@ -49,7 +49,6 @@ LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMKLExt = "MKL_jll"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"

Expand Down Expand Up @@ -91,7 +90,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand All @@ -101,4 +99,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ PrecompileTools.@recompile_invalidations begin
import Krylov

using SciMLBase

using MKL_jll
end

using Reexport
@reexport using SciMLBase
using SciMLBase: _unwrap_val

const usemkl = MKL_jll.is_available()

abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
Expand Down Expand Up @@ -91,6 +95,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
CholeskyFactorization
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand All @@ -100,6 +105,7 @@ end
include("common.jl")
include("factorization.jl")
include("appleaccelerate.jl")
include("mkl.jl")
include("simplelu.jl")
include("simplegmres.jl")
include("iterative_wrappers.jl")
Expand Down
14 changes: 11 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17}
T13, T14, T15, T16, T17, T18}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -18,6 +18,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
CholeskyFactorization::T15
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
end

# Legacy fallback
Expand Down Expand Up @@ -162,19 +163,24 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.GenericLUFactorization
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
(usemkl && length(b) <= 200)) &&
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.GenericLUFactorization
DefaultAlgorithmChoice.LUFactorization
end
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
elseif usemkl
DefaultAlgorithmChoice.MKLLUFactorization
else
DefaultAlgorithmChoice.LUFactorization
end
Expand Down Expand Up @@ -209,6 +215,8 @@ function algchoice_to_alg(alg::Symbol)
LDLtFactorization()
elseif alg === :LUFactorization
LUFactorization()
elseif alg === :MKLLUFactorization
MKLLUFactorization()
elseif alg === :QRFactorization
QRFactorization()
elseif alg === :DiagonalFactorization
Expand Down
10 changes: 0 additions & 10 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,6 @@ A wrapper over the IterativeSolvers.jl MINRES.
"""
function IterativeSolversJL_MINRES end

"""
```julia
MKLLUFactorization()
```

A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct MKLLUFactorization <: AbstractFactorization end

"""
```julia
MetalLUFactorization()
Expand Down
3 changes: 0 additions & 3 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ function __init__()
@require KrylovKit="0b1a1467-8014-51b9-945f-bf0ae24f4b77" begin
include("../ext/LinearSolveKrylovKitExt.jl")
end
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/LinearSolveEnzymeExt.jl")
end
Expand Down
32 changes: 16 additions & 16 deletions ext/LinearSolveMKLExt.jl → src/mkl.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
module LinearSolveMKLExt
"""
```julia
MKLLUFactorization()
```

using MKL_jll
using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing,
chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearAlgebra
const usemkl = MKL_jll.is_available()

using LinearSolve
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct MKLLUFactorization <: AbstractFactorization end

function getrf!(A::AbstractMatrix{<:Float64};
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
Expand Down Expand Up @@ -104,10 +101,15 @@ end
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false

function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
const PREALLOCATED_MKL_LU = begin
A = rand(0, 0)
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
end

function init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
PREALLOCATED_MKL_LU
end

function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
Expand Down Expand Up @@ -140,6 +142,4 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;

SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
=#
end

end
end
10 changes: 8 additions & 2 deletions test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ solve(prob)
prob = LinearProblem(rand(50, 50), rand(50))
solve(prob)

@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
if LinearSolve.usemkl
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.MKLLUFactorization
else
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
LinearSolve.DefaultAlgorithmChoice.LUFactorization
end

prob = LinearProblem(rand(600, 600), rand(600))
solve(prob)

Expand Down
Loading