Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 15, 2023
1 parent d738dcc commit b1e4bf4
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
2 changes: 2 additions & 0 deletions src/simplegmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ function _init_cacheval(::Val{false}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiter
q = PlisI ? similar(u, 0) : similar(u, n)
p = PrisI ? similar(u, 0) : similar(u, n)
x = u
x .= zero(T)

w = similar(u, n)
V = [similar(u) for _ in 1:memory]
Expand Down Expand Up @@ -398,6 +399,7 @@ function _init_cacheval(::Val{true}, alg::SimpleGMRES, A, b, u, Pl, Pr, maxiters
q = PlisI ? similar(u, 0) : similar(u, n)
p = PrisI ? similar(u, 0) : similar(u, n)
x = u
x .= zero(T)

w = similar(u, n)
V = [similar(u) for _ in 1:memory]
Expand Down
31 changes: 30 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ end
end
end

@testset "Simple GMRES: restart = $restart" for restart in (true, false)
test_interface(SimpleGMRES(; restart), prob1, prob2)
end

@testset "KrylovJL" begin
kwargs = (; gmres_restart = 5)
for alg in (("Default", KrylovJL(kwargs...)),
Expand Down Expand Up @@ -412,7 +416,7 @@ end

@testset "DirectLdiv!" begin
function get_operator(A, u; add_inverse = true)

function f(u, p, t)
println("using FunctionOperator OOP mul")
A * u
Expand Down Expand Up @@ -470,3 +474,28 @@ lp = LinearProblem(A, b; u0 = view(u0, :));
truesol = solve(lp, LUFactorization())
krylovsol = solve(lp, KrylovJL_GMRES())
@test truesol krylovsol

# Block Diagonals
using BlockDiagonals

@testset "Block Diagonal Specialization" begin
A = BlockDiagonal([rand(2, 2) for _ in 1:3])
b = rand(size(A, 1))

if VERSION > v"1.9-"
x1 = zero(b)
x2 = zero(b)
prob1 = LinearProblem(A, b, x1)
prob2 = LinearProblem(A, b, x2)
test_interface(SimpleGMRES(), prob1, prob2)
end

x1 = zero(b)
x2 = zero(b)
prob1 = LinearProblem(Array(A), b, x1)
prob2 = LinearProblem(Array(A), b, x2)

test_interface(SimpleGMRES(; blocksize=2), prob1, prob2)

@test solve(prob1, SimpleGMRES(; blocksize=2)).u solve(prob2, SimpleGMRES()).u
end
1 change: 1 addition & 0 deletions test/gpu/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down
24 changes: 23 additions & 1 deletion test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,32 @@ function test_interface(alg, prob1, prob2)
return
end

test_interface(CudaOffloadFactorization(), prob1, prob2)
@testset "CudaOffloadFactorization" begin
test_interface(CudaOffloadFactorization(), prob1, prob2)
end

@testset "Simple GMRES: restart = $restart" for restart in (true, false)
test_interface(SimpleGMRES(; restart), prob1, prob2)
end

A1 = prob1.A;
b1 = prob1.b;
x1 = prob1.u0;
y = solve(prob1)
@test A1 * y b1

using BlockDiagonals

@testset "Block Diagonal Specialization" begin
A = BlockDiagonal([rand(2, 2) for _ in 1:3]) |> cu
b = rand(size(A, 1)) |> cu

x1 = zero(b)
x2 = zero(b)
prob1 = LinearProblem(A, b, x1)
prob2 = LinearProblem(A, b, x2)

test_interface(SimpleGMRES(; blocksize=2), prob1, prob2)

@test solve(prob1, SimpleGMRES(; blocksize=2)).u solve(prob2, SimpleGMRES()).u
end

0 comments on commit b1e4bf4

Please sign in to comment.