Skip to content

Commit

Permalink
Refactor: add smooth threshold support for david method (#5697)
Browse files Browse the repository at this point in the history
* change raw pointer to std::vector

* add ethr_band for dav method

* change unit test

* fix build bug

* fix pyabacus

* fix pyabacus dav-subspace

* fix pyabacus build

* fix pyabacus

* add & for vector
  • Loading branch information
haozhihan authored Dec 10, 2024
1 parent de90fdf commit fd3a6d8
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 33 deletions.
6 changes: 3 additions & 3 deletions python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ class PyDiagoDavSubspace

int diag(
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
std::vector<double> precond_vec,
std::vector<double>& precond_vec,
int dav_ndim,
double tol,
int max_iter,
bool need_subspace,
std::vector<double> diag_ethr,
std::vector<double>& diag_ethr,
bool scf_type,
hsolver::diag_comm_info comm_info
) {
Expand Down Expand Up @@ -141,7 +141,7 @@ class PyDiagoDavSubspace
comm_info
);

return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr.data(), scf_type);
return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr, scf_type);
}

private:
Expand Down
5 changes: 3 additions & 2 deletions python/pyabacus/src/hsolver/py_diago_david.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ class PyDiagoDavid

int diag(
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
std::vector<double> precond_vec,
std::vector<double>& precond_vec,
int dav_ndim,
double tol,
std::vector<double>& diag_ethr,
int max_iter,
bool use_paw,
hsolver::diag_comm_info comm_info
Expand Down Expand Up @@ -146,7 +147,7 @@ class PyDiagoDavid
comm_info
);

return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, tol, max_iter);
return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, diag_ethr, max_iter);
}

private:
Expand Down
3 changes: 3 additions & 0 deletions python/pyabacus/src/hsolver/py_hsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ void bind_hsolver(py::module& m)
eigenvectors to be calculated.
tol : double
The tolerance for the convergence.
diag_ethr: np.ndarray
The tolerance vector.
max_iter : int
The maximum number of iterations.
use_paw : bool
Expand All @@ -130,6 +132,7 @@ void bind_hsolver(py::module& m)
"precond_vec"_a,
"dav_ndim"_a,
"tol"_a,
"diag_ethr"_a,
"max_iter"_a,
"use_paw"_a,
"comm_info"_a)
Expand Down
7 changes: 7 additions & 0 deletions python/pyabacus/src/pyabacus/hsolver/_hsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def davidson(
dav_ndim: int = 2,
tol: float = 1e-2,
max_iter: int = 1000,
diag_ethr: Union[List[float], None] = None,
use_paw: bool = False,
# scf_type: bool = False
) -> Tuple[NDArray[np.float64], NDArray[np.complex128]]:
Expand All @@ -143,6 +144,8 @@ def davidson(
The tolerance for the convergence, by default 1e-2.
max_iter : int, optional
The maximum number of iterations, by default 1000.
diag_ethr : List[float] | None, optional
The list of thresholds of bands, by default None.
use_paw : bool, optional
Whether to use projector augmented wave (PAW) method, by default False.
Expand All @@ -164,12 +167,16 @@ def davidson(
_diago_obj_david.init_eigenvalue()

comm_info = diag_comm_info(0, 1)

if diag_ethr is None:
diag_ethr = [tol] * num_eigs

_ = _diago_obj_david.diag(
mvv_op,
precondition,
dav_ndim,
tol,
diag_ethr,
max_iter,
use_paw,
comm_info
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
const double* ethr_band)
const std::vector<double>& ethr_band)
{
ModuleBase::timer::tick("Diago_DavSubspace", "diag_once");

Expand Down Expand Up @@ -726,7 +726,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
const double* ethr_band,
const std::vector<double>& ethr_band,
const bool& scf_type)
{
/// record the times of trying iterative diagonalization
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Diago_DavSubspace
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in,
const double* ethr_band,
const std::vector<double>& ethr_band,
const bool& scf_type);

private:
Expand Down Expand Up @@ -135,7 +135,7 @@ class Diago_DavSubspace
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in,
const double* ethr_band);
const std::vector<double>& ethr_band);

bool test_exit_cond(const int& ntry, const int& notconv, const bool& scf);

Expand Down
8 changes: 4 additions & 4 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
const int ld_psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const std::vector<double>& ethr_band,
const int david_maxiter)
{
if (test_david == 1)
Expand Down Expand Up @@ -273,7 +273,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
this->notconv = 0;
for (int m = 0; m < nband; m++)
{
convflag[m] = (std::abs(this->eigenvalue[m] - eigenvalue_in[m]) < david_diag_thr);
convflag[m] = (std::abs(this->eigenvalue[m] - eigenvalue_in[m]) < ethr_band[m]);
if (!convflag[m])
{
unconv[this->notconv] = m;
Expand Down Expand Up @@ -1177,7 +1177,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
const int ld_psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const std::vector<double>& ethr_band,
const int david_maxiter,
const int ntry_max,
const int notconv_max)
Expand All @@ -1189,7 +1189,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
int sum_dav_iter = 0;
do
{
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ld_psi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ld_psi, psi_in, eigenvalue_in, ethr_band, david_maxiter);
++ntry;
} while (!check_block_conv(ntry, this->notconv, ntry_max, notconv_max));

Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class DiagoDavid
const int ld_psi, // Leading dimension of the psi input
T *psi_in, // Pointer to eigenvectors
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
const std::vector<double>& ethr_band, // Convergence threshold for the Davidson iteration
const int david_maxiter, // Maximum allowed iterations for the Davidson method
const int ntry_max = 5, // Maximum number of diagonalization attempts (5 by default)
const int notconv_max = 0); // Maximum number of allowed non-converged eigenvectors
Expand Down Expand Up @@ -134,7 +134,7 @@ class DiagoDavid
const int ld_psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const std::vector<double>& ethr_band,
const int david_maxiter);

void cal_grad(const HPsiFunc& hpsi_func,
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
comm_info);

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band.data(), scf));
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
}
else if (this->method == "dav")
{
Expand Down Expand Up @@ -589,7 +589,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
ld_psi,
psi.get_pointer(),
eigenvalue,
david_diag_thr,
this->ethr_band,
david_maxiter,
ntry_max,
notconv_max));
Expand Down
10 changes: 6 additions & 4 deletions source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ class DiagoDavPrepare
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
phm->ops->hPsi(info);
};
auto spsi_func = [phm](const std::complex<float>* psi_in, std::complex<float>* spsi_out,const int ld_psi, const int nbands){
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
};
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
auto spsi_func = [phm](const std::complex<float>* psi_in,
std::complex<float>* spsi_out,
const int ld_psi,
const int nbands) { phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands); };
std::vector<double> ethr_band(phi.get_nbands(), eps);
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, ethr_band, maxiter);

#ifdef __MPI
end = MPI_Wtime();
Expand Down
7 changes: 4 additions & 3 deletions source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ class DiagoDavPrepare
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
phm->ops->hPsi(info);
};
auto spsi_func = [phm](const double* psi_in, double* spsi_out,const int ld_psi, const int nbands){
auto spsi_func = [phm](const double* psi_in, double* spsi_out, const int ld_psi, const int nbands) {
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
};
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
};
std::vector<double> ethr_band(phi.get_nbands(), eps);
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, ethr_band, maxiter);

#ifdef __MPI
end = MPI_Wtime();
Expand Down
10 changes: 6 additions & 4 deletions source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ class DiagoDavPrepare
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
phm->ops->hPsi(info);
};
auto spsi_func = [phm](const std::complex<double>* psi_in, std::complex<double>* spsi_out,const int ld_psi, const int nbands){
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
};
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
auto spsi_func = [phm](const std::complex<double>* psi_in,
std::complex<double>* spsi_out,
const int ld_psi,
const int nbands) { phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands); };
std::vector<double> ethr_band(phi.get_nbands(), eps);
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, ethr_band, maxiter);

#ifdef __MPI
end = MPI_Wtime();
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/test/hsolver_pw_sup.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ template <typename T, typename Device>
int DiagoDavid<T, Device>::diag(const std::function<void(T*, T*, const int, const int)>& hpsi_func,
const std::function<void(T*, T*, const int, const int)>& spsi_func,
const int ld_psi,
T *psi_in,
T* psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const std::vector<double>& ethr_band,
const int david_maxiter,
const int ntry_max,
const int notconv_max) {
Expand Down
12 changes: 9 additions & 3 deletions source/module_lr/hsolver_lrtd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,15 @@ namespace LR
// converged.
const int notconv_max = ("nscf" == PARAM.inp.calculation) ? 0 : 5;
// do diag and add davidson iteration counts up to avg_iter
hsolver::DiagoDavid<T> david(precondition.data(), nband, dim, PARAM.inp.pw_diag_ndim, PARAM.inp.use_paw, comm_info);
hsolver::DiagoDavid<T> david(precondition.data(),
nband,
dim,
PARAM.inp.pw_diag_ndim,
PARAM.inp.use_paw,
comm_info);
std::vector<double> ethr_band(nband, diag_ethr);
hsolver::DiagoIterAssist<T>::avg_iter += static_cast<double>(david.diag(hpsi_func, spsi_func,
dim, psi, eigenvalue.data(), diag_ethr, maxiter, ntry_max, 0));
dim, psi, eigenvalue.data(), ethr_band, maxiter, ntry_max, 0));
}
else if (method == "dav_subspace") //need refactor
{
Expand All @@ -102,7 +108,7 @@ namespace LR
hpsi_func, psi,
dim,
eigenvalue.data(),
ethr_band.data(),
ethr_band,
false /*scf*/));
}
else if (method == "cg")
Expand Down

0 comments on commit fd3a6d8

Please sign in to comment.