diff --git a/test/Project.toml b/test/Project.toml index 818e0ac708..b8b32f556e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/mpi.jl b/test/mpi.jl new file mode 100644 index 0000000000..2988764472 --- /dev/null +++ b/test/mpi.jl @@ -0,0 +1,61 @@ +using MPI +using Enzyme +using Test + +struct Context + x::Vector{Float64} +end + +function halo(context) + x = context.x + np = MPI.Comm_size(MPI.COMM_WORLD) + rank = MPI.Comm_rank(MPI.COMM_WORLD) + requests = Vector{MPI.Request}() + if rank != 0 + buf = @view x[1:1] + push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0)) + push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0)) + end + if rank != np-1 + buf = @view x[end:end] + push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0)) + push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0)) + end + for request in requests + MPI.Wait(request) + end + return nothing +end + +MPI.Init() +np = MPI.Comm_size(MPI.COMM_WORLD) +rank = MPI.Comm_rank(MPI.COMM_WORLD) +n = np*10 +n1 = Int(round(rank / np * (n+np))) - rank +n2 = Int(round((rank + 1) / np * (n+np))) - rank +nl = rank == 0 ? n1+1 : n1 +nr = rank == np-1 ? n2-1 : n2 +nlocal = nr-nl+1 +context = Context(zeros(nlocal)) +fill!(context.x, Float64(rank)) +halo(context) +if rank != 0 + @test context.x[1] == Float64(rank-1) +end +if rank != np-1 + @test context.x[end] == Float64(rank+1) +end + +dcontext = Context(zeros(nlocal)) +fill!(dcontext.x, Float64(rank)) +autodiff(Reverse, halo, Duplicated(context, dcontext)) +MPI.Barrier(MPI.COMM_WORLD) +if rank != 0 + @test dcontext.x[2] == Float64(rank + rank - 1) +end +if rank != np-1 + @test dcontext.x[end-1] == Float64(rank + rank + 1) +end +if !isinteractive() + MPI.Finalize() +end diff --git a/test/runtests.jl b/test/runtests.jl index c9cde02c28..9c05f68083 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using Aqua using Statistics using LinearAlgebra using InlineStrings +using MPI using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme @@ -1176,7 +1177,7 @@ end bias = Float32[0.0;;;] res = Enzyme.autodiff(Reverse, f, Active, Active(x[1]), Const(bias)) - + @test bias[1][1] ≈ 0.0 @test res[1][1] ≈ cos(x[1]) end @@ -1585,7 +1586,7 @@ end @inline function myquantile(v::AbstractVector, p::Real; alpha) n = length(v) - + m = 1.0 + p * (1.0 - alpha - 1.0) aleph = n*p + oftype(p, m) j = clamp(trunc(Int, aleph), 1, n-1) @@ -1598,7 +1599,7 @@ end a = @inbounds v[j] b = @inbounds v[j + 1] end - + return a + γ*(b-a) end @@ -1820,18 +1821,18 @@ end @test 1.0 ≈ Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1] function whocallsmorethan30args(R) - temp = diag(R) - R_inv = [temp[1] 0. 0. 0. 0. 0.; - 0. temp[2] 0. 0. 0. 0.; - 0. 0. temp[3] 0. 0. 0.; - 0. 0. 0. temp[4] 0. 0.; - 0. 0. 0. 0. temp[5] 0.; + temp = diag(R) + R_inv = [temp[1] 0. 0. 0. 0. 0.; + 0. temp[2] 0. 0. 0. 0.; + 0. 0. temp[3] 0. 0. 0.; + 0. 0. 0. temp[4] 0. 0.; + 0. 0. 0. 0. temp[5] 0.; ] - + return sum(R_inv) end - - R = zeros(6,6) + + R = zeros(6,6) dR = zeros(6, 6) @static if VERSION ≥ v"1.11-" @@ -2719,7 +2720,7 @@ end end # TODO: Add test for NoShadowException end - + function indirectfltret(a)::DataType a[] *= 2 return Float64 @@ -3419,6 +3420,22 @@ end ) @test ad_eta[1] ≈ 0.0 end +@testset "MPI" begin + testdir = @__DIR__ + # Test parsing + mpi_test = false + try + include("mpi.jl") + mpiexec() do cmd + run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`) + end + mpi_test = true + catch + mpi_test = false + end + @test mpi_test +end + function absset(out, x) @inbounds out[1] = (x,) @@ -3698,10 +3715,10 @@ end Duplicated(inters, dinters), ) - @test dinters[1].k ≈ 0.1 - @test dinters[1].t0 ≈ 1.0 - @test dinters[2].k ≈ 0.3 - @test dinters[2].t0 ≈ 2.0 + @test dinters[1].k ≈ 0.1 + @test dinters[1].t0 ≈ 1.0 + @test dinters[2].k ≈ 0.3 + @test dinters[2].t0 ≈ 2.0 end @testset "Statistics" begin