Skip to content

Commit

Permalink
Add option to include spin-magnitude term in Hamiltonian
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643060421
Change-Id: If7d0779f080ccf3bfdcd5a583dc8178d70fd5234
  • Loading branch information
dpfau committed Aug 22, 2024
1 parent eb616a8 commit 8e1c747
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
4 changes: 4 additions & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def default() -> ml_collections.ConfigDict:
# If using Wasserstein QMC, this parameter controls the amount of
# "default" VMC gradient to mix in. Otherwise, it is ignored.
'vmc_weight': 0.0,
# If nonzero, add a term to the Hamiltonian proportional to the spin
# magnitude. Useful for removing non-singlet states from excited
# state calculations.
'spin_energy': 0.0,
# KFAC hyperparameters. See KFAC documentation for details.
'kfac': {
'invert_every': 1,
Expand Down
18 changes: 18 additions & 0 deletions ferminet/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def test_training_step(self, system, optimizer, complex_, states, laplacian):
# ensure they actually run without a top-level error.
train.train(cfg)

@parameterized.parameters([{'states': 0}, {'states': 3}])
def test_s2_energy(self, states):
cfg = diatomic.get_config()
cfg.system.molecule_name = 'LiH'
cfg.network.ferminet.hidden_dims = ((16, 4),) * 2
cfg.network.determinants = 2
cfg.batch_size = 32
cfg.system.states = states
cfg.pretrain.iterations = 10
cfg.mcmc.burn_in = 10
cfg.optim.iterations = 3
cfg.optim.spin_energy = 1.0
cfg.log.save_path = self.create_tempdir().full_path
cfg = base_config.resolve(cfg)
# Calculation is too small to test the results for accuracy. Test just to
# ensure they actually run without a top-level error.
train.train(cfg)

def test_random_pretraining(self):
cfg = diatomic.get_config()
cfg.system.molecule_name = 'LiH'
Expand Down
27 changes: 25 additions & 2 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def log_network(*args, **kwargs):
cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1))
local_energy_module = importlib.import_module(local_energy_module)
make_local_energy = getattr(local_energy_module, local_energy_fn) # type: hamiltonian.MakeLocalEnergy
local_energy = make_local_energy(
local_energy_fn = make_local_energy(
f=signed_network,
charges=charges,
nspins=nspins,
Expand All @@ -705,7 +705,7 @@ def log_network(*args, **kwargs):
**cfg.system.make_local_energy_kwargs)
else:
pp_symbols = cfg.system.get('pp', {'symbols': None}).get('symbols')
local_energy = hamiltonian.local_energy(
local_energy_fn = hamiltonian.local_energy(
f=signed_network,
charges=charges,
nspins=nspins,
Expand All @@ -715,6 +715,29 @@ def log_network(*args, **kwargs):
states=cfg.system.get('states', 0),
pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'),
pp_symbols=pp_symbols if cfg.system.get('use_pp') else None)

if cfg.optim.get('spin_energy', 0.0) > 0.0:
# Minimize <H + c * S^2> instead of just <H>
# Create a new local_energy function that takes the weighted sum of
# the local energy and the local spin magnitude.
local_s2_fn = observables.make_s2(
signed_network,
nspins=nspins,
states=cfg.system.states)
def local_energy_and_s2_fn(params, keys, data):
local_energy, aux_data = local_energy_fn(params, keys, data)
s2 = local_s2_fn(params, data, None)
weight = cfg.optim.get('spin_energy', 0.0)
if cfg.system.states:
aux_data = aux_data + weight * s2
local_energy_and_s2 = local_energy + weight * jnp.trace(s2)
else:
local_energy_and_s2 = local_energy + weight * s2
return local_energy_and_s2, aux_data
local_energy = local_energy_and_s2_fn
else:
local_energy = local_energy_fn

if cfg.optim.objective == 'vmc':
evaluate_loss = qmc_loss_functions.make_loss(
log_network if use_complex else logabs_network,
Expand Down

0 comments on commit 8e1c747

Please sign in to comment.