From 515e3cb156d327dbe9972c4ab6cde0f1367b02fb Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Sat, 16 Mar 2024 18:25:31 -0700 Subject: [PATCH] Curl Curl Solver: Option to use PCG instead of LU --- Src/LinearSolvers/CMakeLists.txt | 1 + Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H | 3 + Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp | 30 ++- Src/LinearSolvers/MLMG/AMReX_MLCurlCurl_K.H | 282 ++++++++++++-------- Src/LinearSolvers/MLMG/AMReX_PCGSolver.H | 72 +++++ Src/LinearSolvers/MLMG/Make.package | 2 +- Tests/LinearSolvers/CurlCurl/MyTest.H | 1 + Tests/LinearSolvers/CurlCurl/MyTest.cpp | 3 + 8 files changed, 277 insertions(+), 117 deletions(-) create mode 100644 Src/LinearSolvers/MLMG/AMReX_PCGSolver.H diff --git a/Src/LinearSolvers/CMakeLists.txt b/Src/LinearSolvers/CMakeLists.txt index cae0b2028f0..6287ef4b422 100644 --- a/Src/LinearSolvers/CMakeLists.txt +++ b/Src/LinearSolvers/CMakeLists.txt @@ -21,6 +21,7 @@ foreach(D IN LISTS AMReX_SPACEDIM) MLMG/AMReX_MLCellABecLap_K.H MLMG/AMReX_MLCellABecLap_${D}D_K.H MLMG/AMReX_MLCGSolver.H + MLMG/AMReX_PCGSolver.H MLMG/AMReX_MLABecLaplacian.H MLMG/AMReX_MLABecLap_K.H MLMG/AMReX_MLABecLap_${D}D_K.H diff --git a/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H b/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H index 8d461d3bb04..ce8859eae11 100644 --- a/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H +++ b/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H @@ -58,6 +58,8 @@ public: return std::string("curl of curl"); } + bool setUsePCG (bool flag) { return std::exchange(m_use_pcg, flag); } + void setLevelBC (int amrlev, const MF* levelbcdata, const MF* robinbc_a = nullptr, const MF* robinbc_b = nullptr, @@ -137,6 +139,7 @@ private: Vector>>>> m_lusolver; Vector,3>>> m_bcoefs; + bool m_use_pcg = false; }; } diff --git a/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp b/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp index 740d8d5cc00..4167920f57b 100644 --- a/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp +++ b/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp @@ -353,22 +353,36 @@ void MLCurlCurl::smooth4 (int amrlev, int mglev, MF& sol, MF const& rhs, auto* plusolver = m_lusolver[amrlev][mglev]->dataPtr(); ParallelFor(nmf, [=] AMREX_GPU_DEVICE (int bno, int i, int j, int k) { - mlcurlcurl_gs4(i,j,k,ex[bno],ey[bno],ez[bno],rhsx[bno],rhsy[bno],rhsz[bno], + mlcurlcurl_gs4_lu(i,j,k,ex[bno],ey[bno],ez[bno], + rhsx[bno],rhsy[bno],rhsz[bno], #if (AMREX_SPACEDIM == 2) - b, + b, #endif - adxinv,color,*plusolver,dinfo,sinfo); + adxinv,color,*plusolver,dinfo,sinfo); }); } else { auto const& bcx = m_bcoefs[amrlev][mglev][0]->const_arrays(); auto const& bcy = m_bcoefs[amrlev][mglev][1]->const_arrays(); auto const& bcz = m_bcoefs[amrlev][mglev][2]->const_arrays(); - ParallelFor(nmf, [=] AMREX_GPU_DEVICE (int bno, int i, int j, int k) - { + if (m_use_pcg) { + ParallelFor(nmf, [=] AMREX_GPU_DEVICE (int bno, int i, int j, int k) + { - mlcurlcurl_gs4(i,j,k,ex[bno],ey[bno],ez[bno],rhsx[bno],rhsy[bno],rhsz[bno], - adxinv,color,bcx[bno],bcy[bno],bcz[bno],dinfo,sinfo); - }); + mlcurlcurl_gs4(i,j,k,ex[bno],ey[bno],ez[bno], + rhsx[bno],rhsy[bno],rhsz[bno], + adxinv,color,bcx[bno],bcy[bno],bcz[bno], + dinfo,sinfo); + }); + } else { + ParallelFor(nmf, [=] AMREX_GPU_DEVICE (int bno, int i, int j, int k) + { + + mlcurlcurl_gs4(i,j,k,ex[bno],ey[bno],ez[bno], + rhsx[bno],rhsy[bno],rhsz[bno], + adxinv,color,bcx[bno],bcy[bno],bcz[bno], + dinfo,sinfo); + }); + } } Gpu::streamSynchronize(); } diff --git a/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl_K.H b/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl_K.H index 0c1118f7dd3..e243b245f51 100644 --- a/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl_K.H +++ b/Src/LinearSolvers/MLMG/AMReX_MLCurlCurl_K.H @@ -4,6 +4,7 @@ #include #include +#include namespace amrex { @@ -427,20 +428,20 @@ void mlcurlcurl_adotx_z (int i, int j, int k, Array4 const& Az, } AMREX_GPU_DEVICE AMREX_FORCE_INLINE -void mlcurlcurl_gs4 (int i, int j, int k, - Array4 const& ex, - Array4 const& ey, - Array4 const& ez, - Array4 const& rhsx, - Array4 const& rhsy, - Array4 const& rhsz, +void mlcurlcurl_gs4_lu (int i, int j, int k, + Array4 const& ex, + Array4 const& ey, + Array4 const& ez, + Array4 const& rhsx, + Array4 const& rhsy, + Array4 const& rhsz, #if (AMREX_SPACEDIM == 2) - Real beta, + Real beta, #endif - GpuArray const& adxinv, - int color, LUSolver const& lusolver, - CurlCurlDirichletInfo const& dinfo, - CurlCurlSymmetryInfo const& sinfo) + GpuArray const& adxinv, + int color, LUSolver const& lusolver, + CurlCurlDirichletInfo const& dinfo, + CurlCurlSymmetryInfo const& sinfo) { if (dinfo.is_dirichlet_node(i,j,k)) { return; } @@ -598,6 +599,7 @@ void mlcurlcurl_gs4 (int i, int j, int k, #endif } +template AMREX_GPU_DEVICE AMREX_FORCE_INLINE void mlcurlcurl_gs4 (int i, int j, int k, Array4 const& ex, @@ -661,55 +663,81 @@ void mlcurlcurl_gs4 (int i, int j, int k, + dxy * (-ex(i-1,j+1,k ) +ex(i ,j+1,k )))}; - GpuArray x; + GpuArray beta; if (sinfo.xlo_is_symmetric(i)) { b[0] = -b[1]; - x[0] = x[1] = betax(i,j,k); + beta[0] = beta[1] = betax(i,j,k); } else if (sinfo.xhi_is_symmetric(i)) { b[1] = -b[0]; - x[0] = x[1] = betax(i-1,j,k); + beta[0] = beta[1] = betax(i-1,j,k); } else { - x[0] = betax(i-1,j,k); - x[1] = betax(i ,j,k); + beta[0] = betax(i-1,j,k); + beta[1] = betax(i ,j,k); } if (sinfo.ylo_is_symmetric(j)) { b[2] = -b[3]; - x[2] = x[3] = betay(i,j,k); + beta[2] = beta[3] = betay(i,j,k); } else if (sinfo.yhi_is_symmetric(j)) { b[3] = -b[2]; - x[2] = x[3] = betay(i,j-1,k); + beta[2] = beta[3] = betay(i,j-1,k); } else { - x[2] = betay(i,j-1,k); - x[3] = betay(i,j ,k); + beta[2] = betay(i,j-1,k); + beta[3] = betay(i,j ,k); } - LUSolver<4,Real> lusolver - ({dyy*Real(2.0) + x[0], - Real(0.0), - -dxy, - dxy, - // - Real(0.0), - dyy*Real(2.0) + x[1], - dxy, - -dxy, - // - -dxy, - dxy, - dxx*Real(2.0) + x[2], - Real(0.0), - // - dxy, - -dxy, - Real(0.0), - dxx*Real(2.0) + x[3]}); - lusolver(x.data(), b.data()); - ex(i-1,j ,k ) = x[0]; - ex(i ,j ,k ) = x[1]; - ey(i ,j-1,k ) = x[2]; - ey(i ,j ,k ) = x[3]; + if constexpr (PCG) { + Real diagInv[4] = {Real(1.0) / (dyy*Real(2.0) + beta[0]), + Real(1.0) / (dyy*Real(2.0) + beta[1]), + Real(1.0) / (dxx*Real(2.0) + beta[2]), + Real(1.0) / (dxx*Real(2.0) + beta[3])}; + auto precond = [&] (Real * AMREX_RESTRICT z, + Real const* AMREX_RESTRICT r) + { + for (int m = 0; m < 4; ++m) { z[m] = r[m] * diagInv[m]; } + }; + auto mat = [&] (Real * AMREX_RESTRICT Av, + Real const* AMREX_RESTRICT v) + { + Av[0] = (dyy*Real(2.0) + beta[0]) * v[0] - dxy * v[2] + dxy * v[3]; + Av[1] = (dyy*Real(2.0) + beta[1]) * v[1] + dxy * v[2] - dxy * v[3]; + Av[2] = -dxy * v[0] + dxy * v[1] + (dxx*Real(2.0) + beta[2]) * v[2]; + Av[3] = dxy * v[0] - dxy * v[1] + (dxx*Real(2.0) + beta[3]) * v[3]; + }; + Real sol[4] = {0, 0, 0, 0}; + pcg_solve<4>(sol, b.data(), mat, precond, 8, Real(1.e-8)); + ex(i-1,j ,k ) = sol[0]; + ex(i ,j ,k ) = sol[1]; + ey(i ,j-1,k ) = sol[2]; + ey(i ,j ,k ) = sol[3]; + } else { + LUSolver<4,Real> lusolver + ({dyy*Real(2.0) + beta[0], + Real(0.0), + -dxy, + dxy, + // + Real(0.0), + dyy*Real(2.0) + beta[1], + dxy, + -dxy, + // + -dxy, + dxy, + dxx*Real(2.0) + beta[2], + Real(0.0), + // + dxy, + -dxy, + Real(0.0), + dxx*Real(2.0) + beta[3]}); + lusolver(beta.data(), b.data()); + ex(i-1,j ,k ) = beta[0]; + ex(i ,j ,k ) = beta[1]; + ey(i ,j-1,k ) = beta[2]; + ey(i ,j ,k ) = beta[3]; + } #else @@ -772,90 +800,128 @@ void mlcurlcurl_gs4 (int i, int j, int k, + dyz * (-ey(i ,j-1,k+1) +ey(i ,j ,k+1)))}; - GpuArray x; + GpuArray beta; if (sinfo.xlo_is_symmetric(i)) { b[0] = -b[1]; - x[0] = x[1] = betax(i,j,k); + beta[0] = beta[1] = betax(i,j,k); } else if (sinfo.xhi_is_symmetric(i)) { b[1] = -b[0]; - x[0] = x[1] = betax(i-1,j,k); + beta[0] = beta[1] = betax(i-1,j,k); } else { - x[0] = betax(i-1,j,k); - x[1] = betax(i ,j,k); + beta[0] = betax(i-1,j,k); + beta[1] = betax(i ,j,k); } if (sinfo.ylo_is_symmetric(j)) { b[2] = -b[3]; - x[2] = x[3] = betay(i,j,k); + beta[2] = beta[3] = betay(i,j,k); } else if (sinfo.yhi_is_symmetric(j)) { b[3] = -b[2]; - x[2] = x[3] = betay(i,j-1,k); + beta[2] = beta[3] = betay(i,j-1,k); } else { - x[2] = betay(i,j-1,k); - x[3] = betay(i,j ,k); + beta[2] = betay(i,j-1,k); + beta[3] = betay(i,j ,k); } if (sinfo.zlo_is_symmetric(k)) { b[4] = -b[5]; - x[4] = x[5] = betaz(i,j,k); + beta[4] = beta[5] = betaz(i,j,k); } else if (sinfo.zhi_is_symmetric(k)) { b[5] = -b[4]; - x[4] = x[5] = betaz(i,j,k-1); + beta[4] = beta[5] = betaz(i,j,k-1); } else { - x[4] = betaz(i,j,k-1); - x[5] = betaz(i,j,k ); + beta[4] = betaz(i,j,k-1); + beta[5] = betaz(i,j,k ); } - LUSolver<6,Real> lusolver - ({(dyy+dzz)*Real(2.0) + x[0], - Real(0.0), - -dxy, - dxy, - -dxz, - dxz, - // - Real(0.0), - (dyy+dzz)*Real(2.0) + x[1], - dxy, - -dxy, - dxz, - -dxz, - // - -dxy, - dxy, - (dxx+dzz)*Real(2.0) + x[2], - Real(0.0), - -dyz, - dyz, - // - dxy, - -dxy, - Real(0.0), - (dxx+dzz)*Real(2.0) + x[3], - dyz, - -dyz, - // - -dxz, - dxz, - -dyz, - dyz, - (dxx+dyy)*Real(2.0) + x[4], - Real(0.0), - // - dxz, - -dxz, - dyz, - -dyz, - Real(0.0), - (dxx+dyy)*Real(2.0) + x[5]}); - lusolver(x.data(), b.data()); - ex(i-1,j ,k ) = x[0]; - ex(i ,j ,k ) = x[1]; - ey(i ,j-1,k ) = x[2]; - ey(i ,j ,k ) = x[3]; - ez(i ,j ,k-1) = x[4]; - ez(i ,j ,k ) = x[5]; + if constexpr (PCG) { + Real diagInv[6] = {Real(1.0) / ((dyy+dzz)*Real(2.0) + beta[0]), + Real(1.0) / ((dyy+dzz)*Real(2.0) + beta[1]), + Real(1.0) / ((dxx+dzz)*Real(2.0) + beta[2]), + Real(1.0) / ((dxx+dzz)*Real(2.0) + beta[3]), + Real(1.0) / ((dxx+dyy)*Real(2.0) + beta[4]), + Real(1.0) / ((dxx+dyy)*Real(2.0) + beta[5])}; + auto precond = [&] (Real * AMREX_RESTRICT z, + Real const* AMREX_RESTRICT r) + { + for (int m = 0; m < 6; ++m) { z[m] = r[m] * diagInv[m]; } + }; + auto mat = [&] (Real * AMREX_RESTRICT Av, + Real const* AMREX_RESTRICT v) + { + Av[0] = ((dyy+dzz)*Real(2.0) + beta[0]) * v[0] - dxy * v[2] + + dxy * v[3] - dxz * v[4] + dxz * v[5]; + Av[1] = ((dyy+dzz)*Real(2.0) + beta[1]) * v[1] + dxy * v[2] + - dxy * v[3] + dxz * v[4] - dxz * v[5]; + Av[2] = -dxy * v[0] + dxy * v[1] + ((dxx+dzz)*Real(2.0) + beta[2]) * v[2] + - dyz * v[4] + dyz * v[5]; + Av[3] = dxy * v[0] - dxy * v[1] + ((dxx+dzz)*Real(2.0) + beta[3]) * v[3] + + dyz * v[4] - dyz * v[5]; + Av[4] = -dxz * v[0] + dxz * v[1] - dyz * v[2] + dyz * v[3] + + ((dxx+dyy)*Real(2.0) + beta[4]) * v[4]; + Av[5] = dxz * v[0] - dxz * v[1] + dyz * v[2] - dyz * v[3] + + ((dxx+dyy)*Real(2.0) + beta[5]) * v[5]; + }; + Real sol[6] = {0, 0, 0, 0, 0, 0}; + pcg_solve<6>(sol, b.data(), mat, precond, 8, Real(1.e-8)); + ex(i-1,j ,k ) = sol[0]; + ex(i ,j ,k ) = sol[1]; + ey(i ,j-1,k ) = sol[2]; + ey(i ,j ,k ) = sol[3]; + ez(i ,j ,k-1) = sol[4]; + ez(i ,j ,k ) = sol[5]; + } else { + LUSolver<6,Real> lusolver + ({(dyy+dzz)*Real(2.0) + beta[0], + Real(0.0), + -dxy, + dxy, + -dxz, + dxz, + // + Real(0.0), + (dyy+dzz)*Real(2.0) + beta[1], + dxy, + -dxy, + dxz, + -dxz, + // + -dxy, + dxy, + (dxx+dzz)*Real(2.0) + beta[2], + Real(0.0), + -dyz, + dyz, + // + dxy, + -dxy, + Real(0.0), + (dxx+dzz)*Real(2.0) + beta[3], + dyz, + -dyz, + // + -dxz, + dxz, + -dyz, + dyz, + (dxx+dyy)*Real(2.0) + beta[4], + Real(0.0), + // + dxz, + -dxz, + dyz, + -dyz, + Real(0.0), + (dxx+dyy)*Real(2.0) + beta[5]}); + lusolver(beta.data(), b.data()); + ex(i-1,j ,k ) = beta[0]; + ex(i ,j ,k ) = beta[1]; + ey(i ,j-1,k ) = beta[2]; + ey(i ,j ,k ) = beta[3]; + ez(i ,j ,k-1) = beta[4]; + ez(i ,j ,k ) = beta[5]; + } #endif } diff --git a/Src/LinearSolvers/MLMG/AMReX_PCGSolver.H b/Src/LinearSolvers/MLMG/AMReX_PCGSolver.H new file mode 100644 index 00000000000..de3a6cdede7 --- /dev/null +++ b/Src/LinearSolvers/MLMG/AMReX_PCGSolver.H @@ -0,0 +1,72 @@ +#ifndef AMREX_PCG_SOLVER_H_ +#define AMREX_PCG_SOLVER_H_ +#include + +#include +#include +#include +#include + +namespace amrex { + +/** + * \brief Preconditioned conjugate gradient solver + * + * \param x initial guess + * \param r initial residual + * \param mat matrix + * \param precond preconditioner + * \param maxiter max number of iterations + * \param rel_tol relative tolerance + */ +template +AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE +int pcg_solve (T* AMREX_RESTRICT x, T* AMREX_RESTRICT r, + M const& mat, P const& precond, int maxiter, T rel_tol) +{ + static_assert(std::is_floating_point_v); + + T rnorm0 = 0; + for (int i = 0; i < N; ++i) { + rnorm0 = std::max(rnorm0, std::abs(r[i])); + } + if (rnorm0 == 0) { return 0; } + + int iter = 0; + T rho_prev = T(1.0); // initialized to quite gcc warning + T p[N]; + for (iter = 1; iter <= maxiter; ++iter) { + T z[N]; + precond(z, r); + T rho = 0; + for (int i = 0; i < N; ++i) { rho += r[i]*z[i]; } + if (rho == 0) { break; } + if (iter == 1) { + for (int i = 0; i < N; ++i) { p[i] = z[i]; } + } else { + auto rr = rho * (T(1.0)/rho_prev); + for (int i = 0; i < N; ++i) { + p[i] = z[i] + rr * p[i]; + } + } + T q[N]; + mat(q, p); + T pq = 0; + for (int i = 0; i < N; ++i) { pq += p[i]*q[i]; } + if (pq == 0) { break; } + T alpha = rho * (T(1.0)/pq); + T rnorm = 0; + for (int i = 0; i < N; ++i) { + x[i] += alpha * p[i]; + r[i] -= alpha * q[i]; + rnorm = std::max(rnorm, std::abs(r[i])); + } + if (rnorm <= rnorm0*rel_tol) { break; } + rho_prev = rho; + } + return iter; +} + +} + +#endif diff --git a/Src/LinearSolvers/MLMG/Make.package b/Src/LinearSolvers/MLMG/Make.package index a8f267d4c26..9496f3edc5a 100644 --- a/Src/LinearSolvers/MLMG/Make.package +++ b/Src/LinearSolvers/MLMG/Make.package @@ -22,7 +22,7 @@ CEXE_sources += AMReX_MLNodeLinOp.cpp CEXE_headers += AMReX_MLCellABecLap.H CEXE_headers += AMReX_MLCellABecLap_K.H AMReX_MLCellABecLap_$(DIM)D_K.H -CEXE_headers += AMReX_MLCGSolver.H +CEXE_headers += AMReX_MLCGSolver.H AMReX_PCGSolver.H CEXE_headers += AMReX_MLABecLaplacian.H CEXE_headers += AMReX_MLABecLap_K.H AMReX_MLABecLap_$(DIM)D_K.H diff --git a/Tests/LinearSolvers/CurlCurl/MyTest.H b/Tests/LinearSolvers/CurlCurl/MyTest.H index 73b260b470e..618c9689fb4 100644 --- a/Tests/LinearSolvers/CurlCurl/MyTest.H +++ b/Tests/LinearSolvers/CurlCurl/MyTest.H @@ -30,6 +30,7 @@ private: bool consolidation = true; int max_coarsening_level = 30; + bool use_pcg = false; bool use_gmres = false; bool gmres_use_precond = true; diff --git a/Tests/LinearSolvers/CurlCurl/MyTest.cpp b/Tests/LinearSolvers/CurlCurl/MyTest.cpp index af0544db396..0daf07cf3ef 100644 --- a/Tests/LinearSolvers/CurlCurl/MyTest.cpp +++ b/Tests/LinearSolvers/CurlCurl/MyTest.cpp @@ -46,6 +46,8 @@ MyTest::solve () } mlcc.prepareRHS({&rhs}); + if (use_pcg) { mlcc.setUsePCG(true); } + using V = Array; MLMGT mlmg(mlcc); mlmg.setMaxIter(max_iter); @@ -104,6 +106,7 @@ MyTest::readParameters () pp.query("consolidation", consolidation); pp.query("max_coarsening_level", max_coarsening_level); + pp.query("use_pcg", use_pcg); pp.query("use_gmres", use_gmres); pp.query("gmres_use_precond", gmres_use_precond);