Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MPI test #518

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
61 changes: 61 additions & 0 deletions test/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using MPI
using Enzyme
using Test

struct Context
x::Vector{Float64}
end

function halo(context)
x = context.x
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
requests = Vector{MPI.Request}()
if rank != 0
buf = @view x[1:1]
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
Comment on lines +15 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
buf = @view x[1:1]
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
buf = @view x[1:1]
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank - 1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank - 1, tag=0))

end
if rank != np-1
buf = @view x[end:end]
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
Comment on lines +19 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if rank != np-1
buf = @view x[end:end]
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
if rank != np - 1
buf = @view x[end:end]
push!(requests,
MPI.Isend(x[(end - 1):(end - 1)], MPI.COMM_WORLD; dest=rank + 1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank + 1, tag=0))

end
for request in requests
MPI.Wait(request)
end
return nothing
end

MPI.Init()
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
n = np*10
n1 = Int(round(rank / np * (n+np))) - rank
n2 = Int(round((rank + 1) / np * (n+np))) - rank
nl = rank == 0 ? n1+1 : n1
nr = rank == np-1 ? n2-1 : n2
nlocal = nr-nl+1
Comment on lines +33 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
n = np*10
n1 = Int(round(rank / np * (n+np))) - rank
n2 = Int(round((rank + 1) / np * (n+np))) - rank
nl = rank == 0 ? n1+1 : n1
nr = rank == np-1 ? n2-1 : n2
nlocal = nr-nl+1
n = np * 10
n1 = Int(round(rank / np * (n + np))) - rank
n2 = Int(round((rank + 1) / np * (n + np))) - rank
nl = rank == 0 ? n1 + 1 : n1
nr = rank == np - 1 ? n2 - 1 : n2
nlocal = nr - nl + 1

context = Context(zeros(nlocal))
fill!(context.x, Float64(rank))
halo(context)
if rank != 0
@test context.x[1] == Float64(rank-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test context.x[1] == Float64(rank-1)
@test context.x[1] == Float64(rank - 1)

end
if rank != np-1
@test context.x[end] == Float64(rank+1)
Comment on lines +45 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if rank != np-1
@test context.x[end] == Float64(rank+1)
if rank != np - 1
@test context.x[end] == Float64(rank + 1)

end

dcontext = Context(zeros(nlocal))
fill!(dcontext.x, Float64(rank))
autodiff(Reverse, halo, Duplicated(context, dcontext))
MPI.Barrier(MPI.COMM_WORLD)
if rank != 0
@test dcontext.x[2] == Float64(rank + rank - 1)
end
if rank != np-1
@test dcontext.x[end-1] == Float64(rank + rank + 1)
Comment on lines +56 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if rank != np-1
@test dcontext.x[end-1] == Float64(rank + rank + 1)
if rank != np - 1
@test dcontext.x[end - 1] == Float64(rank + rank + 1)

end
if !isinteractive()
MPI.Finalize()
end
51 changes: 34 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Aqua
using Statistics
using LinearAlgebra
using InlineStrings
using MPI

using Enzyme_jll
@info "Testing against" Enzyme_jll.libEnzyme
Expand Down Expand Up @@ -1176,7 +1177,7 @@ end

bias = Float32[0.0;;;]
res = Enzyme.autodiff(Reverse, f, Active, Active(x[1]), Const(bias))

@test bias[1][1] ≈ 0.0
@test res[1][1] ≈ cos(x[1])
end
Expand Down Expand Up @@ -1585,7 +1586,7 @@ end

@inline function myquantile(v::AbstractVector, p::Real; alpha)
n = length(v)

m = 1.0 + p * (1.0 - alpha - 1.0)
aleph = n*p + oftype(p, m)
j = clamp(trunc(Int, aleph), 1, n-1)
michel2323 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -1598,7 +1599,7 @@ end
a = @inbounds v[j]
b = @inbounds v[j + 1]
end

return a + γ*(b-a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return a + γ*(b-a)
return a + γ * (b - a)

end

Expand Down Expand Up @@ -1820,18 +1821,18 @@ end
@test 1.0 ≈ Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test 1.0 Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1]
@test 1.0 Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1]


function whocallsmorethan30args(R)
temp = diag(R)
R_inv = [temp[1] 0. 0. 0. 0. 0.;
0. temp[2] 0. 0. 0. 0.;
0. 0. temp[3] 0. 0. 0.;
0. 0. 0. temp[4] 0. 0.;
0. 0. 0. 0. temp[5] 0.;
temp = diag(R)
R_inv = [temp[1] 0. 0. 0. 0. 0.;
0. temp[2] 0. 0. 0. 0.;
0. 0. temp[3] 0. 0. 0.;
0. 0. 0. temp[4] 0. 0.;
0. 0. 0. 0. temp[5] 0.;
]
Comment on lines +1825 to 1830
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
R_inv = [temp[1] 0. 0. 0. 0. 0.;
0. temp[2] 0. 0. 0. 0.;
0. 0. temp[3] 0. 0. 0.;
0. 0. 0. temp[4] 0. 0.;
0. 0. 0. 0. temp[5] 0.;
]
R_inv = [temp[1] 0.0 0.0 0.0 0.0 0.0;
0.0 temp[2] 0.0 0.0 0.0 0.0;
0.0 0.0 temp[3] 0.0 0.0 0.0;
0.0 0.0 0.0 temp[4] 0.0 0.0;
0.0 0.0 0.0 0.0 temp[5] 0.0]


return sum(R_inv)
end
R = zeros(6,6)

R = zeros(6,6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
R = zeros(6,6)
R = zeros(6, 6)

dR = zeros(6, 6)

@static if VERSION ≥ v"1.11-"
Expand Down Expand Up @@ -2719,7 +2720,7 @@ end
end
# TODO: Add test for NoShadowException
end

function indirectfltret(a)::DataType
a[] *= 2
return Float64
Expand Down Expand Up @@ -3419,6 +3420,22 @@ end
)
@test ad_eta[1] ≈ 0.0
end
@testset "MPI" begin
testdir = @__DIR__
# Test parsing
mpi_test = false
try
include("mpi.jl")
mpiexec() do cmd
run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`)
return run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`)

end
mpi_test = true
catch
mpi_test = false
end
@test mpi_test
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

function absset(out, x)
@inbounds out[1] = (x,)
Expand Down Expand Up @@ -3698,10 +3715,10 @@ end
Duplicated(inters, dinters),
)

@test dinters[1].k ≈ 0.1
@test dinters[1].t0 ≈ 1.0
@test dinters[2].k ≈ 0.3
@test dinters[2].t0 ≈ 2.0
@test dinters[1].k ≈ 0.1
@test dinters[1].t0 ≈ 1.0
@test dinters[2].k ≈ 0.3
@test dinters[2].t0 ≈ 2.0
end

@testset "Statistics" begin
Expand Down
Loading