Skip to content

Commit

Permalink
Add ensemble penalty method of Wheeler, Kleiner and Wagner (https://a…
Browse files Browse the repository at this point in the history
…rxiv.org/abs/2312.00693) for excited states

PiperOrigin-RevId: 643416256
Change-Id: I069fe28f8a981c4fcb23b865bf0f2bad16791228
  • Loading branch information
dpfau committed Aug 22, 2024
1 parent 2b034ad commit 336bbee
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 45 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,18 @@ in the config file.

## Excited States

Excited state properties of systems can be calculated using the [Natural Excited
States for VMC (NES-VMC) algorithm](https://arxiv.org/abs/2308.16848).
Excited state properties of systems can be calculated using either the
[Natural Excited States for VMC (NES-VMC) algorithm](https://arxiv.org/abs/2308.16848)
or an [ensemble penalty method](https://arxiv.org/abs/2312.00693).
To enable the calculation of `k` states of a system, simply set
`cfg.system.states=k` in the config file.
`cfg.system.states=k` in the config file. By default, NES-VMC is used, but to
enable the ensemble penalty method, add `cfg.optim.objective='vmc_overlap'` to
the config. NES-VMC does not have any parameters to set, but the ensemble
penalty method has a free choice of weights on the energies and overlap penalty,
which can be set in `cfg.optim.overlap`. If the weights are not set for the
energies in the config, they are automatically set to 1/k for state k. We have
found that NES-VMC is generally more accurate than the ensemble penalty method,
but include both for completeness.

## Output

Expand Down
14 changes: 13 additions & 1 deletion ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def default() -> ml_collections.ConfigDict:
# importlib.import_module.
'config_module': __name__,
'optim': {
'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc'
# Objective type. One of:
# 'vmc': minimise <H> by standard VMC energy minimization
# 'wqmc': minimise <H> by Wasserstein QMC
# 'vmc_overlap': minimize \sum_i <H_i> + \lambda \sum_ij <psi_i psi_j>
'objective': 'vmc',
'iterations': 1000000, # number of iterations
'optimizer': 'kfac', # one of adam, kfac, lamb, none
'laplacian': 'default', # of of default or folx (for forward lapl)
Expand Down Expand Up @@ -84,6 +88,14 @@ def default() -> ml_collections.ConfigDict:
# magnitude. Useful for removing non-singlet states from excited
# state calculations.
'spin_energy': 0.0,
# If 'objective' is 'vmc_overlap', these parameters control the
# penalty term.
'overlap': {
# Weights on each state. Generate automatically if none provided.
'weights': None,
# Strength of the penalty term
'penalty': 1.0,
},
# KFAC hyperparameters. See KFAC documentation for details.
'kfac': {
'invert_every': 1,
Expand Down
51 changes: 34 additions & 17 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def local_energy(
complex_output: bool = False,
laplacian_method: str = 'default',
states: int = 0,
state_specific: bool = False,
pp_type: str = 'ccecp',
pp_symbols: Sequence[str] | None = None,
) -> LocalEnergy:
Expand All @@ -341,6 +342,9 @@ def local_energy(
'folx': use Microsoft's implementation of forward laplacian
states: Number of excited states to compute. If 0, compute ground state with
default machinery. If 1, compute ground state with excited state machinery
state_specific: Only used for excited states (states > 0). If true, then
the local energy is computed separately for each output from the network,
instead of the local energy matrix being computed.
pp_type: type of pseudopotential to use. Only used if ecp_symbols is
provided.
pp_symbols: sequence of element symbols for which the pseudopotential is
Expand All @@ -353,17 +357,6 @@ def local_energy(
"""
del nspins

if states:
ke = excited_kinetic_energy_matrix(f,
states,
complex_output,
laplacian_method)
else:
ke = local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output,
laplacian_method=laplacian_method)

if not pp_symbols:
effective_charges = charges
use_pp = False
Expand Down Expand Up @@ -416,14 +409,38 @@ def _e_l(
pp_nonlocal, (None, None, None, data_vmap_dims, 0, 0))
pot_spectrum += vmap_pp_nonloc(key, f, params, data_, ae, r_ae)

# Compute kinetic energy and matrix of states
psi_mat, kin_mat = ke(params, data)

# Combine terms
hpsi_mat = kin_mat + psi_mat * pot_spectrum
energy_mat = jnp.linalg.solve(psi_mat, hpsi_mat)
total_energy = jnp.trace(energy_mat)
if state_specific:
# For simplicity, we will only implement a folx version of the kinetic
# energy calculation here.
# TODO(pfau): factor out code repeated here and in _lapl_over_f
pos_ = jnp.reshape(data.positions, [states, -1])
spins_ = jnp.reshape(data.spins, [states, -1])
f_closure = lambda x: f(params, x, spins_[0], data.atoms, data.charges)
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
sign_out, log_out = folx.batched_vmap(f_wrapped, 1)(pos_)
kin = -(log_out.laplacian +
jnp.sum(log_out.jacobian.dense_array ** 2, axis=-2)) / 2
if complex_output:
kin -= 0.5j * sign_out.laplacian
kin += 0.5 * jnp.sum(sign_out.jacobian.dense_array ** 2, axis=-2)
kin -= 1.j * jnp.sum(sign_out.jacobian.dense_array *
log_out.jacobian.dense_array, axis=-2)
total_energy = jnp.diag(kin) + pot_spectrum[:, 0]
energy_mat = None
else:
# Compute kinetic energy and matrix of states
ke = excited_kinetic_energy_matrix(
f, states, complex_output, laplacian_method)
psi_mat, kin_mat = ke(params, data)
hpsi_mat = kin_mat + psi_mat * pot_spectrum
energy_mat = jnp.linalg.solve(psi_mat, hpsi_mat)
total_energy = jnp.trace(energy_mat)
else:
ke = local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output,
laplacian_method=laplacian_method)
ae, _, r_ae, r_ee = networks.construct_input_features(
data.positions, data.atoms
)
Expand Down
190 changes: 185 additions & 5 deletions ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,23 @@ class AuxiliaryLossData:
"""Auxiliary data returned by total_energy.
Attributes:
energy: mean energy over batch, and over all devices if inside a pmap.
variance: mean variance over batch, and over all devices if inside a pmap.
local_energy: local energy for each MCMC configuration.
clipped_energy: local energy after clipping has been applied
grad_local_energy: gradient of the local energy.
local_energy_mat: for excited states, the local energy matrix.
s_ij: Matrix of overlaps between wavefunctions.
mean_s_ij: Mean value of the overlap between wavefunctions across walkers.
"""
energy: jax.Array # for some losses, the energy and loss are not the same
variance: jax.Array
local_energy: jax.Array
clipped_energy: jax.Array
grad_local_energy: jax.Array | None
local_energy_mat: jax.Array | None
grad_local_energy: jax.Array | None = None
local_energy_mat: jax.Array | None = None
s_ij: jax.Array | None = None
mean_s_ij: jax.Array | None = None


class LossFn(Protocol):
Expand Down Expand Up @@ -105,7 +111,7 @@ def clip_local_values(
device.
"""

batch_mean = lambda values: constants.pmean(jnp.mean(values))
batch_mean = lambda values: constants.pmean(jnp.mean(values, axis=0))

def clip_at_total_variation(values, center, scale):
tv = batch_mean(jnp.abs(values- center))
Expand All @@ -114,7 +120,13 @@ def clip_at_total_variation(values, center, scale):
if clip_from_median:
# More natural place to center the clipping, but expensive due to both
# the median and all_gather (at least on multihost)
clip_center = jnp.median(constants.all_gather(local_values).real)
all_local_values = constants.all_gather(local_values).real
shape = all_local_values.shape
if all_local_values.ndim == 3: # energy_and_overlap case
all_local_values = all_local_values.reshape([-1, shape[-1]])
else:
all_local_values = all_local_values.reshape([-1])
clip_center = jnp.median(all_local_values, axis=0)
else:
clip_center = mean_local_values
# roughly, the total variation of the local energies
Expand Down Expand Up @@ -207,10 +219,10 @@ def total_energy(
loss_diff = e_l - loss
variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff)))
return loss, AuxiliaryLossData(
energy=loss,
variance=variance.real,
local_energy=e_l,
clipped_energy=e_l,
grad_local_energy=None,
local_energy_mat=e_l_mat,
)

Expand Down Expand Up @@ -354,6 +366,7 @@ def batch_local_energy_pos(pos):
grad_e_l = jax.grad(batch_local_energy_pos)(data.positions)
grad_e_l = jnp.tanh(jax.lax.stop_gradient(grad_e_l))
return loss, AuxiliaryLossData(
energy=loss,
variance=variance.real,
local_energy=e_l,
clipped_energy=e_l,
Expand Down Expand Up @@ -412,3 +425,170 @@ def log_q(params_, pos_, spins_, atoms_, charges_):
return primals_out, tangents_out

return total_energy


def make_energy_overlap_loss(network: networks.LogFermiNetLike,
local_energy: hamiltonian.LocalEnergy,
clip_local_energy: float = 0.0,
clip_from_median: bool = True,
center_at_clipped_energy: bool = True,
overlap_penalty: float = 1.0,
overlap_weight: Tuple[float, ...] = (1.0,),
complex_output: bool = False) -> LossFn:
"""Creates the loss function for the penalty method for excited states.
Args:
network: callable which evaluates the log of the magnitude of the
wavefunction (square root of the log probability distribution) at the
single MCMC configuration given the network parameters. For the overlap
loss, this returns an entire state matrix - all pairs of states and
walkers.
local_energy: callable which evaluates the local energy.
clip_local_energy: If greater than zero, clip local energies that are
outside [E_L - n D, E_L + n D], where E_L is the mean local energy, n is
this value and D the mean absolute deviation of the local energies from
the mean, to the boundaries. The clipped local energies are only used to
evaluate gradients.
clip_from_median: If true, center the clipping window at the median rather
than the mean. Potentially expensive in multi-host training, but more
accurate.
center_at_clipped_energy: If true, center the local energy differences
passed back to the gradient around the clipped local energy, so the mean
difference across the batch is guaranteed to be zero. Seems to
significantly improve performance with pseudopotentials.
overlap_penalty: The strength of the penalty term that controls the
tradeoff between minimizing the weighted energies and keeping the states
orthogonal.
overlap_weight: The weight to apply to each individual energy in the overall
optimization.
complex_output: If true, the network output is complex-valued.
Returns:
LossFn callable which evaluates the total energy of the system.
"""

data_axes = networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0)
batch_local_energy = jax.vmap(
local_energy, in_axes=(None, 0, data_axes), out_axes=(0, 0))
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)
overlap_weight = jnp.array(overlap_weight)

# TODO(pfau): how much of this can be factored out with make_loss?
@jax.custom_jvp
def total_energy_and_overlap(
params: networks.ParamTree,
key: chex.PRNGKey,
data: networks.FermiNetData,
) -> Tuple[jnp.ndarray, AuxiliaryLossData]:
"""Evaluates the energy of the network for a batch of configurations."""

# Energy term. Largely similar to make_energy_loss, but simplified.
keys = jax.random.split(key, num=data.positions.shape[0])
e_l, _ = batch_local_energy(params, keys, data)
loss = constants.pmean(jnp.mean(e_l, axis=0))
loss_diff = e_l - loss
variance = constants.pmean(
jnp.mean(loss_diff * jnp.conj(loss_diff), axis=0))
weighted_energy = jnp.dot(loss, overlap_weight)

# Overlap matrix. To compute S_ij^2 = <psi_i psi_j>^2/<psi_i^2><psi_j^2>
# by Monte Carlo, you can split up the terms into a product of MC estimates
# E_{x_i ~ psi_i^2} [ psi_j(x_i) / psi_i(x_i) ] *
# E_{x_j ~ psi_j^2} [ psi_i(x_j) / psi_j(x_j) ]
# Since x_i and x_j are sampled independently, the product of empirical
# estimates is an unbiased estimate of the product of expectations.

# #TODO(pfau): Avoid the double call to batch_network here and in the jvp.
sign_psi, log_psi = batch_network(params,
data.positions,
data.spins,
data.atoms,
data.charges)
sign_psi_diag = jax.vmap(jnp.diag)(sign_psi)[..., None]
log_psi_diag = jax.vmap(jnp.diag)(log_psi)[..., None]
s_ij_local = sign_psi * sign_psi_diag * jnp.exp(log_psi - log_psi_diag)
s_ij = constants.pmean(jnp.mean(s_ij_local, axis=0))
total_overlap = jnp.sum(jnp.triu(s_ij * s_ij.T, 1))

return (weighted_energy + overlap_penalty * total_overlap,
AuxiliaryLossData(
energy=loss,
variance=variance.real,
local_energy=e_l,
clipped_energy=loss,
s_ij=s_ij_local,
mean_s_ij=s_ij,
local_energy_mat=e_l))

@total_energy_and_overlap.defjvp
def total_energy_and_overlap_jvp(primals, tangents): # pylint: disable=unused-variable
"""Custom Jacobian-vector product for unbiased local energy gradients."""
if complex_output:
raise NotImplementedError('Complex output is not supported with penalty '
'method gradients for excited states.')

params, key, data = primals
batch_loss, aux_data = total_energy_and_overlap(params, key, data)
energy = aux_data.energy.real

if clip_local_energy > 0.0:
clipped_energy, energy_diff = clip_local_values(
aux_data.local_energy,
energy,
clip_local_energy,
clip_from_median,
center_at_clipped_energy)
aux_data.clipped_energy = jnp.dot(clipped_energy, overlap_weight)
else:
energy_diff = aux_data.local_energy - energy

# To take the gradient of the overlap squared between psi_i and psi_j, we
# can use a similar derivation to the gradient of the energy, which gives
# \nabla_i S_ij^2 =
# 2 E_{x_j ~ psi_j^2} [ psi_i(x_j) / psi_j(x_j) ] *
# E_{x_i ~ psi_i^2} [ (psi_j(x_i) / psi_i(x_i) - <psi_j(x_i) / psi_i(x_i)>)
# \nabla_i log psi_i(x_i) ]
# where \nabla_i means the gradient wrt the parameters of psi_i
# Again, because the two expectations are taken over independent samples,
# the product of empirical estimates will be unbiased.

if clip_local_energy > 0.0:
clipped_overlap, overlap_diff = clip_local_values(
aux_data.s_ij,
aux_data.mean_s_ij,
clip_local_energy,
clip_from_median,
center_at_clipped_energy)
else:
clipped_overlap = aux_data.s_ij
overlap_diff = clipped_overlap - aux_data.mean_s_ij

overlap_diff = 2 * jnp.sum(jnp.triu(
clipped_overlap * overlap_diff.transpose((0, 2, 1)), 1), axis=1)

# Due to the simultaneous requirements of KFAC (calling convention must be
# (params, rng, data)) and Laplacian calculation (only want to take
# Laplacian wrt electron positions) we need to change up the calling
# convention between total_energy and batch_network
data = primals[2]
data_tangents = tangents[2]
primals = (primals[0], data.positions, data.spins, data.atoms, data.charges)
tangents = (tangents[0], data_tangents.positions, data_tangents.spins,
data_tangents.atoms, data_tangents.charges)

psi_primal, psi_tangent = jax.jvp(batch_network, primals, tangents)
_, log_primal = psi_primal
_, log_tangent = psi_tangent
kfac_jax.register_normal_predictive_distribution(
jax.vmap(jnp.diag)(log_primal))
device_batch_size = jnp.shape(aux_data.local_energy)[0]
tangent_loss = energy_diff * overlap_weight + overlap_penalty * overlap_diff
tangents_out = (
jnp.sum(jax.vmap(jnp.diag)(log_tangent) * tangent_loss) /
device_batch_size,
aux_data)

primals_out = batch_loss.real, aux_data
return primals_out, tangents_out

return total_energy_and_overlap
Loading

0 comments on commit 336bbee

Please sign in to comment.