Skip to content

Commit

Permalink
Support subcommunicator
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 20, 2024
1 parent cfa4888 commit 21c4761
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 16 deletions.
16 changes: 8 additions & 8 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ private:
template <typename FA>
static typename FA::FABType::value_type *
get_fab (FA& fa) {
auto myproc = ParallelDescriptor::MyProc();
auto myproc = ParallelContext::MyProcSub();
if (myproc < fa.size()) {
return fa.fabPtr(myproc);
} else {
Expand Down Expand Up @@ -378,8 +378,8 @@ R2C<T,D>::R2C (Box const& domain, Info const& info)
AMREX_ALWAYS_ASSERT(! m_info.batch_mode);
#endif

int myproc = ParallelDescriptor::MyProc();
int nprocs = ParallelDescriptor::NProcs();
int myproc = ParallelContext::MyProcSub();
int nprocs = ParallelContext::NProcsSub();

auto bax = amrex::decompose(m_real_domain, nprocs, {AMREX_D_DECL(false,true,true)});
DistributionMapping dmx = detail::make_iota_distromap(bax.size());
Expand Down Expand Up @@ -635,8 +635,8 @@ void R2C<T,D>::exec_r2c (Plan plan, MF& in, cMF& out)
if (! plan.defined) { return; }

#if defined(AMREX_USE_GPU)
auto* pin = in[ParallelDescriptor::MyProc()].dataPtr();
auto* pout = out[ParallelDescriptor::MyProc()].dataPtr();
auto* pin = in[ParallelContext::MyProcSub()].dataPtr();
auto* pout = out[ParallelContext::MyProcSub()].dataPtr();
#else
amrex::ignore_unused(in,out);
#endif
Expand Down Expand Up @@ -666,8 +666,8 @@ void R2C<T,D>::exec_c2r (Plan plan, cMF& in, MF& out)
if (! plan.defined) { return; }

#if defined(AMREX_USE_GPU)
auto* pin = in[ParallelDescriptor::MyProc()].dataPtr();
auto* pout = out[ParallelDescriptor::MyProc()].dataPtr();
auto* pin = in[ParallelContext::MyProcSub()].dataPtr();
auto* pout = out[ParallelContext::MyProcSub()].dataPtr();
#else
amrex::ignore_unused(in,out);
#endif
Expand Down Expand Up @@ -699,7 +699,7 @@ void R2C<T,D>::exec_c2c (Plan2 plan, cMF& inout)

amrex::ignore_unused(inout);
#if defined(AMREX_USE_GPU)
auto* p = inout[ParallelDescriptor::MyProc()].dataPtr();
auto* p = inout[ParallelContext::MyProcSub()].dataPtr();
#endif

#if defined(AMREX_USE_CUDA)
Expand Down
6 changes: 4 additions & 2 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ namespace amrex::FFT::detail

DistributionMapping make_iota_distromap (Long n)
{
AMREX_ASSERT(n <= ParallelDescriptor::NProcs());
AMREX_ASSERT(n <= ParallelContext::NProcsSub());
Vector<int> pm(n);
std::iota(pm.begin(), pm.end(), 0);
for (int i = 0; i < n; ++i) {
pm[i] = ParallelContext::local_to_global_rank(i);
}
return DistributionMapping(std::move(pm));
}

Expand Down
2 changes: 1 addition & 1 deletion Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)

Box cdomain = m_geom.Domain();
cdomain.setBig(0,cdomain.length(0)/2);
auto cba = amrex::decompose(cdomain, ParallelDescriptor::NProcs(),
auto cba = amrex::decompose(cdomain, ParallelContext::NProcsSub(),
{AMREX_D_DECL(true,true,false)});
DistributionMapping dm = detail::make_iota_distromap(cba.size());
FabArray<BaseFab<GpuComplex<T> > > spmf(cba, dm, 1, 0);
Expand Down
19 changes: 14 additions & 5 deletions Tests/FFT/Poisson/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,27 @@ int main (int argc, char* argv[])
ParallelFor(res, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
{
auto const& phia = phi_ma[b];
auto lap = AMREX_D_TERM
(((phia(i-1,j,k)-2.*phia(i,j,k)+phia(i+1,j,k)) / (dx[0]*dx[0])),
+ ((phia(i,j-1,k)-2.*phia(i,j,k)+phia(i,j+1,k)) / (dx[1]*dx[1])),
+ ((phia(i,j,k-1)-2.*phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2])));
auto lap = (phia(i-1,j,k)-2._rt*phia(i,j,k)+phia(i+1,j,k)) / (dx[0]*dx[0]);
#if (AMREX_SPACEDIM >= 2)
lap += (phia(i,j-1,k)-2._rt*phia(i,j,k)+phia(i,j+1,k)) / (dx[1]*dx[1]);
#endif
#if (AMREX_SPACEDIM == 3)
if ((solver_type == 1) && (k == 0)) { // Neumann
lap += (-phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2]);
} else if ((solver_type == 1) && ((k+1) == n_cell_z)) { // Neumann
lap += (phia(i,j,k-1)-phia(i,j,k)) / (dx[2]*dx[2]);
} else {
lap += (phia(i,j,k-1)-2._rt*phia(i,j,k)+phia(i,j,k+1)) / (dx[2]*dx[2]);
}
#endif
res_ma[b](i,j,k) = rhs_ma[b](i,j,k) - lap;
});
auto bnorm = rhs.norminf();
auto rnorm = res.norminf();
amrex::Print() << " rhs inf norm " << bnorm << "\n"
<< " res inf norm " << rnorm << "\n";
#ifdef AMREX_USE_FLOAT
auto eps = 1.e-3f;
auto eps = 2.e-3f;
#else
auto eps = 1.e-11;
#endif
Expand Down

0 comments on commit 21c4761

Please sign in to comment.