Skip to content

Commit

Permalink
Rescale the excited states at inference time to improve numerical sta…
Browse files Browse the repository at this point in the history
…bility

PiperOrigin-RevId: 643385267
Change-Id: I92646ed81b622c4e049b98ae74e32786978590e3
  • Loading branch information
dpfau committed Aug 22, 2024
1 parent 31bb082 commit 2383bb3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
8 changes: 6 additions & 2 deletions ferminet/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,9 +1486,13 @@ def apply(
jnp.reshape(orbital, (options.states, -1) + orbital.shape[1:])
for orbital in orbitals
]
return batch_logdet_matmul(orbitals)
result = batch_logdet_matmul(orbitals)
else:
return network_blocks.logdet_matmul(orbitals)
result = network_blocks.logdet_matmul(orbitals)
if 'state_scale' in params:
# only used at inference time for excited states
result = result[0], result[1] + params['state_scale']
return result

return Network(
options=options, init=init, apply=apply, orbitals=orbitals_apply
Expand Down
8 changes: 6 additions & 2 deletions ferminet/psiformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,13 @@ def network_apply(
jnp.reshape(orbital, (options.states, -1) + orbital.shape[1:])
for orbital in orbitals
]
return batch_logdet_matmul(orbitals)
result = batch_logdet_matmul(orbitals)
else:
return network_blocks.logdet_matmul(orbitals)
result = network_blocks.logdet_matmul(orbitals)
if 'state_scale' in params:
# only used at inference time for excited states
result = result[0], result[1] + params['state_scale']
return result

return networks.Network(
options=options,
Expand Down
21 changes: 21 additions & 0 deletions ferminet/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ def test_random_pretraining(self):
# ensure they actually run without a top-level error.
train.train(cfg)

def test_inference_step(self):
cfg = diatomic.get_config()
cfg.system.molecule_name = 'LiH'
cfg.network.ferminet.hidden_dims = ((16, 4),) * 2
cfg.network.determinants = 2
cfg.batch_size = 32
cfg.system.states = 2
cfg.pretrain.iterations = 10
cfg.mcmc.burn_in = 10
cfg.optim.iterations = 3

cfg.log.save_path = self.create_tempdir().full_path
cfg.log.save_frequency = 0 # Save at every step.
cfg = base_config.resolve(cfg)
# Trivial training run
train.train(cfg)

# Update config and run inference
cfg.optim.optimizer = 'none'
cfg = base_config.resolve(cfg)
train.train(cfg)

MOL_STRINGS = [
'H 0 0 -1; H 0 0 1',
Expand Down
19 changes: 19 additions & 0 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,25 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
logging.info('Setting initial iteration to 0.')
t_init = 0

# Excited states inference only: rescale each state to be roughly
# comparable, to avoid large outlier values in the local energy matrix.
# This is not a factor in training, as the outliers are only off-diagonal.
# This only becomes a significant factor for systems with >25 electrons.
if cfg.system.states > 0 and 'state_scale' not in params:
state_matrix = utils.select_output(
networks.make_state_matrix(signed_network,
cfg.system.states), 1)
batch_state_matrix = jax.vmap(state_matrix, (None, 0, 0, 0, 0))
pmap_state_matrix = constants.pmap(batch_state_matrix)
log_psi_vals = pmap_state_matrix(
params, data.positions, data.spins, data.atoms, data.charges)
state_scale = np.mean(log_psi_vals, axis=[0, 1, 2])
state_scale = jax.experimental.multihost_utils.broadcast_one_to_all(
state_scale)
state_scale = np.tile(state_scale[None], [jax.local_device_count(), 1])
if isinstance(params, dict): # Always true, but prevents type errors
params['state_scale'] = -state_scale

if writer_manager is None:
writer_manager = writers.Writer(
name='train_stats',
Expand Down

0 comments on commit 2383bb3

Please sign in to comment.