Skip to content

Commit

Permalink
Refactor: remove DiagoIterAssist<T, Device>::diagH_subspace from dav-…
Browse files Browse the repository at this point in the history
…subspace (deepmodeling#4470)

* remove DiagoIterAssist<T, Device>::diagH_subspace from dav-subspace

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
haozhihan and pre-commit-ci-lite[bot] authored Jun 22, 2024
1 parent 74c1664 commit 3c8764f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 98 deletions.
18 changes: 8 additions & 10 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
}

template <typename T, typename Device>
int Diago_DavSubspace<T, Device>::diag_once(const Func& hpsi_func,
int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
Expand Down Expand Up @@ -254,15 +254,15 @@ int Diago_DavSubspace<T, Device>::diag_once(const Func& hpsi_func,
}
}

} while (1);
} while (true);

ModuleBase::timer::tick("Diago_DavSubspace", "diag_once");

return dav_iter;
}

template <typename T, typename Device>
void Diago_DavSubspace<T, Device>::cal_grad(const Func& hpsi_func,
void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
const int& dim,
const int& nbase,
const int& notconv,
Expand Down Expand Up @@ -728,12 +728,10 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
}

template <typename T, typename Device>
int Diago_DavSubspace<T, Device>::diag(const Func& hpsi_func,
int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
const SubspaceFunc& subspace_func,
T* psi_in,

hamilt::Hamilt<T, Device>* phm_in,
psi::Psi<T, Device>& psi,

const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
const std::vector<bool>& is_occupied,
const bool& scf_type)
Expand Down Expand Up @@ -790,10 +788,10 @@ int Diago_DavSubspace<T, Device>::diag(const Func& hpsi_func,
{
if (this->is_subspace || ntry > 0)
{
DiagoIterAssist<T, Device>::diagH_subspace(phm_in, psi, psi, eigenvalue_in_hsolver, psi.get_nbands());
subspace_func(psi_in, psi_in, eigenvalue_in_hsolver, this->n_band, psi_in_dmax);
}

sum_iter += this->diag_once(hpsi_func, psi_in, psi.get_nbasis(), eigenvalue_in_hsolver, is_occupied);
sum_iter += this->diag_once(hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, is_occupied);

++ntry;

Expand Down
15 changes: 7 additions & 8 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ class Diago_DavSubspace : public DiagH<T, Device>

virtual ~Diago_DavSubspace() override;

using Func = std::function<void(T*, T*, const int, const int, const int, const int)>;
using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
using SubspaceFunc = std::function<void(T*, T*, Real*, const int, const int)>;

int diag(const Func& hpsi_func,
int diag(const HPsiFunc& hpsi_func,
const SubspaceFunc& subspace_func,
T* psi_in,

hamilt::Hamilt<T, Device>* phm_in,
psi::Psi<T, Device>& phi,

const int psi_in_dmax,
Real* eigenvalue_in,
const std::vector<bool>& is_occupied,
const bool& scf_type);
Expand Down Expand Up @@ -89,7 +88,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
base_device::DEVICE_CPU* cpu_ctx = {};
base_device::AbacusDevice_t device = {};

void cal_grad(const Func& hpsi_func,
void cal_grad(const HPsiFunc& hpsi_func,
const int& dim,
const int& nbase,
const int& notconv,
Expand Down Expand Up @@ -121,7 +120,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
bool init,
bool is_subspace);

int diag_once(const Func& hpsi_func,
int diag_once(const HPsiFunc& hpsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in,
Expand Down
157 changes: 77 additions & 80 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,34 @@ void HSolverPW<T, Device>::initDiagh(const psi::Psi<T, Device>& psi)
}
else if (this->method == "dav")
{
// #ifdef __MPI
// const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #else
// const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #endif

// if (this->pdiagh != nullptr)
// {
// if (this->pdiagh->method != this->method)
// {
// delete (DiagoDavid<T, Device>*)this->pdiagh;

// this->pdiagh = new DiagoDavid<T, Device>(precondition.data(),
// GlobalV::PW_DIAG_NDIM,
// GlobalV::use_paw,
// comm_info);

// this->pdiagh->method = this->method;
// }
// }
// else
// {
// this->pdiagh
// = new DiagoDavid<T, Device>(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);

// this->pdiagh->method = this->method;
// }
// #ifdef __MPI
// const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #else
// const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
// #endif

// if (this->pdiagh != nullptr)
// {
// if (this->pdiagh->method != this->method)
// {
// delete (DiagoDavid<T, Device>*)this->pdiagh;

// this->pdiagh = new DiagoDavid<T, Device>(precondition.data(),
// GlobalV::PW_DIAG_NDIM,
// GlobalV::use_paw,
// comm_info);

// this->pdiagh->method = this->method;
// }
// }
// else
// {
// this->pdiagh
// = new DiagoDavid<T, Device>(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw,
// comm_info);

// this->pdiagh->method = this->method;
// }
}
else if (this->method == "dav_subspace")
{
Expand Down Expand Up @@ -731,28 +732,28 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2");
// Convert a Tensor object to a psi::Psi object
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in.data<T>(),
1,
psi_in.shape().dim_size(0),
psi_in.shape().dim_size(1),
ngk_pointer);
1,
psi_in.shape().dim_size(0),
psi_in.shape().dim_size(1),
ngk_pointer);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
1,
psi_out.shape().dim_size(0),
psi_out.shape().dim_size(1),
ngk_pointer);
1,
psi_out.shape().dim_size(0),
psi_out.shape().dim_size(1),
ngk_pointer);
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_in.shape().dim_size(0)}));

DiagoIterAssist<T, Device>::diagH_subspace(hamilt_, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
};
DiagoCG<T, Device> cg(GlobalV::BASIS_TYPE,
GlobalV::CALCULATION,
DiagoIterAssist<T, Device>::need_subspace,
subspace_func,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
GlobalV::NPROC_IN_POOL);
GlobalV::CALCULATION,
DiagoIterAssist<T, Device>::need_subspace,
subspace_func,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
GlobalV::NPROC_IN_POOL);

// warp the hpsi_func and spsi_func into a lambda function
using ct_Device = typename ct::PsiToContainer<Device>::type;
Expand Down Expand Up @@ -833,18 +834,14 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif
Diago_DavSubspace<T, Device> dav_subspace(this->precondition,

psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),

GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);

// this->pdiagh->method = this->method;
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);

bool scf;
if (GlobalV::CALCULATION == "nscf")
Expand All @@ -858,14 +855,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P

auto ngk_pointer = psi.get_ngk_pointer();

std::function<void(T*, T*, const int, const int, const int, const int)> hpsi_func = [hm, ngk_pointer](
T* hpsi_out,
T* psi_in,
const int nband_in,
const int nbasis_in,
const int band_index1,
const int band_index2)
{
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
const int nband_in,
const int nbasis_in,
const int band_index1,
const int band_index2) {
ModuleBase::timer::tick("DavSubspace", "hpsi_func");

// Convert "pointer data stucture" to a psi::Psi object
Expand All @@ -880,20 +875,26 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
};

auto subspace_func = [hm, ngk_pointer](T* psi_out,
T* psi_in,
Real* eigenvalue_in_hsolver,
const int nband_in,
const int nbasis_max_in) {
// Convert "pointer data stucture" to a psi::Psi object
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in, 1, nband_in, nbasis_max_in, ngk_pointer);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out, 1, nband_in, nbasis_max_in, ngk_pointer);

DiagoIterAssist<T, Device>::diagH_subspace(hm,
psi_in_wrapper,
psi_out_wrapper,
eigenvalue_in_hsolver,
nband_in);
};

DiagoIterAssist<T, Device>::avg_iter
+= static_cast<double>(dav_subspace.diag(

hpsi_func,
psi.get_pointer(),

hm,
psi,
eigenvalue,
is_occupied,
scf));
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace
.diag(hpsi_func, subspace_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf));

// delete reinterpret_cast<Diago_DavSubspace<T, Device>*>(this->pdiagh);
this->pdiagh = nullptr;
}
else if (this->method == "bpcg")
Expand All @@ -918,14 +919,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
// do diag and add davidson iteration counts up to avg_iter
const Real david_diag_thr = DiagoIterAssist<T, Device>::PW_DIAG_THR;
const int david_maxiter = DiagoIterAssist<T, Device>::PW_DIAG_NMAX;

DiagoDavid<T, Device> david(precondition.data(),
GlobalV::PW_DIAG_NDIM,
GlobalV::use_paw,
comm_info);

DiagoDavid<T, Device> david(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
david.diag(hm, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max)
);
david.diag(hm, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max));
}
return;
}
Expand Down

0 comments on commit 3c8764f

Please sign in to comment.