From 8e1c747fb6b23852e5b6a09851ceffb02205732b Mon Sep 17 00:00:00 2001 From: David Pfau Date: Thu, 13 Jun 2024 19:12:11 +0100 Subject: [PATCH] Add option to include spin-magnitude term in Hamiltonian PiperOrigin-RevId: 643060421 Change-Id: If7d0779f080ccf3bfdcd5a583dc8178d70fd5234 --- ferminet/base_config.py | 4 ++++ ferminet/tests/train_test.py | 18 ++++++++++++++++++ ferminet/train.py | 27 +++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/ferminet/base_config.py b/ferminet/base_config.py index 03de5f6..b6e52e0 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -80,6 +80,10 @@ def default() -> ml_collections.ConfigDict: # If using Wasserstein QMC, this parameter controls the amount of # "default" VMC gradient to mix in. Otherwise, it is ignored. 'vmc_weight': 0.0, + # If nonzero, add a term to the Hamiltonian proportional to the spin + # magnitude. Useful for removing non-singlet states from excited + # state calculations. + 'spin_energy': 0.0, # KFAC hyperparameters. See KFAC documentation for details. 'kfac': { 'invert_every': 1, diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 88eeb2a..7b6bf21 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -113,6 +113,24 @@ def test_training_step(self, system, optimizer, complex_, states, laplacian): # ensure they actually run without a top-level error. train.train(cfg) + @parameterized.parameters([{'states': 0}, {'states': 3}]) + def test_s2_energy(self, states): + cfg = diatomic.get_config() + cfg.system.molecule_name = 'LiH' + cfg.network.ferminet.hidden_dims = ((16, 4),) * 2 + cfg.network.determinants = 2 + cfg.batch_size = 32 + cfg.system.states = states + cfg.pretrain.iterations = 10 + cfg.mcmc.burn_in = 10 + cfg.optim.iterations = 3 + cfg.optim.spin_energy = 1.0 + cfg.log.save_path = self.create_tempdir().full_path + cfg = base_config.resolve(cfg) + # Calculation is too small to test the results for accuracy. Test just to + # ensure they actually run without a top-level error. + train.train(cfg) + def test_random_pretraining(self): cfg = diatomic.get_config() cfg.system.molecule_name = 'LiH' diff --git a/ferminet/train.py b/ferminet/train.py index e3bc64d..10e04cc 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -696,7 +696,7 @@ def log_network(*args, **kwargs): cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1)) local_energy_module = importlib.import_module(local_energy_module) make_local_energy = getattr(local_energy_module, local_energy_fn) # type: hamiltonian.MakeLocalEnergy - local_energy = make_local_energy( + local_energy_fn = make_local_energy( f=signed_network, charges=charges, nspins=nspins, @@ -705,7 +705,7 @@ def log_network(*args, **kwargs): **cfg.system.make_local_energy_kwargs) else: pp_symbols = cfg.system.get('pp', {'symbols': None}).get('symbols') - local_energy = hamiltonian.local_energy( + local_energy_fn = hamiltonian.local_energy( f=signed_network, charges=charges, nspins=nspins, @@ -715,6 +715,29 @@ def log_network(*args, **kwargs): states=cfg.system.get('states', 0), pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'), pp_symbols=pp_symbols if cfg.system.get('use_pp') else None) + + if cfg.optim.get('spin_energy', 0.0) > 0.0: + # Minimize instead of just + # Create a new local_energy function that takes the weighted sum of + # the local energy and the local spin magnitude. + local_s2_fn = observables.make_s2( + signed_network, + nspins=nspins, + states=cfg.system.states) + def local_energy_and_s2_fn(params, keys, data): + local_energy, aux_data = local_energy_fn(params, keys, data) + s2 = local_s2_fn(params, data, None) + weight = cfg.optim.get('spin_energy', 0.0) + if cfg.system.states: + aux_data = aux_data + weight * s2 + local_energy_and_s2 = local_energy + weight * jnp.trace(s2) + else: + local_energy_and_s2 = local_energy + weight * s2 + return local_energy_and_s2, aux_data + local_energy = local_energy_and_s2_fn + else: + local_energy = local_energy_fn + if cfg.optim.objective == 'vmc': evaluate_loss = qmc_loss_functions.make_loss( log_network if use_complex else logabs_network,