Skip to content

Commit

Permalink
pardiso: format
Browse files Browse the repository at this point in the history
  • Loading branch information
j-fu committed Jun 5, 2024
1 parent 372ea1e commit 0c0de1c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 109 deletions.
11 changes: 5 additions & 6 deletions ext/LinearSolvePardisoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ function LinearSolve.init_cacheval(alg::PardisoJL,

if isnothing(vendor)
if Pardiso.panua_is_available()
vendor=:Panua
vendor = :Panua

Check warning on line 30 in ext/LinearSolvePardisoExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolvePardisoExt.jl#L30

Added line #L30 was not covered by tests
else
vendor=:MKL
vendor = :MKL
end
end

Expand All @@ -42,7 +42,7 @@ function LinearSolve.init_cacheval(alg::PardisoJL,

# for mkl 1 means conjugated and 2 means transposed.
# https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-0/pardiso-iparm-parameter.html#IPARM37
transposed_iparm = 2
transposed_iparm = 2

solver
else
Expand All @@ -53,15 +53,15 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
solver = Pardiso.PardisoSolver()
Pardiso.pardisoinit(solver)
solver_type !== nothing && Pardiso.set_solver!(solver, solver_type)

Check warning on line 55 in ext/LinearSolvePardisoExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolvePardisoExt.jl#L51-L55

Added lines #L51 - L55 were not covered by tests

solver

Check warning on line 57 in ext/LinearSolvePardisoExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolvePardisoExt.jl#L57

Added line #L57 was not covered by tests
else
error("Panua Pardiso is not available.")

Check warning on line 59 in ext/LinearSolvePardisoExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolvePardisoExt.jl#L59

Added line #L59 was not covered by tests
end
else
error("Pardiso vendor must be either `:MKL` or `:Panua`")
end

if matrix_type !== nothing
Pardiso.set_matrixtype!(solver, matrix_type)
else
Expand Down Expand Up @@ -125,7 +125,6 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
b)
end

>>>>>>> main
return solver
end

Expand Down
17 changes: 8 additions & 9 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ All values default to `nothing` and the solver internally determines the values
given the input types, and these keyword arguments are only for overriding the
default handling process. This should not be required by most users.
"""
MKLPardisoFactorize(; kwargs...) = PardisoJL(; vendor=:MKL, solver_type = 0, kwargs...)
MKLPardisoFactorize(; kwargs...) = PardisoJL(; vendor = :MKL, solver_type = 0, kwargs...)

"""
```julia
Expand Down Expand Up @@ -136,8 +136,7 @@ All values default to `nothing` and the solver internally determines the values
given the input types, and these keyword arguments are only for overriding the
default handling process. This should not be required by most users.
"""
MKLPardisoIterate(; kwargs...) = PardisoJL(; vendor=:MKL, solver_type = 1, kwargs...)

MKLPardisoIterate(; kwargs...) = PardisoJL(; vendor = :MKL, solver_type = 1, kwargs...)

"""
```julia
Expand Down Expand Up @@ -165,7 +164,8 @@ All values default to `nothing` and the solver internally determines the values
given the input types, and these keyword arguments are only for overriding the
default handling process. This should not be required by most users.
"""
PanuaPardisoFactorize(; kwargs...) = PardisoJL(; vendor=:Panua, solver_type = 0, kwargs...)
PanuaPardisoFactorize(; kwargs...) = PardisoJL(;

Check warning on line 167 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L167

Added line #L167 was not covered by tests
vendor = :Panua, solver_type = 0, kwargs...)

"""
```julia
Expand All @@ -188,8 +188,7 @@ All values default to `nothing` and the solver internally determines the values
given the input types, and these keyword arguments are only for overriding the
default handling process. This should not be required by most users.
"""
PanuaPardisoIterate(; kwargs...) = PardisoJL(; vendor=:Panua, solver_type = 1, kwargs...)

PanuaPardisoIterate(; kwargs...) = PardisoJL(; vendor = :Panua, solver_type = 1, kwargs...)

Check warning on line 191 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L191

Added line #L191 was not covered by tests

"""
```julia
Expand All @@ -198,7 +197,7 @@ PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
matrix_type = nothing,
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
vendor::Union{Symbol,Nothing} = nothing
vendor::Union{Symbol, Nothing} = nothing
)
```
Expand All @@ -225,15 +224,15 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
cache_analysis::Bool
iparm::Union{Vector{Tuple{Int, Int}}, Nothing}
dparm::Union{Vector{Tuple{Int, Int}}, Nothing}
vendor::Union{Symbol,Nothing}
vendor::Union{Symbol, Nothing}

function PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
solver_type = nothing,
matrix_type = nothing,
cache_analysis = false,
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
vendor::Union{Symbol,Nothing}=nothing )
vendor::Union{Symbol, Nothing} = nothing)
ext = Base.get_extension(@__MODULE__, :LinearSolvePardisoExt)
if ext === nothing
error("PardisoJL requires that Pardiso is loaded, i.e. `using Pardiso`")
Expand Down
203 changes: 109 additions & 94 deletions test/pardiso/pardiso.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@ cache_kwargs = (; abstol = 1e-8, reltol = 1e-8, maxiter = 30)

prob2 = LinearProblem(A2, b2)

algs=[PardisoJL()]
algs=LinearSolve.SciMLLinearSolveAlgorithm[PardisoJL()]
solvers=Pardiso.AbstractPardisoSolver[]
extended_algs=LinearSolve.SciMLLinearSolveAlgorithm[PardisoJL()]

if Pardiso.mkl_is_available()
algs=vcat(algs,[MKLPardisoFactorize(), MKLPardisoIterate()])
push!(algs,MKLPardisoFactorize())
push!(solvers,Pardiso.MKLPardisoSolver())
extended_algs=vcat(extended_algs,[MKLPardisoFactorize(), MKLPardisoIterate()])
@info "Testing MKL Pardiso"
end

if Pardiso.panua_is_available()
algs=vcat(algs,[PanuaPardisoFactorize(), PanuaPardisoIterate()])
push!(algs,PanuaPardisoFactorize())
push!(solvers,Pardiso.PardisoSolver())
extended_algs=vcat(extended_algs,[PanuaPardisoFactorize(), PanuaPardisoIterate()])
@info "Testing Panua Pardiso"
end


for alg in algs
for alg in extended_algs
u = solve(prob1, alg; cache_kwargs...).u
@test A1 * u b1

Expand All @@ -38,7 +46,6 @@ for alg in algs
@test A2 * u b2
end

return


Random.seed!(10)
Expand All @@ -48,24 +55,27 @@ b1 = rand(n);
b2 = rand(n);
prob = LinearProblem(copy(A), copy(b1))

prob = LinearProblem(copy(A), copy(b1))

linsolve = init(prob, UMFPACKFactorization())
sol11 = solve(linsolve)
linsolve = LinearSolve.set_b(sol11.cache, copy(b2))
sol12 = solve(linsolve)
linsolve = LinearSolve.set_A(sol12.cache, copy(A2))
sol13 = solve(linsolve)

linsolve = init(prob, MKLPardisoFactorize())
sol31 = solve(linsolve)
linsolve = LinearSolve.set_b(sol31.cache, copy(b2))
sol32 = solve(linsolve)
linsolve = LinearSolve.set_A(sol32.cache, copy(A2))
sol33 = solve(linsolve)

@test sol11.u sol31.u
@test sol12.u sol32.u
@test sol13.u sol33.u

for alg in algs
linsolve = init(prob, alg)
sol31 = solve(linsolve)
linsolve = LinearSolve.set_b(sol31.cache, copy(b2))
sol32 = solve(linsolve)
linsolve = LinearSolve.set_A(sol32.cache, copy(A2))
sol33 = solve(linsolve)
@test sol11.u sol31.u
@test sol12.u sol32.u
@test sol13.u sol33.u
end


# Test for problem from #497
Expand All @@ -78,87 +88,92 @@ function makeA()
return(A)
end

A=makeA()
u0=fill(0.1,size(A,2))
linprob = LinearProblem(A, A*u0)
u = LinearSolve.solve(linprob, PardisoJL())
@test norm(u-u0) < 1.0e-14

for alg in algs
A=makeA()
u0=fill(0.1,size(A,2))
linprob = LinearProblem(A, A*u0)
u = LinearSolve.solve(linprob, alg)
@test norm(u-u0) < 1.0e-14
end


# Testing and demonstrating Pardiso.set_iparm! for MKLPardisoSolver
solver = Pardiso.MKLPardisoSolver()
iparm = [
(1, 1),
(2, 2),
(3, 0),
(4, 0),
(5, 0),
(6, 0),
(7, 0),
(8, 20),
(9, 0),
(10, 13),
(11, 1),
(12, 1),
(13, 1),
(14, 0),
(15, 0),
(16, 0),
(17, 0),
(18, -1),
(19, -1),
(20, 0),
(21, 0),
(22, 0),
(23, 0),
(24, 10),
(25, 0),
(26, 0),
(27, 1),
(28, 0),
(29, 0),
(30, 0),
(31, 0),
(32, 0),
(33, 0),
(34, 0),
(35, 0),
(36, 0),
(37, 0),
(38, 0),
(39, 0),
(40, 0),
(41, 0),
(42, 0),
(43, 0),
(44, 0),
(45, 0),
(46, 0),
(47, 0),
(48, 0),
(49, 0),
(50, 0),
(51, 0),
(52, 0),
(53, 0),
(54, 0),
(55, 0),
(56, 0),
(57, 0),
(58, 0),
(59, 0),
(60, 0),
(61, 0),
(62, 0),
(63, 0),
(64, 0)
]

for i in iparm
Pardiso.set_iparm!(solver, i...)
end

for i in Base.OneTo(length(iparm))
@test Pardiso.get_iparm(solver, i) == iparm[i][2]

# Testing and demonstrating Pardiso.set_iparm! for MKLPardisoSolver
for solver in solvers
iparm = [
(1, 1),
(2, 2),
(3, 0),
(4, 0),
(5, 0),
(6, 0),
(7, 0),
(8, 20),
(9, 0),
(10, 13),
(11, 1),
(12, 1),
(13, 1),
(14, 0),
(15, 0),
(16, 0),
(17, 0),
(18, -1),
(19, -1),
(20, 0),
(21, 0),
(22, 0),
(23, 0),
(24, 10),
(25, 0),
(26, 0),
(27, 1),
(28, 0),
(29, 0),
(30, 0),
(31, 0),
(32, 0),
(33, 0),
(34, 0),
(35, 0),
(36, 0),
(37, 0),
(38, 0),
(39, 0),
(40, 0),
(41, 0),
(42, 0),
(43, 0),
(44, 0),
(45, 0),
(46, 0),
(47, 0),
(48, 0),
(49, 0),
(50, 0),
(51, 0),
(52, 0),
(53, 0),
(54, 0),
(55, 0),
(56, 0),
(57, 0),
(58, 0),
(59, 0),
(60, 0),
(61, 0),
(62, 0),
(63, 0),
(64, 0)
]

for i in iparm
Pardiso.set_iparm!(solver, i...)
end

for i in Base.OneTo(length(iparm))
@test Pardiso.get_iparm(solver, i) == iparm[i][2]
end
end

0 comments on commit 0c0de1c

Please sign in to comment.