Skip to content

Commit

Permalink
Enable all combinations of folx + complex wavef'n + excited states
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642620004
Change-Id: I5d271f28b240cc59934bbd3492af5ea35a8b8ee9
  • Loading branch information
dpfau committed Aug 22, 2024
1 parent 82d89b4 commit fa0f0d0
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 42 deletions.
83 changes: 61 additions & 22 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,18 @@ def grad_phase_closure(x):
return result

elif laplacian_method == 'folx':
if complex_output:
raise NotImplementedError('Forward laplacian not yet supported for'
'complex-valued outputs.')
else:
def _lapl_over_f(params, data):
f_closure = lambda x: logabs_f(params,
x,
data.spins,
data.atoms,
data.charges)
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
output = f_wrapped(data.positions)
return - (output.laplacian +
jnp.sum(output.jacobian.dense_array ** 2)) / 2
def _lapl_over_f(params, data):
f_closure = lambda x: f(params, x, data.spins, data.atoms, data.charges)
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
output = f_wrapped(data.positions)
result = - (output[1].laplacian +
jnp.sum(output[1].jacobian.dense_array ** 2)) / 2
if complex_output:
result -= 0.5j * output[0].laplacian
result += 0.5 * jnp.sum(output[0].jacobian.dense_array ** 2)
result -= 1.j * jnp.sum(output[0].jacobian.dense_array *
output[1].jacobian.dense_array)
return result
else:
raise NotImplementedError(f'Laplacian method {laplacian_method} '
'not implemented.')
Expand All @@ -163,13 +161,15 @@ def _lapl_over_f(params, data):
def excited_kinetic_energy_matrix(
f: networks.FermiNetLike,
states: int,
complex_output: bool = False,
laplacian_method: str = 'default') -> KineticEnergy:
"""Creates a f'n which evaluates the matrix of local kinetic energies.
Args:
f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where
each array contains one element per excited state.
states: the number of excited states
complex_output: If true, the output of f is complex-valued.
laplacian_method: Laplacian calculation method. One of:
'default': take jvp(grad), looping over inputs
'folx': use Microsoft's implementation of forward laplacian
Expand All @@ -188,10 +188,37 @@ def _lapl_all_states(params, pos, spins, atoms, charges):
grad_f_closure = lambda x: grad_f(params, x, spins, atoms, charges)
primal, dgrad_f = jax.linearize(grad_f_closure, pos)

if complex_output:
grad_phase = jax.jacrev(utils.select_output(f, 0), argnums=1)
def grad_phase_closure(x):
return grad_phase(params, x, spins, atoms, charges)
phase_primal, dgrad_phase = jax.linearize(grad_phase_closure, pos)
hessian_diagonal = (
lambda i: dgrad_f(eye[i])[:, i] + 1.j * dgrad_phase(eye[i])[:, i]
)
else:
phase_primal = 1.0
hessian_diagonal = lambda i: dgrad_f(eye[i])[:, i]

if complex_output:
if pos.dtype == jnp.float32:
dtype = jnp.complex64
elif pos.dtype == jnp.float64:
dtype = jnp.complex128
else:
raise ValueError(f'Unsupported dtype for input: {pos.dtype}')
else:
dtype = pos.dtype

result = -0.5 * lax.fori_loop(
0, n, lambda i, val: val + dgrad_f(eye[i])[:, i], jnp.zeros(states))
0, n, lambda i, val: val + hessian_diagonal(i),
jnp.zeros(states, dtype=dtype))
result -= 0.5 * jnp.sum(primal ** 2, axis=-1)
if complex_output:
result += 0.5 * jnp.sum(phase_primal ** 2, axis=-1)
result -= 1.j * jnp.sum(primal * phase_primal, axis=-1)

return result - 0.5 * jnp.sum(primal ** 2, axis=-1)
return result

def _lapl_over_f(params, data):
"""Return the kinetic energy (divided by psi) summed over excited states."""
Expand All @@ -208,16 +235,28 @@ def _lapl_over_f(params, data):
# CAUTION!! Only the first array of spins is being passed!
f_closure = lambda x: f(params, x, spins_[0], data.atoms, data.charges)
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
sign_mat, log_out = folx.batched_vmap(f_wrapped, 1)(pos_)
sign_out, log_out = folx.batched_vmap(f_wrapped, 1)(pos_)
log_mat = log_out.x
lapl = -(log_out.laplacian +
jnp.sum(log_out.jacobian.dense_array ** 2, axis=-2)) / 2
if complex_output:
sign_mat = sign_out.x
lapl -= 0.5j * sign_out.laplacian
lapl += 0.5 * jnp.sum(sign_out.jacobian.dense_array ** 2, axis=-2)
lapl -= 1.j * jnp.sum(sign_out.jacobian.dense_array *
log_out.jacobian.dense_array, axis=-2)
else:
sign_mat = sign_out
else:
raise NotImplementedError(f'Laplacian method {laplacian_method} '
'not implemented with excited states.')

# psi_i(r_j)
# subtract off largest value to avoid under/overflow
psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j)
if complex_output:
psi_mat = jnp.exp(log_mat + 1.j * sign_mat - jnp.max(log_mat))
else:
psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat))
kpsi_mat = lapl * psi_mat # K psi_i(r_j)
return psi_mat, kpsi_mat

Expand Down Expand Up @@ -312,13 +351,13 @@ def local_energy(
energy of the wavefunction given the parameters params, RNG state key,
and a single MCMC configuration in data.
"""
if complex_output and states > 1:
raise NotImplementedError(
'Excited states not implemented with complex output')
del nspins

if states:
ke = excited_kinetic_energy_matrix(f, states, laplacian_method)
ke = excited_kinetic_energy_matrix(f,
states,
complex_output,
laplacian_method)
else:
ke = local_kinetic_energy(f,
use_scan=use_scan,
Expand Down
18 changes: 14 additions & 4 deletions ferminet/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,9 @@ def state_matrix(params, pos, spins, atoms, charges, **kwargs):
return state_matrix


def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike:
def make_total_ansatz(signed_network: FermiNetLike,
n: int,
complex_output: bool = False) -> FermiNetLike:
"""Construct a single-output ansatz which gives the meta-Slater determinant.
Let signed_network(params, pos, spins, options) be a function which returns
Expand All @@ -1302,6 +1304,8 @@ def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike:
Args:
signed_network: A function with the same calling convention as the FermiNet.
n: the number of excited states, needed to know how to shape the determinant
complex_output: If true, the output of the network is complex, and the
individual states return phase angles rather than signs.
Returns:
A function with a single output which combines the individual excited states
Expand All @@ -1311,11 +1315,17 @@ def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike:

def total_ansatz(params, pos, spins, atoms, charges, **kwargs):
"""Evaluate meta_determinant for a given ansatz."""
sign_mat, log_mat = state_matrix(
sign_in, log_in = state_matrix(
params, pos, spins, atoms=atoms, charges=charges, **kwargs)

logmax = jnp.max(log_mat) # logsumexp trick
sign_out, log_out = jnp.linalg.slogdet(sign_mat * jnp.exp(log_mat - logmax))
logmax = jnp.max(log_in) # logsumexp trick
if complex_output:
# sign_in is a phase angle rather than a sign for complex networks
mat_in = jnp.exp(log_in + 1.j * sign_in - logmax)
sign_out, log_out = jnp.linalg.slogdet(mat_in)
sign_out = jnp.angle(sign_out)
else:
sign_out, log_out = jnp.linalg.slogdet(sign_in * jnp.exp(log_in - logmax))
log_out += n * logmax
return sign_out, log_out

Expand Down
5 changes: 3 additions & 2 deletions ferminet/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def _config_params():
'states': 0,
'laplacian': 'default',
}
for states, laplacian in itertools.product((0, 2), ('default', 'folx')):
for states, laplacian, complex_ in itertools.product(
(0, 2), ('default', 'folx'), (True, False)):
yield {
'system': 'Li',
'optimizer': 'kfac',
'complex_': False,
'complex_': complex_,
'states': states,
'laplacian': laplacian
}
Expand Down
41 changes: 27 additions & 14 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None):
else:
envelope = envelopes.make_isotropic_envelope()

use_complex = cfg.network.get('complex', False)
if cfg.network.network_type == 'ferminet':
network = networks.make_fermi_net(
nspins,
Expand All @@ -472,7 +473,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None):
bias_orbitals=cfg.network.bias_orbitals,
full_det=cfg.network.full_det,
rescale_inputs=cfg.network.get('rescale_inputs', False),
complex_output=cfg.network.get('complex', False),
complex_output=use_complex,
**cfg.network.ferminet,
)
elif cfg.network.network_type == 'psiformer':
Expand All @@ -487,7 +488,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None):
jastrow=cfg.network.get('jastrow', 'default'),
bias_orbitals=cfg.network.bias_orbitals,
rescale_inputs=cfg.network.get('rescale_inputs', False),
complex_output=cfg.network.get('complex', False),
complex_output=use_complex,
**cfg.network.psiformer,
)
key, subkey = jax.random.split(key)
Expand All @@ -498,7 +499,8 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None):
if cfg.system.get('states', 0):
logabs_network = utils.select_output(
networks.make_total_ansatz(signed_network,
cfg.system.get('states', 0)), 1)
cfg.system.get('states', 0),
complex_output=use_complex), 1)
else:
logabs_network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1]
batch_network = jax.vmap(
Expand All @@ -508,12 +510,23 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None):
# Exclusively when computing the gradient wrt the energy for complex
# wavefunctions, it is necessary to have log(psi) rather than log(|psi|).
# This is unused if the wavefunction is real-valued.
def log_network(*args, **kwargs):
if not cfg.network.get('complex', False):
raise ValueError('This function should never be used if the '
'wavefunction is real-valued.')
phase, mag = signed_network(*args, **kwargs)
return mag + 1.j * phase
if cfg.system.get('states', 0):
def log_network(*args, **kwargs):
if not use_complex:
raise ValueError('This function should never be used if the '
'wavefunction is real-valued.')
meta_net = networks.make_total_ansatz(signed_network,
cfg.system.get('states', 0),
complex_output=True)
phase, mag = meta_net(*args, **kwargs)
return mag + 1.j * phase
else:
def log_network(*args, **kwargs):
if not use_complex:
raise ValueError('This function should never be used if the '
'wavefunction is real-valued.')
phase, mag = signed_network(*args, **kwargs)
return mag + 1.j * phase

# Set up checkpointing and restore params/data if necessary
# Mirror behaviour of checkpoints in TF FermiNet.
Expand Down Expand Up @@ -696,28 +709,28 @@ def log_network(*args, **kwargs):
charges=charges,
nspins=nspins,
use_scan=False,
complex_output=cfg.network.get('complex', False),
complex_output=use_complex,
laplacian_method=laplacian_method,
states=cfg.system.get('states', 0),
pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'),
pp_symbols=pp_symbols if cfg.system.get('use_pp') else None)
if cfg.optim.objective == 'vmc':
evaluate_loss = qmc_loss_functions.make_loss(
log_network if cfg.network.get('complex', False) else logabs_network,
log_network if use_complex else logabs_network,
local_energy,
clip_local_energy=cfg.optim.clip_local_energy,
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=cfg.network.get('complex', False),
complex_output=use_complex,
)
elif cfg.optim.objective == 'wqmc':
evaluate_loss = qmc_loss_functions.make_wqmc_loss(
log_network if cfg.network.get('complex', False) else logabs_network,
log_network if use_complex else logabs_network,
local_energy,
clip_local_energy=cfg.optim.clip_local_energy,
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=cfg.network.get('complex', False),
complex_output=use_complex,
vmc_weight=cfg.optim.get('vmc_weight', 1.0)
)
else:
Expand Down

0 comments on commit fa0f0d0

Please sign in to comment.