From 35d4fae6d0199d7e3a1ae4f2061697264ac85583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Fuhrmann?= Date: Sat, 19 Oct 2024 21:07:27 +0200 Subject: [PATCH 1/2] Allow to use Pardiso with AbstractSparseMatrixCSC --- ext/LinearSolvePardisoExt.jl | 5 ++--- src/extension_algs.jl | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ext/LinearSolvePardisoExt.jl b/ext/LinearSolvePardisoExt.jl index 0b4cfbb1..0318bb8a 100644 --- a/ext/LinearSolvePardisoExt.jl +++ b/ext/LinearSolvePardisoExt.jl @@ -134,12 +134,11 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs if cache.isfresh phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT Pardiso.set_phase!(cache.cacheval, phase) - Pardiso.pardiso(cache.cacheval, A, eltype(A)[]) + Pardiso.pardiso(cache.cacheval, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), eltype(A)[]) cache.isfresh = false end Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE) - Pardiso.pardiso(cache.cacheval, u, A, b) - + Pardiso.pardiso(cache.cacheval, u, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b) return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 7534d2fa..2559a210 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -217,7 +217,7 @@ All values default to `nothing` and the solver internally determines the values given the input types, and these keyword arguments are only for overriding the default handling process. This should not be required by most users. """ -struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm +struct PardisoJL{T1, T2} <: AbstractSparseFactorization nprocs::Union{Int, Nothing} solver_type::T1 matrix_type::T2 From c88b634f945b95023bae7698d73b8e8f4702f820 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Fuhrmann?= Date: Sun, 20 Oct 2024 21:45:15 +0200 Subject: [PATCH 2/2] Add a test for PardisoExt working with AbstractSparseMatrixCSC --- test/pardiso/pardiso.jl | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/pardiso/pardiso.jl b/test/pardiso/pardiso.jl index a961a53d..c6af3cf4 100644 --- a/test/pardiso/pardiso.jl +++ b/test/pardiso/pardiso.jl @@ -177,3 +177,40 @@ for solver in solvers @test Pardiso.get_iparm(solver, i) == iparm[i][2] end end + +@testset "AbstractSparseMatrixCSC" begin + struct MySparseMatrixCSC2{Tv, Ti} <: SparseArrays.AbstractSparseMatrixCSC{Tv, Ti} + csc::SparseMatrixCSC{Tv, Ti} + end + + Base.size(m::MySparseMatrixCSC2) = size(m.csc) + SparseArrays.getcolptr(m::MySparseMatrixCSC2) = SparseArrays.getcolptr(m.csc) + SparseArrays.rowvals(m::MySparseMatrixCSC2) = SparseArrays.rowvals(m.csc) + SparseArrays.nonzeros(m::MySparseMatrixCSC2) = SparseArrays.nonzeros(m.csc) + + for alg in algs + N = 100 + u0 = ones(N) + A0 = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1)) + b0 = A0 * u0 + B0 = MySparseMatrixCSC2(A0) + A1 = spdiagm(1 => -ones(N - 1), 0 => fill(100.0, N), -1 => -ones(N - 1)) + b1=A1*u0 + B1= MySparseMatrixCSC2(A1) + + + pr = LinearProblem(B0, b0) + # test default algorithn + u=solve(pr,alg) + @test norm(u - u0, Inf) < 1.0e-13 + + # test factorization with reinit! + pr = LinearProblem(B0, b0) + cache=init(pr,alg) + u=solve!(cache) + @test norm(u - u0, Inf) < 1.0e-13 + reinit!(cache; A=B1, b=b1) + u=solve!(cache) + @test norm(u - u0, Inf) < 1.0e-13 + end +end