Skip to content

Commit

Permalink
Clean up one electron integral python code according to new convention (
Browse files Browse the repository at this point in the history
#263)

* Clean up according to new convention

* Support reuse VHFOpt, cleanup tests, return cupy instead of numpy array, cleanup esp and chelpg

* bugfix in xc_deriv2

* Code improvement

---------

Co-authored-by: xiaojie.wu <xiaojie.wu@bytedance.com>
Co-authored-by: Xiaojie Wu <wxj6000@gmail.com>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent d4dce92 commit 28323a5
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 283 deletions.
201 changes: 104 additions & 97 deletions gpu4pyscf/gto/moleintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,26 @@
from gpu4pyscf.scf.int4c2e import BasisProdCache
from gpu4pyscf.df.int3c2e import sort_mol, _split_l_ctr_groups, get_pairing
from gpu4pyscf.gto.mole import basis_seg_contraction
from gpu4pyscf.__config__ import _num_devices, _streams

BLKSIZE = 128

libgvhf = load_library('libgvhf')
libgint = load_library('libgint')

class VHFOpt(_vhf.VHFOpt):
def __init__(self, mol, intor, prescreen='CVHFnoscreen',
def __init__(self, mol, intor='int2e', prescreen='CVHFnoscreen',
qcondname='CVHFsetnr_direct_scf', dmcondname=None):
# use local basis_seg_contraction for efficiency
# TODO: switch _mol and mol
self.mol = basis_seg_contraction(mol,allow_replica=True)
self._mol = mol

'''
# Note mol._bas will be sorted in .build() method. VHFOpt should be
# initialized after mol._bas updated.
'''
self.nao = self.mol.nao
self.mol = mol
self._sorted_mol = None

self._intor = intor
self._prescreen = prescreen
self._qcondname = qcondname
self._dmcondname = dmcondname

self.bpcache = None

self.cart_ao_idx = None
self.sph_ao_idx = None

self.cart_ao_loc = []
self.sph_ao_loc = []
self.cart2sph = None

self.angular = None

Expand All @@ -69,8 +56,8 @@ def __init__(self, mol, intor, prescreen='CVHFnoscreen',

def clear(self):
_vhf.VHFOpt.__del__(self)
if self.bpcache is not None:
libgvhf.GINTdel_basis_prod(ctypes.byref(self.bpcache))
for n, bpcache in self._bpcache.items():
libgvhf.GINTdel_basis_prod(ctypes.byref(bpcache))
return self

def __del__(self):
Expand All @@ -79,20 +66,21 @@ def __del__(self):
except AttributeError:
pass

def build(self, cutoff=1e-14, group_size=None, diag_block_with_triu=False, aosym=False):
_mol = self._mol
mol = self.mol
def build(self, cutoff=1e-13, group_size=BLKSIZE, diag_block_with_triu=False, aosym=True):
original_mol = self.mol
mol = basis_seg_contraction(original_mol, allow_replica=True)

log = logger.new_logger(_mol, _mol.verbose)
log = logger.new_logger(original_mol, original_mol.verbose)
cput0 = log.init_timer()
sorted_mol, sorted_idx, uniq_l_ctr, l_ctr_counts = sort_mol(mol, log=log)
self.sorted_mol = sorted_mol
_sorted_mol, sorted_idx, uniq_l_ctr, l_ctr_counts = sort_mol(mol, log=log)
self._sorted_mol = _sorted_mol

if group_size is not None :
uniq_l_ctr, l_ctr_counts = _split_l_ctr_groups(uniq_l_ctr, l_ctr_counts, group_size)
self.nctr = len(uniq_l_ctr)
self.l_ctr_counts = l_ctr_counts

# Initialize vhfopt after reordering mol._bas
_vhf.VHFOpt.__init__(self, sorted_mol, self._intor, self._prescreen,
_vhf.VHFOpt.__init__(self, _sorted_mol, self._intor, self._prescreen,
self._qcondname, self._dmcondname)
self.direct_scf_tol = cutoff

Expand All @@ -104,39 +92,22 @@ def build(self, cutoff=1e-14, group_size=None, diag_block_with_triu=False, aosym
l_ctr_offsets, l_ctr_offsets, q_cond,
diag_block_with_triu=diag_block_with_triu, aosym=aosym)
self.log_qs = log_qs.copy()
cput1 = log.timer_debug1('Get pairing', *cput1)
cput1 = log.timer_debug1('Get AO pairing', *cput1)

# contraction coefficient for ao basis
cart_ao_loc = sorted_mol.ao_loc_nr(cart=True)
sph_ao_loc = sorted_mol.ao_loc_nr(cart=False)
cart_ao_loc = _sorted_mol.ao_loc_nr(cart=True)
sph_ao_loc = _sorted_mol.ao_loc_nr(cart=False)
self.cart_ao_loc = [cart_ao_loc[cp] for cp in l_ctr_offsets]
self.sph_ao_loc = [sph_ao_loc[cp] for cp in l_ctr_offsets]
self.angular = [l[0] for l in uniq_l_ctr]

cart_ao_loc = mol.ao_loc_nr(cart=True)
sph_ao_loc = mol.ao_loc_nr(cart=False)
nao = sph_ao_loc[-1]
ao_idx = np.array_split(np.arange(nao), sph_ao_loc[1:-1])
self.sph_ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])

# cartesian ao index
nao = cart_ao_loc[-1]
ao_idx = np.array_split(np.arange(nao), cart_ao_loc[1:-1])
self.cart_ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])
self.cart2sph = block_c2s_diag(self.angular, l_ctr_counts)
cput1 = log.timer_debug1('AO cart2sph coeff', *cput1)

if _mol.cart:
ncart = cart_ao_loc[-1]
inv_idx = np.argsort(self.cart_ao_idx, kind='stable').astype(np.int32)
self.coeff = cp.eye(ncart)[:,inv_idx]
else:
inv_idx = np.argsort(self.sph_ao_idx, kind='stable').astype(np.int32)
self.coeff = self.cart2sph[:, inv_idx]
cput1 = log.timer_debug1('AO cart2sph coeff', *cput1)
# Sorted AO indices
ao_loc = mol.ao_loc_nr(cart=original_mol.cart)
ao_idx = np.array_split(np.arange(original_mol.nao), ao_loc[1:-1])
self._ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])
cput1 = log.timer_debug1('AO indices', *cput1)

ao_loc = sorted_mol.ao_loc_nr(cart=True)
cput1 = log.timer_debug1('Get AO pairs', *cput1)
ao_loc = cart_ao_loc

self.pair2bra = pair2bra
self.pair2ket = pair2ket
Expand All @@ -156,16 +127,20 @@ def get_n_hermite_density_of_angular_pair(l):
n_density_per_angular_pair = (bas_pairs_locs[1:] - bas_pairs_locs[:-1]) * n_density_per_pair
self.density_offset = np.append(0, np.cumsum(n_density_per_angular_pair)).astype(np.int32)

self.bpcache = ctypes.POINTER(BasisProdCache)()
scale_shellpair_diag = 1.0
libgint.GINTinit_basis_prod(
ctypes.byref(self.bpcache), ctypes.c_double(scale_shellpair_diag),
ao_loc.ctypes.data_as(ctypes.c_void_p),
bas_pair2shls.ctypes.data_as(ctypes.c_void_p),
bas_pairs_locs.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(ncptype),
sorted_mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(sorted_mol.natm),
sorted_mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(sorted_mol.nbas),
sorted_mol._env.ctypes.data_as(ctypes.c_void_p))
self._bpcache = {}
for n in range(_num_devices):
with cp.cuda.Device(n), _streams[n]:
bpcache = ctypes.POINTER(BasisProdCache)()
scale_shellpair_diag = 1.0
libgint.GINTinit_basis_prod(
ctypes.byref(bpcache), ctypes.c_double(scale_shellpair_diag),
ao_loc.ctypes.data_as(ctypes.c_void_p),
bas_pair2shls.ctypes.data_as(ctypes.c_void_p),
bas_pairs_locs.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(ncptype),
_sorted_mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(_sorted_mol.natm),
_sorted_mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(_sorted_mol.nbas),
_sorted_mol._env.ctypes.data_as(ctypes.c_void_p))
self._bpcache[n] = bpcache

cput1 = log.timer_debug1('Initialize GPU cache', *cput1)
ncptype = len(self.log_qs)
Expand All @@ -176,22 +151,44 @@ def get_n_hermite_density_of_angular_pair(l):
nl = int(round(np.sqrt(ncptype)))
self.cp_idx, self.cp_jdx = np.unravel_index(np.arange(ncptype), (nl, nl))

if _mol.cart:
if original_mol.cart:
self.ao_loc = self.cart_ao_loc
self.ao_idx = self.cart_ao_idx
else:
self.ao_loc = self.sph_ao_loc
self.ao_idx = self.sph_ao_idx

def sort_orbitals(self, mat, axis=[]):
''' Transform given axis of a matrix into sorted AO,
and transform given auxiliary axis of a matrix into sorted auxiliary AO
'''
idx = self._ao_idx
shape_ones = (1,) * mat.ndim
fancy_index = []
for dim, n in enumerate(mat.shape):
if dim in axis:
assert n == len(idx)
indices = idx
else:
indices = np.arange(n)
idx_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
fancy_index.append(indices.reshape(idx_shape))
return mat[tuple(fancy_index)]

@property
def bpcache(self):
device_id = cp.cuda.Device().id
bpcache = self._bpcache[device_id]
return bpcache

@property
def cart2sph(self):
return block_c2s_diag(self.angular, self.l_ctr_counts)
# end of class VHFOpt


def get_int3c1e(mol, grids, direct_scf_tol):
def get_int3c1e(mol, grids, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

intopt = VHFOpt(mol, 'int2e')
intopt.build(direct_scf_tol, diag_block_with_triu=True, aosym=True, group_size=BLKSIZE)

nao = mol.nao
ngrids = grids.shape[0]
total_double_number = ngrids * nao * nao
Expand Down Expand Up @@ -221,7 +218,7 @@ def get_int3c1e(mol, grids, direct_scf_tol):
lj = intopt.angular[cpj]

stream = cp.cuda.get_current_stream()
nao_cart = intopt.mol.nao
nao_cart = intopt._sorted_mol.nao

log_q_ij = intopt.log_qs[cp_ij_id]

Expand Down Expand Up @@ -265,21 +262,18 @@ def get_int3c1e(mol, grids, direct_scf_tol):

row, col = np.tril_indices(nao)
int3c_grid_slice[:, row, col] = int3c_grid_slice[:, col, row]
ao_idx = np.argsort(intopt.ao_idx)
ao_idx = np.argsort(intopt._ao_idx)
grid_idx = np.arange(ngrids_of_split)
int3c_grid_slice = int3c_grid_slice[np.ix_(grid_idx, ao_idx, ao_idx)]

int3c_grid_slice.get(out = int3c[i_grid_split : i_grid_split + ngrids_of_split, :, :])

return int3c

def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):
def get_int3c1e_charge_contracted(mol, grids, charges, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

intopt = VHFOpt(mol, 'int2e')
intopt.build(direct_scf_tol, diag_block_with_triu=True, aosym=True, group_size=BLKSIZE)

nao = mol.nao

assert charges.ndim == 1 and charges.shape[0] == grids.shape[0]
Expand All @@ -296,7 +290,7 @@ def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):
lj = intopt.angular[cpj]

stream = cp.cuda.get_current_stream()
nao_cart = intopt.mol.nao
nao_cart = intopt._sorted_mol.nao

log_q_ij = intopt.log_qs[cp_ij_id]

Expand All @@ -312,8 +306,8 @@ def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):
strides = np.array([ni, ni*nj], dtype=np.int32)

ngrids = grids.shape[0]
# n_charge_sum_per_thread = 1 means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids or larger number gaurantees one thread processes one pair and all grid points
# n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points
n_charge_sum_per_thread = 10

int1e_angular_slice = cp.zeros([j1-j0, i1-i0], order='C')
Expand Down Expand Up @@ -346,26 +340,30 @@ def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):

row, col = np.tril_indices(nao)
int1e[row, col] = int1e[col, row]
ao_idx = np.argsort(intopt.ao_idx)
ao_idx = np.argsort(intopt._ao_idx)
int1e = int1e[np.ix_(ao_idx, ao_idx)]

return cp.asnumpy(int1e)
return int1e

def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):
def get_int3c1e_density_contracted(mol, grids, dm, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

dm = cp.asarray(dm)
if dm.ndim == 3:
if dm.shape[0] > 2:
print("Warning: There are more than two density matrices to contract with one electron integrals, "
"it's not from an unrestricted calculation, and we're unsure about your purpose. "
"We sum the density matrices up, please check if that's expected.")
dm = cp.einsum("ijk->jk", dm)

assert dm.ndim == 2
assert dm.shape[0] == dm.shape[1] and dm.shape[0] == mol.nao

intopt = VHFOpt(mol, 'int2e')
intopt.build(direct_scf_tol, diag_block_with_triu=False, aosym=True, group_size=BLKSIZE)

nao_cart = intopt.mol.nao
nao_cart = intopt._sorted_mol.nao
ngrids = grids.shape[0]

dm = dm[np.ix_(intopt.ao_idx, intopt.ao_idx)] # intopt.ao_idx is in spherical basis
dm = intopt.sort_orbitals(dm, [0,1])
if not mol.cart:
cart2sph_transformation_matrix = intopt.cart2sph
# TODO: This part is inefficient (O(N^3)), should be changed to the O(N^2) algorithm
Expand All @@ -374,9 +372,9 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):

dm = cp.asnumpy(dm)

ao_loc_sorted_order = intopt.sorted_mol.ao_loc_nr(cart = True)
ao_loc_sorted_order = intopt._sorted_mol.ao_loc_nr(cart = True)
l_ij = intopt.l_ij.T.flatten()
bas_coords = intopt.sorted_mol.atom_coords()[intopt.sorted_mol._bas[:, ATOM_OF]].flatten()
bas_coords = intopt._sorted_mol.atom_coords()[intopt._sorted_mol._bas[:, ATOM_OF]].flatten()

n_total_hermite_density = intopt.density_offset[-1]
dm_pair_ordered = np.zeros(n_total_hermite_density)
Expand Down Expand Up @@ -412,8 +410,8 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):
nbins = 1
bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32)

# n_pair_sum_per_thread = 1 means every thread processes one pair and one grid
# n_pair_sum_per_thread = nao_cart or larger number gaurantees one thread processes one grid and all pairs of the same type
# n_pair_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_pair_sum_per_thread = nao_cart # or larger number gaurantees one thread processes one grid and all pairs of the same type
n_pair_sum_per_thread = nao_cart

err = libgint.GINTfill_int3c1e_density_contracted(
Expand All @@ -433,19 +431,28 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):
if err != 0:
raise RuntimeError('GINTfill_int3c1e_density_contracted failed')

return cp.asnumpy(int3c_density_contracted)
return int3c_density_contracted

def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13):
assert intor == 'int1e_grids' and grids is not None
def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13, intopt=None):
assert intor == 'int1e_grids'
assert grids is not None
assert dm is None or charges is None, \
"Are you sure you want to contract the one electron integrals with both charge and density? " + \
"If so, pass in density, obtain the result with n_charge and contract with the charges yourself."

if intopt is None:
intopt = VHFOpt(mol)
intopt.build(direct_scf_tol)
else:
assert isinstance(intopt, VHFOpt), \
f"Please make sure intopt is a {VHFOpt.__module__}.{VHFOpt.__name__} object."
assert hasattr(intopt, "density_offset"), "Please call build() function for VHFOpt object first."

if dm is None and charges is None:
return get_int3c1e(mol, grids, direct_scf_tol)
return get_int3c1e(mol, grids, intopt)
elif dm is not None:
return get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol)
return get_int3c1e_density_contracted(mol, grids, dm, intopt)
elif charges is not None:
return get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol)
return get_int3c1e_charge_contracted(mol, grids, charges, intopt)
else:
raise ValueError(f"Logic error in {__file__} {__name__}")
Loading

0 comments on commit 28323a5

Please sign in to comment.