Skip to content

Commit

Permalink
Avoided the .transpose() in the densities or coefficient inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAMC committed Aug 9, 2023
1 parent 80f5fc5 commit 0c7aef0
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 173 deletions.
6 changes: 3 additions & 3 deletions examples/basic_examples/example_lda_functional_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@
# First a features method, which takes a molecule and returns an array of features
# It computes what in the article appears as potential e_\theta(r), as well as the
# input to the neural network to compute the density.
def lsda_density(molecule: Molecule, clip_cte: float = 1e-27, *_, **__):
def lsda_density(molecule: Molecule, clip_cte: float = 1e-27):
r"""Auxiliary function to generate the features of LSDA."""
# Molecule can compute the density matrix.
rho = molecule.density()
# To avoid numerical issues in JAX we limit too small numbers.
rho = jnp.clip(rho, a_min = clip_cte)
# Now we can implement the LDA energy density equation in the paper.
lda_e = -3./2. * (3. / (4*jnp.pi)) ** (1 / 3) * (rho**(4/3)).sum(axis = 0, keepdims = True)
lda_e = -3/2 * (3/(4*jnp.pi)) ** (1/3) * (rho**(4/3)).sum(axis = 1, keepdims = True)
# For simplicity we do not include the exchange polarization correction
# check function exchange_polarization_correction in functional.py
# The output of features must be an Array of dimension n_grid x n_features.
return lda_e.T
return lda_e

# Then we have to define a function that takes the output of features and returns the energy density.
# Its first argument represents the instance of the functional. Note how we sum over the dimensions
Expand Down
4 changes: 2 additions & 2 deletions examples/basic_examples/example_neural_functional_03.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
def densities(molecule: Molecule, *_, **__):
rho = jnp.clip(molecule.density(), a_min = 1e-27)
kinetic = jnp.clip(molecule.kinetic_density(), a_min = 1e-27)
return jnp.concatenate((rho, kinetic)).T
return jnp.concatenate((rho, kinetic), axis = 1)

out_features = 4
def nn_coefficients(instance, rhoinputs, *_, **__):
def nn_coefficients(instance, rhoinputs):
r"""
Instance is an instance of the class Functional or NeuralFunctional.
rhoinputs is the input to the neural network, in the form of an array.
Expand Down
75 changes: 19 additions & 56 deletions examples/train_complex_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ast import Return
from functools import partial
from flax.core import unfreeze, freeze
from typing import Dict, Optional, Union
Expand All @@ -13,7 +14,7 @@
from molecule import Molecule

from train import make_train_kernel, molecule_predictor
from functional import DispersionFunctional, Functional, NeuralFunctional, canonicalize_inputs, dm21_combine, dm21_hfgrads, dm21_coefficient_inputs, densities
from functional import DispersionFunctional, NeuralFunctional, canonicalize_inputs, dm21_coefficient_inputs, densities, dm21_combine_cinputs, dm21_combine_densities, dm21_hfgrads_cinputs, dm21_hfgrads_densities
from interface.pyscf import loader

from torch.utils.tensorboard import SummaryWriter
Expand All @@ -27,7 +28,7 @@

dirpath = os.path.dirname(os.path.dirname(__file__))
training_data_dirpath = os.path.normpath(dirpath + "/data/training/dissociation/")
training_files = ["H2_extrapolation.h5"]
training_files = ["H2_extrapolation_train.h5"]
# alternatively, use "H2plus_extrapolation.h5". You will have needed to execute in data_processing.py
#distances = [0.5, 0.75, 1, 1.25, 1.5]
#process_dissociation(atom1 = 'H', atom2 = 'H', charge = 0, spin = 0, file = 'H2_dissociation.xlsx', energy_column_name='cc-pV5Z', training_distances=distances)
Expand All @@ -48,7 +49,7 @@
activation = gelu
loadcheckpoint = False

def function(instance, rhoinputs, localfeatures, *_, **__):
def nn_coefficients(instance, rhoinputs, *_, **__):
x = canonicalize_inputs(rhoinputs) # Making sure dimensions are correct

# Initial layer: log -> dense -> tanh
Expand All @@ -71,9 +72,8 @@ def function(instance, rhoinputs, localfeatures, *_, **__):
x = activation(x) # activation = jax.nn.gelu
instance.sow('intermediates', 'residual_elu_'+str(i), x)

x = instance.head(x, out_features, sigmoid_scale_factor)
return instance.head(x, out_features, sigmoid_scale_factor)

return jnp.einsum('ri,ri->r', x, localfeatures)

def nn_dispersion(instance: nn.Module, x, *_, **__):

Expand Down Expand Up @@ -127,57 +127,21 @@ def nn_dispersion(instance: nn.Module, x, *_, **__):

return Cab/(1+jnp.exp(-(Rab/Rab0-1)))

def features(molecule: Molecule, functional_type: Optional[Union[str, Dict]] = 'LDA', clip_cte: float = 1e-27, *args, **kwargs):

r"""
Generates all features except the HF energy features.
Parameters
----------
molecule: Molecule
functional_type: Optional[Union[str, Dict]]
Either one of 'LDA', 'GGA', 'MGGA' or Dictionary
{'u_range': range(), 'w_range': range()} that generates
a functional. See `default_functionals` function.
clip_cte: float
A numerical threshold to avoid numerical precision issues
Default: 1e-27
Returns
----------
Tuple[Array, Array]
The features and local features, similar to those used by DM21
"""

features = dm21_coefficient_inputs(molecule, *args, **kwargs)
localfeatures = densities(molecule, functional_type, clip_cte)

# We return them with the first index being the position r and the second the feature.
return features, localfeatures

features = partial(features, functional_type = 'MGGA')

def combine(features, ehf):

features, local_features = features

# Remember that DM concatenates the hf density in the x features by spin...
features = jnp.concatenate([features, ehf[:,0].T, ehf[:,1].T], axis=1)

# ... and in the y features by omega.
local_features = jnp.concatenate([local_features, ehf[:,0].T, ehf[:,1].T], axis=1)
return features,local_features
def combine_densities(densities, ehf):
ehf = jnp.reshape(ehf, (ehf.shape[2], ehf.shape[0]*ehf.shape[1]))
return jnp.concatenate((densities, ehf), axis = 1)

omegas = [0., 0.4]
functional = NeuralFunctional(function = function,
features = features,
nograd_features = lambda molecule, *_, **__: molecule.HF_energy_density(omegas),
featuregrads=lambda self, params, molecule, features, nograd_features, *_, **__: dm21_hfgrads(self, params,
molecule, features, nograd_features,
omegas = omegas),
combine = combine)
functional = NeuralFunctional(coefficients = nn_coefficients,
energy_densities=partial(densities, functional_type = 'MGGA'),
nograd_densities=lambda molecule, *_, **__: molecule.HF_energy_density(omegas),
densitygrads = lambda self, params, molecule, nograd_densities, cinputs, grad_densities, *_, **__: dm21_hfgrads_densities(self, params, molecule, nograd_densities, cinputs, grad_densities, omegas),
combine_densities = combine_densities,
coefficient_inputs=dm21_coefficient_inputs,
nograd_coefficient_inputs = lambda molecule, *_, **__: molecule.HF_energy_density(omegas),
coefficient_input_grads = lambda self, params, molecule, nograd_cinputs, grad_cinputs, densities, *_, **__: dm21_hfgrads_cinputs(self, params, molecule, nograd_cinputs, grad_cinputs, densities, omegas),
combine_inputs = dm21_combine_cinputs
)

DispersionNN = DispersionFunctional(dispersion = nn_dispersion)

Expand All @@ -188,8 +152,7 @@ def combine(features, ehf):
# We generate the features from the molecule we created before, to initialize the parameters
key, = split(key, 1)
rhoinputs = jax.random.normal(key, shape = [2, 11])
localfeatures = jax.random.normal(key, shape = [2, out_features])
params = functional.init(key, rhoinputs, localfeatures)
params = functional.init(key, rhoinputs)

dispersioninputs = jax.random.normal(key, shape = [2, 4])
dparams = DispersionNN.init(key, dispersioninputs)
Expand Down
103 changes: 33 additions & 70 deletions functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,11 @@ class Molecule
tau = molecule.kinetic_density()

grad_rho_norm = jnp.sum(grad_rho**2, axis=-1)
grad_rho_norm_sumspin = jnp.sum(grad_rho.sum(axis=0, keepdims=True) ** 2, axis=-1)
grad_rho_norm_sumspin = jnp.sum(grad_rho.sum(axis=1, keepdims=True) ** 2, axis=-1)

features = jnp.concatenate((rho, grad_rho_norm_sumspin, grad_rho_norm, tau), axis=0)
features = jnp.concatenate((rho, grad_rho_norm_sumspin, grad_rho_norm, tau), axis=1)

return features.T
return features

def dm21_densities(molecule: Molecule, functional_type: Optional[Union[str, Dict[str, int]]] = 'LDA', clip_cte: float = 1e-27, *_, **__):
r"""
Expand Down Expand Up @@ -514,44 +514,13 @@ class Molecule
log_w_sigma = jnp.where(jnp.greater(log_rho, jnp.log2(clip_cte)), log_1t_sigma - jnp.log2(1 + beta*(2**log_1t_sigma)) + jnp.log2(beta), 0)

# Compute the local features
localfeatures = jnp.empty((0, log_rho.shape[-1]))
localfeatures = jnp.empty((log_rho.shape[0], 0))
for i, j in itertools.product(u_range, w_range):
mgga_term = jnp.expand_dims((2**(4/3.*log_rho + i * log_u_sigma + j * log_w_sigma)).sum(axis=0), axis = 0) \
mgga_term = (2**(4/3.*log_rho + i * log_u_sigma + j * log_w_sigma)).sum(axis=1, keepdims = True) \
* jnp.where(jnp.logical_and(i==0, j==0), -2 * jnp.pi * (3 / (4 * jnp.pi)) ** (4 / 3), 1) # to match DM21
localfeatures = jnp.concatenate((localfeatures, mgga_term), axis=0)
localfeatures = jnp.concatenate((localfeatures, mgga_term), axis=1)

return localfeatures.T

def dm21_densities_old(molecule: Molecule, functional_type: Optional[Union[str, Dict]] = 'LDA', clip_cte: float = 1e-27, *args, **kwargs):

r"""
Generates all features except the HF energy features.
Parameters
----------
molecule: Molecule
functional_type: Optional[Union[str, Dict]]
Either one of 'LDA', 'GGA', 'MGGA' or Dictionary
{'u_range': range(), 'w_range': range()} that generates
a functional. See `default_functionals` function.
clip_cte: float
A numerical threshold to avoid numerical precision issues
Default: 1e-27
Returns
----------
Tuple[Array, Array]
The features and local features, similar to those used by DM21
"""
raise DeprecationWarning('This function is deprecated, use dm21_densities instead.')

features = dm21_coefficient_inputs(molecule, *args, **kwargs)
localfeatures = dm21_densities(molecule, functional_type, clip_cte)

# We return them with the first index being the position r and the second the feature.
return features, localfeatures
return localfeatures

def dm21_combine_cinputs(cinputs, ehf):
r"""
Expand Down Expand Up @@ -827,10 +796,10 @@ def exchange_polarization_correction(e_PF, rho):
Array, shape (n_grid)
The ready to be integrated electronic energy density.
"""
zeta = (rho[0] - rho[1])/ rho.sum(axis = 0)
zeta = (rho[:,0] - rho[:,1])/ rho.sum(axis = 1)
def fzeta(z): return ((1-z)**(4/3) + (1+z)**(4/3) - 2) / (2*(2**(1/3) - 1))
# Eq 2.71 in from Time-Dependent Density-Functional Theory, from Carsten A. Ullrich
return e_PF[0] + (e_PF[1]-e_PF[0])*fzeta(zeta)
return e_PF[:,0] + (e_PF[:,1]-e_PF[:,0])*fzeta(zeta)


def correlation_polarization_correction(e_PF: Array, rho: Array, clip_cte: float = 1e-27):
Expand Down Expand Up @@ -858,13 +827,13 @@ def correlation_polarization_correction(e_PF: Array, rho: Array, clip_cte: float
The ready to be integrated electronic energy density.
"""

e_tilde_PF = jnp.einsum('sr,r->sr', e_PF, rho.sum(axis = 0))
e_tilde_PF = jnp.einsum('rs,r->rs', e_PF, rho.sum(axis = 1))

log_rho = jnp.log2(jnp.clip(rho.sum(axis = 0), a_min = clip_cte))
log_rho = jnp.log2(jnp.clip(rho.sum(axis = 1), a_min = clip_cte))
#assert not jnp.isnan(log_rho).any() and not jnp.isinf(log_rho).any()
log_rs = jnp.log2((3/(4*jnp.pi))**(1/3)) - log_rho/3.

zeta = jnp.where(rho.sum(axis = 0) > clip_cte, (rho[0] - rho[1]) / (rho.sum(axis = 0)), 0.)
zeta = jnp.where(rho.sum(axis = 1) > clip_cte, (rho[:,0] - rho[:,1]) / (rho.sum(axis = 1)), 0.)
def fzeta(z):
zm = 2**(4*jnp.log2(1-z)/3)
zp = 2**(4*jnp.log2(1+z)/3)
Expand All @@ -889,7 +858,7 @@ def fzeta(z):
fz = jnp.round(fzeta(zeta), int(math.log10(clip_cte)))
z4 = jnp.round(2**(4*jnp.log2(jnp.clip(zeta, a_min = clip_cte))), int(math.log10(clip_cte)))

e_tilde = e_tilde_PF[0] + alphac * (fz/(grad(grad(fzeta))(0.)))* (1-z4) + (e_tilde_PF[1]-e_tilde_PF[0]) * fz*z4
e_tilde = e_tilde_PF[:,0] + alphac * (fz/(grad(grad(fzeta))(0.)))* (1-z4) + (e_tilde_PF[:,1]-e_tilde_PF[:,0]) * fz*z4
#assert not jnp.isnan(e_tilde).any() and not jnp.isinf(e_tilde).any()

return e_tilde
Expand Down Expand Up @@ -962,18 +931,18 @@ class Molecule
log_1t_sigma - jnp.log2(1 + beta*(2**log_1t_sigma)) + jnp.log2(beta), 0)

# Compute the local features
localfeatures = jnp.empty((0, log_rho.shape[-1]))
localfeatures = jnp.empty((log_rho.shape[0], 0))
for i, j in itertools.product(u_range, w_range):
mgga_term = 2**(4/3.*log_rho + i * log_u_sigma + j * log_w_sigma)

# First we concatenate the exchange terms
localfeatures = jnp.concatenate((localfeatures, mgga_term), axis=0)
localfeatures = jnp.concatenate((localfeatures, mgga_term), axis=1)

######### Correlation features ###############

grad_rho_norm_sq_ss = jnp.sum((grad_rho.sum(axis = 0))**2, axis=-1)
grad_rho_norm_sq_ss = jnp.sum((grad_rho.sum(axis = 1))**2, axis=-1)
log_grad_rho_norm_ss = jnp.log2(jnp.clip(grad_rho_norm_sq_ss, a_min = clip_cte))/2
log_rho_ss = jnp.log2(jnp.clip(rho.sum(axis = 0), a_min = clip_cte))
log_rho_ss = jnp.log2(jnp.clip(rho.sum(axis = 1), a_min = clip_cte))
log_x_ss = log_grad_rho_norm_ss - 4/3.*log_rho_ss

log_u_ss = jnp.where(jnp.greater(log_rho_ss,jnp.log2(clip_cte)),
Expand All @@ -982,34 +951,28 @@ class Molecule
log_u_ab = jnp.where(jnp.greater(log_rho_ss,jnp.log2(clip_cte)),
log_x_ss - 1 - jnp.log2(1 + beta*(2**(log_x_ss-1))) + jnp.log2(beta), 0)

log_u_c = jnp.stack((log_u_ss, log_u_ab), axis = 0)
log_u_c = jnp.stack((log_u_ss, log_u_ab), axis = 1)


log_tau_ss = jnp.log2(jnp.clip(tau.sum(axis = 0), a_min = clip_cte))
log_tau_ss = jnp.log2(jnp.clip(tau.sum(axis = 1), a_min = clip_cte))
log_1t_ss = log_tau_ss - 5/3.*log_rho_ss
log_w_ss = jnp.where(jnp.greater(log_rho.sum(axis = 0), jnp.log2(clip_cte)),
log_w_ss = jnp.where(jnp.greater(log_rho.sum(axis = 1), jnp.log2(clip_cte)),
log_1t_ss - jnp.log2(1 + beta*(2**log_1t_ss)) + jnp.log2(beta), 0)

log_w_ab = jnp.where(jnp.greater(log_rho.sum(axis = 0), jnp.log2(clip_cte)),
log_w_ab = jnp.where(jnp.greater(log_rho.sum(axis = 1), jnp.log2(clip_cte)),
log_1t_ss - 1 - jnp.log2(1 + beta*(2**(log_1t_ss-1))) + jnp.log2(beta), 0)

log_w_c = jnp.stack((log_w_ss, log_w_ab), axis = 0)


A_ = jnp.array([[0.031091],
[0.015545]])
alpha1 = jnp.array([[0.21370],
[0.20548]])
beta1 = jnp.array([[7.5957],
[14.1189]])
beta2 = jnp.array([[3.5876],
[6.1977]])
beta3 = jnp.array([[1.6382],
[3.3662]])
beta4 = jnp.array([[0.49294],
[0.62517]])
log_w_c = jnp.stack((log_w_ss, log_w_ab), axis = 1)


A_ = jnp.array([[0.031091,0.015545]])
alpha1 = jnp.array([[0.21370,0.20548]])
beta1 = jnp.array([[7.5957,14.1189]])
beta2 = jnp.array([[3.5876,6.1977]])
beta3 = jnp.array([[1.6382,3.3662]])
beta4 = jnp.array([[0.49294,0.62517]])

log_rho = jnp.log2(jnp.clip(rho.sum(axis = 0), a_min = clip_cte))
log_rho = jnp.log2(jnp.clip(rho.sum(axis = 1, keepdims = True), a_min = clip_cte))
log_rs = jnp.log2((3/(4*jnp.pi))**(1/3)) - log_rho/3.
brs_1_2 = 2**(log_rs/2 + jnp.log2(beta1))
ars = 2**(log_rs + jnp.log2(alpha1))
Expand All @@ -1025,9 +988,9 @@ class Molecule
2**(jnp.log2(e_PW92) + i * log_u_c + j * log_w_c), 0)

# First we concatenate the exchange terms
localfeatures = jnp.concatenate((localfeatures, mgga_term), axis=0)
localfeatures = jnp.concatenate((localfeatures, mgga_term), axis=1)

return localfeatures.T
return localfeatures


############# Dispersion functional #############
Expand Down
10 changes: 5 additions & 5 deletions molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def density(rdm1: Array, ao: Array, precision: Precision = Precision.HIGHEST) ->
The density. Shape: (n_spin, n_grid_points)
"""

return jnp.einsum("...ab,ra,rb->...r", rdm1, ao, ao, precision=precision)
return jnp.einsum("...ab,ra,rb->r...", rdm1, ao, ao, precision=precision)

@partial(jax.jit, static_argnames="precision")
def grad_density(
Expand Down Expand Up @@ -259,13 +259,13 @@ def grad_density(
The density gradient. Shape: (n_spin, n_grid_points, 3)
"""

return 2 * jnp.einsum("...ab,ra,rbj->...rj", rdm1, ao, grad_ao, precision=precision)
return 2 * jnp.einsum("...ab,ra,rbj->r...j", rdm1, ao, grad_ao, precision=precision)

@partial(jax.jit, static_argnames="precision")
def lapl_density(rdm1: Array, ao: Array, grad_ao: Array, grad_2_ao: PyTree, precision: Precision = Precision.HIGHEST):

return 2* jnp.einsum("...ab,raj,rbj->...r", rdm1, grad_ao, grad_ao, precision=precision) + \
2 * jnp.einsum("...ab,ra,rbi->...r", rdm1, ao, grad_2_ao, precision=precision)
return 2* jnp.einsum("...ab,raj,rbj->r...", rdm1, grad_ao, grad_ao, precision=precision) + \
2 * jnp.einsum("...ab,ra,rbi->r...", rdm1, ao, grad_2_ao, precision=precision)

@partial(jax.jit, static_argnames="precision")
def kinetic_density(rdm1: Array, grad_ao: Array, precision: Precision = Precision.HIGHEST) -> Array:
Expand All @@ -291,7 +291,7 @@ def kinetic_density(rdm1: Array, grad_ao: Array, precision: Precision = Precisio
The kinetic energy density. Shape: (n_spin, n_grid_points)
"""

return 0.5 * jnp.einsum("...ab,raj,rbj->...r", rdm1, grad_ao, grad_ao, precision=precision)
return 0.5 * jnp.einsum("...ab,raj,rbj->r...", rdm1, grad_ao, grad_ao, precision=precision)

@partial(jax.jit, static_argnames=["precision"])
def HF_energy_density(rdm1: Array, ao: Array, chi: Array, precision: Precision = Precision.HIGHEST):
Expand Down
Loading

0 comments on commit 0c7aef0

Please sign in to comment.