diff --git a/ferminet/hamiltonian.py b/ferminet/hamiltonian.py index d732dc8..4ef68bd 100644 --- a/ferminet/hamiltonian.py +++ b/ferminet/hamiltonian.py @@ -139,20 +139,18 @@ def grad_phase_closure(x): return result elif laplacian_method == 'folx': - if complex_output: - raise NotImplementedError('Forward laplacian not yet supported for' - 'complex-valued outputs.') - else: - def _lapl_over_f(params, data): - f_closure = lambda x: logabs_f(params, - x, - data.spins, - data.atoms, - data.charges) - f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6) - output = f_wrapped(data.positions) - return - (output.laplacian + - jnp.sum(output.jacobian.dense_array ** 2)) / 2 + def _lapl_over_f(params, data): + f_closure = lambda x: f(params, x, data.spins, data.atoms, data.charges) + f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6) + output = f_wrapped(data.positions) + result = - (output[1].laplacian + + jnp.sum(output[1].jacobian.dense_array ** 2)) / 2 + if complex_output: + result -= 0.5j * output[0].laplacian + result += 0.5 * jnp.sum(output[0].jacobian.dense_array ** 2) + result -= 1.j * jnp.sum(output[0].jacobian.dense_array * + output[1].jacobian.dense_array) + return result else: raise NotImplementedError(f'Laplacian method {laplacian_method} ' 'not implemented.') @@ -163,6 +161,7 @@ def _lapl_over_f(params, data): def excited_kinetic_energy_matrix( f: networks.FermiNetLike, states: int, + complex_output: bool = False, laplacian_method: str = 'default') -> KineticEnergy: """Creates a f'n which evaluates the matrix of local kinetic energies. @@ -170,6 +169,7 @@ def excited_kinetic_energy_matrix( f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where each array contains one element per excited state. states: the number of excited states + complex_output: If true, the output of f is complex-valued. laplacian_method: Laplacian calculation method. One of: 'default': take jvp(grad), looping over inputs 'folx': use Microsoft's implementation of forward laplacian @@ -188,10 +188,37 @@ def _lapl_all_states(params, pos, spins, atoms, charges): grad_f_closure = lambda x: grad_f(params, x, spins, atoms, charges) primal, dgrad_f = jax.linearize(grad_f_closure, pos) + if complex_output: + grad_phase = jax.jacrev(utils.select_output(f, 0), argnums=1) + def grad_phase_closure(x): + return grad_phase(params, x, spins, atoms, charges) + phase_primal, dgrad_phase = jax.linearize(grad_phase_closure, pos) + hessian_diagonal = ( + lambda i: dgrad_f(eye[i])[:, i] + 1.j * dgrad_phase(eye[i])[:, i] + ) + else: + phase_primal = 1.0 + hessian_diagonal = lambda i: dgrad_f(eye[i])[:, i] + + if complex_output: + if pos.dtype == jnp.float32: + dtype = jnp.complex64 + elif pos.dtype == jnp.float64: + dtype = jnp.complex128 + else: + raise ValueError(f'Unsupported dtype for input: {pos.dtype}') + else: + dtype = pos.dtype + result = -0.5 * lax.fori_loop( - 0, n, lambda i, val: val + dgrad_f(eye[i])[:, i], jnp.zeros(states)) + 0, n, lambda i, val: val + hessian_diagonal(i), + jnp.zeros(states, dtype=dtype)) + result -= 0.5 * jnp.sum(primal ** 2, axis=-1) + if complex_output: + result += 0.5 * jnp.sum(phase_primal ** 2, axis=-1) + result -= 1.j * jnp.sum(primal * phase_primal, axis=-1) - return result - 0.5 * jnp.sum(primal ** 2, axis=-1) + return result def _lapl_over_f(params, data): """Return the kinetic energy (divided by psi) summed over excited states.""" @@ -208,16 +235,28 @@ def _lapl_over_f(params, data): # CAUTION!! Only the first array of spins is being passed! f_closure = lambda x: f(params, x, spins_[0], data.atoms, data.charges) f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6) - sign_mat, log_out = folx.batched_vmap(f_wrapped, 1)(pos_) + sign_out, log_out = folx.batched_vmap(f_wrapped, 1)(pos_) log_mat = log_out.x lapl = -(log_out.laplacian + jnp.sum(log_out.jacobian.dense_array ** 2, axis=-2)) / 2 + if complex_output: + sign_mat = sign_out.x + lapl -= 0.5j * sign_out.laplacian + lapl += 0.5 * jnp.sum(sign_out.jacobian.dense_array ** 2, axis=-2) + lapl -= 1.j * jnp.sum(sign_out.jacobian.dense_array * + log_out.jacobian.dense_array, axis=-2) + else: + sign_mat = sign_out else: raise NotImplementedError(f'Laplacian method {laplacian_method} ' 'not implemented with excited states.') + # psi_i(r_j) # subtract off largest value to avoid under/overflow - psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j) + if complex_output: + psi_mat = jnp.exp(log_mat + 1.j * sign_mat - jnp.max(log_mat)) + else: + psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) kpsi_mat = lapl * psi_mat # K psi_i(r_j) return psi_mat, kpsi_mat @@ -312,13 +351,13 @@ def local_energy( energy of the wavefunction given the parameters params, RNG state key, and a single MCMC configuration in data. """ - if complex_output and states > 1: - raise NotImplementedError( - 'Excited states not implemented with complex output') del nspins if states: - ke = excited_kinetic_energy_matrix(f, states, laplacian_method) + ke = excited_kinetic_energy_matrix(f, + states, + complex_output, + laplacian_method) else: ke = local_kinetic_energy(f, use_scan=use_scan, diff --git a/ferminet/networks.py b/ferminet/networks.py index f8ddda5..e4ffce3 100644 --- a/ferminet/networks.py +++ b/ferminet/networks.py @@ -1288,7 +1288,9 @@ def state_matrix(params, pos, spins, atoms, charges, **kwargs): return state_matrix -def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike: +def make_total_ansatz(signed_network: FermiNetLike, + n: int, + complex_output: bool = False) -> FermiNetLike: """Construct a single-output ansatz which gives the meta-Slater determinant. Let signed_network(params, pos, spins, options) be a function which returns @@ -1302,6 +1304,8 @@ def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike: Args: signed_network: A function with the same calling convention as the FermiNet. n: the number of excited states, needed to know how to shape the determinant + complex_output: If true, the output of the network is complex, and the + individual states return phase angles rather than signs. Returns: A function with a single output which combines the individual excited states @@ -1311,11 +1315,17 @@ def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike: def total_ansatz(params, pos, spins, atoms, charges, **kwargs): """Evaluate meta_determinant for a given ansatz.""" - sign_mat, log_mat = state_matrix( + sign_in, log_in = state_matrix( params, pos, spins, atoms=atoms, charges=charges, **kwargs) - logmax = jnp.max(log_mat) # logsumexp trick - sign_out, log_out = jnp.linalg.slogdet(sign_mat * jnp.exp(log_mat - logmax)) + logmax = jnp.max(log_in) # logsumexp trick + if complex_output: + # sign_in is a phase angle rather than a sign for complex networks + mat_in = jnp.exp(log_in + 1.j * sign_in - logmax) + sign_out, log_out = jnp.linalg.slogdet(mat_in) + sign_out = jnp.angle(sign_out) + else: + sign_out, log_out = jnp.linalg.slogdet(sign_in * jnp.exp(log_in - logmax)) log_out += n * logmax return sign_out, log_out diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 41efee9..9cfad83 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -65,11 +65,12 @@ def _config_params(): 'states': 0, 'laplacian': 'default', } - for states, laplacian in itertools.product((0, 2), ('default', 'folx')): + for states, laplacian, complex_ in itertools.product( + (0, 2), ('default', 'folx'), (True, False)): yield { 'system': 'Li', 'optimizer': 'kfac', - 'complex_': False, + 'complex_': complex_, 'states': states, 'laplacian': laplacian } diff --git a/ferminet/train.py b/ferminet/train.py index 5c58a66..1200772 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -459,6 +459,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): else: envelope = envelopes.make_isotropic_envelope() + use_complex = cfg.network.get('complex', False) if cfg.network.network_type == 'ferminet': network = networks.make_fermi_net( nspins, @@ -472,7 +473,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): bias_orbitals=cfg.network.bias_orbitals, full_det=cfg.network.full_det, rescale_inputs=cfg.network.get('rescale_inputs', False), - complex_output=cfg.network.get('complex', False), + complex_output=use_complex, **cfg.network.ferminet, ) elif cfg.network.network_type == 'psiformer': @@ -487,7 +488,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): jastrow=cfg.network.get('jastrow', 'default'), bias_orbitals=cfg.network.bias_orbitals, rescale_inputs=cfg.network.get('rescale_inputs', False), - complex_output=cfg.network.get('complex', False), + complex_output=use_complex, **cfg.network.psiformer, ) key, subkey = jax.random.split(key) @@ -498,7 +499,8 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): if cfg.system.get('states', 0): logabs_network = utils.select_output( networks.make_total_ansatz(signed_network, - cfg.system.get('states', 0)), 1) + cfg.system.get('states', 0), + complex_output=use_complex), 1) else: logabs_network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1] batch_network = jax.vmap( @@ -508,12 +510,23 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): # Exclusively when computing the gradient wrt the energy for complex # wavefunctions, it is necessary to have log(psi) rather than log(|psi|). # This is unused if the wavefunction is real-valued. - def log_network(*args, **kwargs): - if not cfg.network.get('complex', False): - raise ValueError('This function should never be used if the ' - 'wavefunction is real-valued.') - phase, mag = signed_network(*args, **kwargs) - return mag + 1.j * phase + if cfg.system.get('states', 0): + def log_network(*args, **kwargs): + if not use_complex: + raise ValueError('This function should never be used if the ' + 'wavefunction is real-valued.') + meta_net = networks.make_total_ansatz(signed_network, + cfg.system.get('states', 0), + complex_output=True) + phase, mag = meta_net(*args, **kwargs) + return mag + 1.j * phase + else: + def log_network(*args, **kwargs): + if not use_complex: + raise ValueError('This function should never be used if the ' + 'wavefunction is real-valued.') + phase, mag = signed_network(*args, **kwargs) + return mag + 1.j * phase # Set up checkpointing and restore params/data if necessary # Mirror behaviour of checkpoints in TF FermiNet. @@ -696,28 +709,28 @@ def log_network(*args, **kwargs): charges=charges, nspins=nspins, use_scan=False, - complex_output=cfg.network.get('complex', False), + complex_output=use_complex, laplacian_method=laplacian_method, states=cfg.system.get('states', 0), pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'), pp_symbols=pp_symbols if cfg.system.get('use_pp') else None) if cfg.optim.objective == 'vmc': evaluate_loss = qmc_loss_functions.make_loss( - log_network if cfg.network.get('complex', False) else logabs_network, + log_network if use_complex else logabs_network, local_energy, clip_local_energy=cfg.optim.clip_local_energy, clip_from_median=cfg.optim.clip_median, center_at_clipped_energy=cfg.optim.center_at_clip, - complex_output=cfg.network.get('complex', False), + complex_output=use_complex, ) elif cfg.optim.objective == 'wqmc': evaluate_loss = qmc_loss_functions.make_wqmc_loss( - log_network if cfg.network.get('complex', False) else logabs_network, + log_network if use_complex else logabs_network, local_energy, clip_local_energy=cfg.optim.clip_local_energy, clip_from_median=cfg.optim.clip_median, center_at_clipped_energy=cfg.optim.center_at_clip, - complex_output=cfg.network.get('complex', False), + complex_output=use_complex, vmc_weight=cfg.optim.get('vmc_weight', 1.0) ) else: