diff --git a/ext/LinearSolveCUDAExt.jl b/ext/LinearSolveCUDAExt.jl index 08acdaa28..c036d53e0 100644 --- a/ext/LinearSolveCUDAExt.jl +++ b/ext/LinearSolveCUDAExt.jl @@ -1,31 +1,26 @@ module LinearSolveCUDAExt -using CUDA, LinearAlgebra, LinearSolve, SciMLBase +using CUDA +using LinearSolve +using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface using SciMLBase: AbstractSciMLOperator function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; kwargs...) if cache.isfresh - fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u) - cache = LinearSolve.set_cacheval(cache, fact) + fact = qr(CUDA.CuArray(cache.A)) + cache.cacheval = fact cache.isfresh = false end - - copyto!(cache.u, cache.b) - y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u))) + y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) + cache.u .= y SciMLBase.build_linear_solution(alg, y, nothing, cache) end -function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u) - A isa Union{AbstractMatrix, AbstractSciMLOperator} || - error("LU is not defined for $(typeof(A))") - - if A isa Union{MatrixOperator, DiffEqArrayOperator} - A = A.A - end - - fact = qr(CUDA.CuArray(A)) - return fact +function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + qr(CUDA.CuArray(A)) end end diff --git a/test/gpu/cuda.jl b/test/gpu/cuda.jl index 7402275eb..5d901de5a 100644 --- a/test/gpu/cuda.jl +++ b/test/gpu/cuda.jl @@ -28,16 +28,16 @@ function test_interface(alg, prob1, prob2) @test A1 * y ≈ b1 cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache - y = solve(cache) - @test A1 * y ≈ b1 + solve!(cache) + @test A1 * cache.u ≈ b1 - cache = LinearSolve.set_A(cache, copy(A2)) - y = solve(cache) - @test A2 * y ≈ b1 + cache.A = copy(A2) + solve!(cache) + @test A2 * cache.u ≈ b1 - cache = LinearSolve.set_b(cache, b2) - y = solve(cache) - @test A2 * y ≈ b2 + cache.b = copy(b2) + solve!(cache) + @test A2 * cache.u ≈ b2 return end