Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up one electron integral python code according to new convention #263

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 104 additions & 97 deletions gpu4pyscf/gto/moleintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,26 @@
from gpu4pyscf.scf.int4c2e import BasisProdCache
from gpu4pyscf.df.int3c2e import sort_mol, _split_l_ctr_groups, get_pairing
from gpu4pyscf.gto.mole import basis_seg_contraction
from gpu4pyscf.__config__ import _num_devices, _streams

BLKSIZE = 128

libgvhf = load_library('libgvhf')
libgint = load_library('libgint')

class VHFOpt(_vhf.VHFOpt):
def __init__(self, mol, intor, prescreen='CVHFnoscreen',
def __init__(self, mol, intor='int2e', prescreen='CVHFnoscreen',
qcondname='CVHFsetnr_direct_scf', dmcondname=None):
# use local basis_seg_contraction for efficiency
# TODO: switch _mol and mol
self.mol = basis_seg_contraction(mol,allow_replica=True)
self._mol = mol

'''
# Note mol._bas will be sorted in .build() method. VHFOpt should be
# initialized after mol._bas updated.
'''
self.nao = self.mol.nao
self.mol = mol
self._sorted_mol = None

self._intor = intor
self._prescreen = prescreen
self._qcondname = qcondname
self._dmcondname = dmcondname

self.bpcache = None

self.cart_ao_idx = None
self.sph_ao_idx = None

self.cart_ao_loc = []
self.sph_ao_loc = []
self.cart2sph = None

self.angular = None

Expand All @@ -69,8 +56,8 @@ def __init__(self, mol, intor, prescreen='CVHFnoscreen',

def clear(self):
_vhf.VHFOpt.__del__(self)
if self.bpcache is not None:
libgvhf.GINTdel_basis_prod(ctypes.byref(self.bpcache))
for n, bpcache in self._bpcache.items():
libgvhf.GINTdel_basis_prod(ctypes.byref(bpcache))
return self

def __del__(self):
Expand All @@ -79,20 +66,21 @@ def __del__(self):
except AttributeError:
pass

def build(self, cutoff=1e-14, group_size=None, diag_block_with_triu=False, aosym=False):
_mol = self._mol
mol = self.mol
def build(self, cutoff=1e-13, group_size=BLKSIZE, diag_block_with_triu=False, aosym=True):
original_mol = self.mol
mol = basis_seg_contraction(original_mol, allow_replica=True)

log = logger.new_logger(_mol, _mol.verbose)
log = logger.new_logger(original_mol, original_mol.verbose)
cput0 = log.init_timer()
sorted_mol, sorted_idx, uniq_l_ctr, l_ctr_counts = sort_mol(mol, log=log)
self.sorted_mol = sorted_mol
_sorted_mol, sorted_idx, uniq_l_ctr, l_ctr_counts = sort_mol(mol, log=log)
self._sorted_mol = _sorted_mol

if group_size is not None :
uniq_l_ctr, l_ctr_counts = _split_l_ctr_groups(uniq_l_ctr, l_ctr_counts, group_size)
self.nctr = len(uniq_l_ctr)
self.l_ctr_counts = l_ctr_counts

# Initialize vhfopt after reordering mol._bas
_vhf.VHFOpt.__init__(self, sorted_mol, self._intor, self._prescreen,
_vhf.VHFOpt.__init__(self, _sorted_mol, self._intor, self._prescreen,
self._qcondname, self._dmcondname)
self.direct_scf_tol = cutoff

Expand All @@ -104,39 +92,22 @@ def build(self, cutoff=1e-14, group_size=None, diag_block_with_triu=False, aosym
l_ctr_offsets, l_ctr_offsets, q_cond,
diag_block_with_triu=diag_block_with_triu, aosym=aosym)
self.log_qs = log_qs.copy()
cput1 = log.timer_debug1('Get pairing', *cput1)
cput1 = log.timer_debug1('Get AO pairing', *cput1)

# contraction coefficient for ao basis
cart_ao_loc = sorted_mol.ao_loc_nr(cart=True)
sph_ao_loc = sorted_mol.ao_loc_nr(cart=False)
cart_ao_loc = _sorted_mol.ao_loc_nr(cart=True)
sph_ao_loc = _sorted_mol.ao_loc_nr(cart=False)
self.cart_ao_loc = [cart_ao_loc[cp] for cp in l_ctr_offsets]
self.sph_ao_loc = [sph_ao_loc[cp] for cp in l_ctr_offsets]
self.angular = [l[0] for l in uniq_l_ctr]

cart_ao_loc = mol.ao_loc_nr(cart=True)
sph_ao_loc = mol.ao_loc_nr(cart=False)
nao = sph_ao_loc[-1]
ao_idx = np.array_split(np.arange(nao), sph_ao_loc[1:-1])
self.sph_ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])

# cartesian ao index
nao = cart_ao_loc[-1]
ao_idx = np.array_split(np.arange(nao), cart_ao_loc[1:-1])
self.cart_ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])
self.cart2sph = block_c2s_diag(self.angular, l_ctr_counts)
cput1 = log.timer_debug1('AO cart2sph coeff', *cput1)

if _mol.cart:
ncart = cart_ao_loc[-1]
inv_idx = np.argsort(self.cart_ao_idx, kind='stable').astype(np.int32)
self.coeff = cp.eye(ncart)[:,inv_idx]
else:
inv_idx = np.argsort(self.sph_ao_idx, kind='stable').astype(np.int32)
self.coeff = self.cart2sph[:, inv_idx]
cput1 = log.timer_debug1('AO cart2sph coeff', *cput1)
# Sorted AO indices
ao_loc = mol.ao_loc_nr(cart=original_mol.cart)
ao_idx = np.array_split(np.arange(original_mol.nao), ao_loc[1:-1])
self._ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])
cput1 = log.timer_debug1('AO indices', *cput1)

ao_loc = sorted_mol.ao_loc_nr(cart=True)
cput1 = log.timer_debug1('Get AO pairs', *cput1)
ao_loc = cart_ao_loc

self.pair2bra = pair2bra
self.pair2ket = pair2ket
Expand All @@ -156,16 +127,20 @@ def get_n_hermite_density_of_angular_pair(l):
n_density_per_angular_pair = (bas_pairs_locs[1:] - bas_pairs_locs[:-1]) * n_density_per_pair
self.density_offset = np.append(0, np.cumsum(n_density_per_angular_pair)).astype(np.int32)

self.bpcache = ctypes.POINTER(BasisProdCache)()
scale_shellpair_diag = 1.0
libgint.GINTinit_basis_prod(
ctypes.byref(self.bpcache), ctypes.c_double(scale_shellpair_diag),
ao_loc.ctypes.data_as(ctypes.c_void_p),
bas_pair2shls.ctypes.data_as(ctypes.c_void_p),
bas_pairs_locs.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(ncptype),
sorted_mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(sorted_mol.natm),
sorted_mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(sorted_mol.nbas),
sorted_mol._env.ctypes.data_as(ctypes.c_void_p))
self._bpcache = {}
for n in range(_num_devices):
with cp.cuda.Device(n), _streams[n]:
bpcache = ctypes.POINTER(BasisProdCache)()
scale_shellpair_diag = 1.0
libgint.GINTinit_basis_prod(
ctypes.byref(bpcache), ctypes.c_double(scale_shellpair_diag),
ao_loc.ctypes.data_as(ctypes.c_void_p),
bas_pair2shls.ctypes.data_as(ctypes.c_void_p),
bas_pairs_locs.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(ncptype),
_sorted_mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(_sorted_mol.natm),
_sorted_mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(_sorted_mol.nbas),
_sorted_mol._env.ctypes.data_as(ctypes.c_void_p))
self._bpcache[n] = bpcache

cput1 = log.timer_debug1('Initialize GPU cache', *cput1)
ncptype = len(self.log_qs)
Expand All @@ -176,22 +151,44 @@ def get_n_hermite_density_of_angular_pair(l):
nl = int(round(np.sqrt(ncptype)))
self.cp_idx, self.cp_jdx = np.unravel_index(np.arange(ncptype), (nl, nl))

if _mol.cart:
if original_mol.cart:
self.ao_loc = self.cart_ao_loc
self.ao_idx = self.cart_ao_idx
else:
self.ao_loc = self.sph_ao_loc
self.ao_idx = self.sph_ao_idx

def sort_orbitals(self, mat, axis=[]):
''' Transform given axis of a matrix into sorted AO,
and transform given auxiliary axis of a matrix into sorted auxiliary AO
'''
idx = self._ao_idx
shape_ones = (1,) * mat.ndim
fancy_index = []
for dim, n in enumerate(mat.shape):
if dim in axis:
assert n == len(idx)
indices = idx
else:
indices = np.arange(n)
idx_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
fancy_index.append(indices.reshape(idx_shape))
return mat[tuple(fancy_index)]

@property
def bpcache(self):
device_id = cp.cuda.Device().id
bpcache = self._bpcache[device_id]
return bpcache

@property
def cart2sph(self):
return block_c2s_diag(self.angular, self.l_ctr_counts)
# end of class VHFOpt


def get_int3c1e(mol, grids, direct_scf_tol):
def get_int3c1e(mol, grids, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

intopt = VHFOpt(mol, 'int2e')
intopt.build(direct_scf_tol, diag_block_with_triu=True, aosym=True, group_size=BLKSIZE)

nao = mol.nao
ngrids = grids.shape[0]
total_double_number = ngrids * nao * nao
Expand Down Expand Up @@ -221,7 +218,7 @@ def get_int3c1e(mol, grids, direct_scf_tol):
lj = intopt.angular[cpj]

stream = cp.cuda.get_current_stream()
nao_cart = intopt.mol.nao
nao_cart = intopt._sorted_mol.nao

log_q_ij = intopt.log_qs[cp_ij_id]

Expand Down Expand Up @@ -265,21 +262,18 @@ def get_int3c1e(mol, grids, direct_scf_tol):

row, col = np.tril_indices(nao)
int3c_grid_slice[:, row, col] = int3c_grid_slice[:, col, row]
ao_idx = np.argsort(intopt.ao_idx)
ao_idx = np.argsort(intopt._ao_idx)
grid_idx = np.arange(ngrids_of_split)
int3c_grid_slice = int3c_grid_slice[np.ix_(grid_idx, ao_idx, ao_idx)]

int3c_grid_slice.get(out = int3c[i_grid_split : i_grid_split + ngrids_of_split, :, :])

return int3c

def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):
def get_int3c1e_charge_contracted(mol, grids, charges, intopt):
wxj6000 marked this conversation as resolved.
Show resolved Hide resolved
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

intopt = VHFOpt(mol, 'int2e')
intopt.build(direct_scf_tol, diag_block_with_triu=True, aosym=True, group_size=BLKSIZE)

nao = mol.nao

assert charges.ndim == 1 and charges.shape[0] == grids.shape[0]
Expand All @@ -296,7 +290,7 @@ def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):
lj = intopt.angular[cpj]

stream = cp.cuda.get_current_stream()
nao_cart = intopt.mol.nao
nao_cart = intopt._sorted_mol.nao

log_q_ij = intopt.log_qs[cp_ij_id]

Expand All @@ -312,8 +306,8 @@ def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):
strides = np.array([ni, ni*nj], dtype=np.int32)

ngrids = grids.shape[0]
# n_charge_sum_per_thread = 1 means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids or larger number gaurantees one thread processes one pair and all grid points
# n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points
n_charge_sum_per_thread = 10

int1e_angular_slice = cp.zeros([j1-j0, i1-i0], order='C')
Expand Down Expand Up @@ -346,26 +340,30 @@ def get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol):

row, col = np.tril_indices(nao)
int1e[row, col] = int1e[col, row]
ao_idx = np.argsort(intopt.ao_idx)
ao_idx = np.argsort(intopt._ao_idx)
int1e = int1e[np.ix_(ao_idx, ao_idx)]

return cp.asnumpy(int1e)
return int1e

def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):
def get_int3c1e_density_contracted(mol, grids, dm, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

dm = cp.asarray(dm)
if dm.ndim == 3:
wxj6000 marked this conversation as resolved.
Show resolved Hide resolved
if dm.shape[0] > 2:
print("Warning: There are more than two density matrices to contract with one electron integrals, "
"it's not from an unrestricted calculation, and we're unsure about your purpose. "
"We sum the density matrices up, please check if that's expected.")
dm = cp.einsum("ijk->jk", dm)

assert dm.ndim == 2
assert dm.shape[0] == dm.shape[1] and dm.shape[0] == mol.nao

intopt = VHFOpt(mol, 'int2e')
intopt.build(direct_scf_tol, diag_block_with_triu=False, aosym=True, group_size=BLKSIZE)

nao_cart = intopt.mol.nao
nao_cart = intopt._sorted_mol.nao
ngrids = grids.shape[0]

dm = dm[np.ix_(intopt.ao_idx, intopt.ao_idx)] # intopt.ao_idx is in spherical basis
dm = intopt.sort_orbitals(dm, [0,1])
if not mol.cart:
cart2sph_transformation_matrix = intopt.cart2sph
# TODO: This part is inefficient (O(N^3)), should be changed to the O(N^2) algorithm
Expand All @@ -374,9 +372,9 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):

dm = cp.asnumpy(dm)

ao_loc_sorted_order = intopt.sorted_mol.ao_loc_nr(cart = True)
ao_loc_sorted_order = intopt._sorted_mol.ao_loc_nr(cart = True)
l_ij = intopt.l_ij.T.flatten()
bas_coords = intopt.sorted_mol.atom_coords()[intopt.sorted_mol._bas[:, ATOM_OF]].flatten()
bas_coords = intopt._sorted_mol.atom_coords()[intopt._sorted_mol._bas[:, ATOM_OF]].flatten()

n_total_hermite_density = intopt.density_offset[-1]
dm_pair_ordered = np.zeros(n_total_hermite_density)
Expand Down Expand Up @@ -412,8 +410,8 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):
nbins = 1
bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32)

# n_pair_sum_per_thread = 1 means every thread processes one pair and one grid
# n_pair_sum_per_thread = nao_cart or larger number gaurantees one thread processes one grid and all pairs of the same type
# n_pair_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_pair_sum_per_thread = nao_cart # or larger number gaurantees one thread processes one grid and all pairs of the same type
n_pair_sum_per_thread = nao_cart

err = libgint.GINTfill_int3c1e_density_contracted(
Expand All @@ -433,19 +431,28 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol):
if err != 0:
raise RuntimeError('GINTfill_int3c1e_density_contracted failed')

return cp.asnumpy(int3c_density_contracted)
return int3c_density_contracted

def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13):
assert intor == 'int1e_grids' and grids is not None
def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13, intopt=None):
assert intor == 'int1e_grids'
assert grids is not None
assert dm is None or charges is None, \
"Are you sure you want to contract the one electron integrals with both charge and density? " + \
"If so, pass in density, obtain the result with n_charge and contract with the charges yourself."

if intopt is None:
intopt = VHFOpt(mol)
intopt.build(direct_scf_tol)
else:
assert isinstance(intopt, VHFOpt), \
f"Please make sure intopt is a {VHFOpt.__module__}.{VHFOpt.__name__} object."
assert hasattr(intopt, "density_offset"), "Please call build() function for VHFOpt object first."

if dm is None and charges is None:
return get_int3c1e(mol, grids, direct_scf_tol)
return get_int3c1e(mol, grids, intopt)
elif dm is not None:
return get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol)
return get_int3c1e_density_contracted(mol, grids, dm, intopt)
elif charges is not None:
return get_int3c1e_charge_contracted(mol, grids, charges, direct_scf_tol)
return get_int3c1e_charge_contracted(mol, grids, charges, intopt)
else:
raise ValueError(f"Logic error in {__file__} {__name__}")
Loading
Loading