From 3c8764fb1d93303956acd44d37e11afa83e199dc Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Sat, 22 Jun 2024 21:18:19 +0800 Subject: [PATCH] Refactor: remove DiagoIterAssist::diagH_subspace from dav-subspace (#4470) * remove DiagoIterAssist::diagH_subspace from dav-subspace * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> --- source/module_hsolver/diago_dav_subspace.cpp | 18 +-- source/module_hsolver/diago_dav_subspace.h | 15 +- source/module_hsolver/hsolver_pw.cpp | 157 +++++++++---------- 3 files changed, 92 insertions(+), 98 deletions(-) diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 83ad2daf9b..4cca1365a1 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -83,7 +83,7 @@ Diago_DavSubspace::~Diago_DavSubspace() } template -int Diago_DavSubspace::diag_once(const Func& hpsi_func, +int Diago_DavSubspace::diag_once(const HPsiFunc& hpsi_func, T* psi_in, const int psi_in_dmax, Real* eigenvalue_in_hsolver, @@ -254,7 +254,7 @@ int Diago_DavSubspace::diag_once(const Func& hpsi_func, } } - } while (1); + } while (true); ModuleBase::timer::tick("Diago_DavSubspace", "diag_once"); @@ -262,7 +262,7 @@ int Diago_DavSubspace::diag_once(const Func& hpsi_func, } template -void Diago_DavSubspace::cal_grad(const Func& hpsi_func, +void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, const int& dim, const int& nbase, const int& notconv, @@ -728,12 +728,10 @@ void Diago_DavSubspace::refresh(const int& dim, } template -int Diago_DavSubspace::diag(const Func& hpsi_func, +int Diago_DavSubspace::diag(const HPsiFunc& hpsi_func, + const SubspaceFunc& subspace_func, T* psi_in, - - hamilt::Hamilt* phm_in, - psi::Psi& psi, - + const int psi_in_dmax, Real* eigenvalue_in_hsolver, const std::vector& is_occupied, const bool& scf_type) @@ -790,10 +788,10 @@ int Diago_DavSubspace::diag(const Func& hpsi_func, { if (this->is_subspace || ntry > 0) { - DiagoIterAssist::diagH_subspace(phm_in, psi, psi, eigenvalue_in_hsolver, psi.get_nbands()); + subspace_func(psi_in, psi_in, eigenvalue_in_hsolver, this->n_band, psi_in_dmax); } - sum_iter += this->diag_once(hpsi_func, psi_in, psi.get_nbasis(), eigenvalue_in_hsolver, is_occupied); + sum_iter += this->diag_once(hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, is_occupied); ++ntry; diff --git a/source/module_hsolver/diago_dav_subspace.h b/source/module_hsolver/diago_dav_subspace.h index 880962827c..00ce15403b 100644 --- a/source/module_hsolver/diago_dav_subspace.h +++ b/source/module_hsolver/diago_dav_subspace.h @@ -29,14 +29,13 @@ class Diago_DavSubspace : public DiagH virtual ~Diago_DavSubspace() override; - using Func = std::function; + using HPsiFunc = std::function; + using SubspaceFunc = std::function; - int diag(const Func& hpsi_func, + int diag(const HPsiFunc& hpsi_func, + const SubspaceFunc& subspace_func, T* psi_in, - - hamilt::Hamilt* phm_in, - psi::Psi& phi, - + const int psi_in_dmax, Real* eigenvalue_in, const std::vector& is_occupied, const bool& scf_type); @@ -89,7 +88,7 @@ class Diago_DavSubspace : public DiagH base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; - void cal_grad(const Func& hpsi_func, + void cal_grad(const HPsiFunc& hpsi_func, const int& dim, const int& nbase, const int& notconv, @@ -121,7 +120,7 @@ class Diago_DavSubspace : public DiagH bool init, bool is_subspace); - int diag_once(const Func& hpsi_func, + int diag_once(const HPsiFunc& hpsi_func, T* psi_in, const int psi_in_dmax, Real* eigenvalue_in, diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index ff4af28080..1fbe846d22 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -86,33 +86,34 @@ void HSolverPW::initDiagh(const psi::Psi& psi) } else if (this->method == "dav") { -// #ifdef __MPI -// const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; -// #else -// const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; -// #endif - -// if (this->pdiagh != nullptr) -// { -// if (this->pdiagh->method != this->method) -// { -// delete (DiagoDavid*)this->pdiagh; - -// this->pdiagh = new DiagoDavid(precondition.data(), -// GlobalV::PW_DIAG_NDIM, -// GlobalV::use_paw, -// comm_info); - -// this->pdiagh->method = this->method; -// } -// } -// else -// { -// this->pdiagh -// = new DiagoDavid(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info); - -// this->pdiagh->method = this->method; -// } + // #ifdef __MPI + // const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; + // #else + // const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; + // #endif + + // if (this->pdiagh != nullptr) + // { + // if (this->pdiagh->method != this->method) + // { + // delete (DiagoDavid*)this->pdiagh; + + // this->pdiagh = new DiagoDavid(precondition.data(), + // GlobalV::PW_DIAG_NDIM, + // GlobalV::use_paw, + // comm_info); + + // this->pdiagh->method = this->method; + // } + // } + // else + // { + // this->pdiagh + // = new DiagoDavid(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, + // comm_info); + + // this->pdiagh->method = this->method; + // } } else if (this->method == "dav_subspace") { @@ -731,15 +732,15 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2"); // Convert a Tensor object to a psi::Psi object auto psi_in_wrapper = psi::Psi(psi_in.data(), - 1, - psi_in.shape().dim_size(0), - psi_in.shape().dim_size(1), - ngk_pointer); + 1, + psi_in.shape().dim_size(0), + psi_in.shape().dim_size(1), + ngk_pointer); auto psi_out_wrapper = psi::Psi(psi_out.data(), - 1, - psi_out.shape().dim_size(0), - psi_out.shape().dim_size(1), - ngk_pointer); + 1, + psi_out.shape().dim_size(0), + psi_out.shape().dim_size(1), + ngk_pointer); auto eigen = ct::Tensor(ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({psi_in.shape().dim_size(0)})); @@ -747,12 +748,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P DiagoIterAssist::diagH_subspace(hamilt_, psi_in_wrapper, psi_out_wrapper, eigen.data()); }; DiagoCG cg(GlobalV::BASIS_TYPE, - GlobalV::CALCULATION, - DiagoIterAssist::need_subspace, - subspace_func, - DiagoIterAssist::PW_DIAG_THR, - DiagoIterAssist::PW_DIAG_NMAX, - GlobalV::NPROC_IN_POOL); + GlobalV::CALCULATION, + DiagoIterAssist::need_subspace, + subspace_func, + DiagoIterAssist::PW_DIAG_THR, + DiagoIterAssist::PW_DIAG_NMAX, + GlobalV::NPROC_IN_POOL); // warp the hpsi_func and spsi_func into a lambda function using ct_Device = typename ct::PsiToContainer::type; @@ -833,18 +834,14 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; #endif Diago_DavSubspace dav_subspace(this->precondition, - - psi.get_nbands(), - psi.get_k_first() ? psi.get_current_nbas() - : psi.get_nk() * psi.get_nbasis(), - - GlobalV::PW_DIAG_NDIM, - DiagoIterAssist::PW_DIAG_THR, - DiagoIterAssist::PW_DIAG_NMAX, - DiagoIterAssist::need_subspace, - comm_info); - - // this->pdiagh->method = this->method; + psi.get_nbands(), + psi.get_k_first() ? psi.get_current_nbas() + : psi.get_nk() * psi.get_nbasis(), + GlobalV::PW_DIAG_NDIM, + DiagoIterAssist::PW_DIAG_THR, + DiagoIterAssist::PW_DIAG_NMAX, + DiagoIterAssist::need_subspace, + comm_info); bool scf; if (GlobalV::CALCULATION == "nscf") @@ -858,14 +855,12 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P auto ngk_pointer = psi.get_ngk_pointer(); - std::function hpsi_func = [hm, ngk_pointer]( - T* hpsi_out, - T* psi_in, - const int nband_in, - const int nbasis_in, - const int band_index1, - const int band_index2) - { + auto hpsi_func = [hm, ngk_pointer](T* hpsi_out, + T* psi_in, + const int nband_in, + const int nbasis_in, + const int band_index1, + const int band_index2) { ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object @@ -880,20 +875,26 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P ModuleBase::timer::tick("DavSubspace", "hpsi_func"); }; + auto subspace_func = [hm, ngk_pointer](T* psi_out, + T* psi_in, + Real* eigenvalue_in_hsolver, + const int nband_in, + const int nbasis_max_in) { + // Convert "pointer data stucture" to a psi::Psi object + auto psi_in_wrapper = psi::Psi(psi_in, 1, nband_in, nbasis_max_in, ngk_pointer); + auto psi_out_wrapper = psi::Psi(psi_out, 1, nband_in, nbasis_max_in, ngk_pointer); + + DiagoIterAssist::diagH_subspace(hm, + psi_in_wrapper, + psi_out_wrapper, + eigenvalue_in_hsolver, + nband_in); + }; - DiagoIterAssist::avg_iter - += static_cast(dav_subspace.diag( - - hpsi_func, - psi.get_pointer(), - - hm, - psi, - eigenvalue, - is_occupied, - scf)); + DiagoIterAssist::avg_iter += static_cast( + dav_subspace + .diag(hpsi_func, subspace_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf)); - // delete reinterpret_cast*>(this->pdiagh); this->pdiagh = nullptr; } else if (this->method == "bpcg") @@ -918,14 +919,10 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P // do diag and add davidson iteration counts up to avg_iter const Real david_diag_thr = DiagoIterAssist::PW_DIAG_THR; const int david_maxiter = DiagoIterAssist::PW_DIAG_NMAX; - - DiagoDavid david(precondition.data(), - GlobalV::PW_DIAG_NDIM, - GlobalV::use_paw, - comm_info); + + DiagoDavid david(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info); DiagoIterAssist::avg_iter += static_cast( - david.diag(hm, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max) - ); + david.diag(hm, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max)); } return; }