Skip to content

Commit

Permalink
Add nodes and example for excited states
Browse files Browse the repository at this point in the history
  • Loading branch information
tautomer committed Sep 13, 2023
1 parent b659fb5 commit 5e8f3d9
Show file tree
Hide file tree
Showing 9 changed files with 465 additions and 276 deletions.
67 changes: 67 additions & 0 deletions docs/source/examples/excited_states.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
Excited States Training
==============

hippynn now is able to predict excited-state energies, transition dipoles, and
the non-adiabatic coupling vectors (NACR) for a given molecule.

Multi-targets nodes are recommended due to efficiency and fewer recursive
layers.

For energies, the node can be constructed just like the ground-state
counterpart::

energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1})
mol_energy = energy.mol_energy
mol_energy.db_name = "E"

Note that a ``multi-target node`` is used here, defined by the keyword
``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of
states in consideration. The extra state is for the ground state, which is often
useful. The database name is simply `E` with a shape of ``(n_molecules,
n_states+1)``.

Predicting the transition dipoles is also similar to the ground-state permanent
dipole::

charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states})
dipole = physics.DipoleNode("D", (charge, positions), db_name="D")

The database name is `D` with a shape of ``(n_molecules, n_states, 3)``.

For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE
instead::

nacr = physics.NACRMultiStateNode(
"ScaledNACR",
(charge, positions, energy),
db_name="ScaledNACR",
module_kwargs={"n_target": n_states},
)

For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed
in the following way

.. math::
\boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}}
:math:`E_{ij}` is energy difference between state `i` and `j`, which is
calculated internally in the NACR node based on the input of the ``energy``
node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code.
:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition
atomic charges for state `i` and `j` contained in the ``charge`` node. This
charge node can be constructed from scratch or reused from the dipole
predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules,
n_states*(n_states-1)/2, 3*n_atoms)``.

Due to the phase problem, when the loss function is constructed, the
`phase-less` version of MAE or RMSE should be used::

energy_mae = loss.MAELoss.of_node(energy)
dipole_mae = loss.MAEPhaseLoss.of_node(dipole)
nacr_mae = loss.MAEPhaseLoss.of_node(nacr)

:func:`~hippynn.graphs.nodes.loss.MAEPhaseLoss` and
:func:`~hippynn.graphs.nodes.loss.MSEPhaseLoss` are the `phase-less` version MAE
and MSE, respectively, behaving exactly like the common version.

For a complete script, please take a look at ``examples/excited_states.py``.
1 change: 1 addition & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ the examples are just snippets. For fully-fledged examples see the
restarting
ase_calculator
mliap_unified
excited_states

173 changes: 173 additions & 0 deletions examples/excited_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import json

import matplotlib
import numpy as np
import torch

import hippynn
from hippynn import plotting
from hippynn.experiment import setup_training, train_model
from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau
from hippynn.graphs import inputs, loss, networks, physics, targets

matplotlib.use("Agg")
# default types for torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_default_dtype(torch.float32)

hippynn.settings.WARN_LOW_DISTANCES = False
hippynn.settings.TRANSPARENT_PLOT = True

n_atoms = 10
n_states = 3
plot_frequency = 100
dipole_weight = 4
nacr_weight = 2
l2_weight = 2e-5

# Hyperparameters for the network
network_params = {
"possible_species": [0, 1, 6, 7],
"n_features": 30,
"n_sensitivities": 28,
"dist_soft_min": 0.7665723566179274,
"dist_soft_max": 3.4134447177301515,
"dist_hard_max": 4.6860240434651805,
"n_interaction_layers": 3,
"n_atom_layers": 3,
}
# dump parameters to the log file
print("Network parameters\n\n", json.dumps(network_params, indent=4))

with hippynn.tools.log_terminal("training_log.txt", "wt"):
# build network
species = inputs.SpeciesNode(db_name="Z")
positions = inputs.PositionsNode(db_name="R")
network = networks.Hipnn(
"hipnn_model", (species, positions), module_kwargs=network_params
)
# add energy
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1})
mol_energy = energy.mol_energy
mol_energy.db_name = "E"
# add dipole
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states})
dipole = physics.DipoleNode("D", (charge, positions), db_name="D")
# add NACR
nacr = physics.NACRMultiStateNode(
"ScaledNACR",
(charge, positions, energy),
db_name="ScaledNACR",
module_kwargs={"n_target": n_states},
)
# set up plotter
plotter = []
for node in [mol_energy, dipole, nacr]:
plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False))
for i in range(network_params["n_interaction_layers"]):
plotter.append(
plotting.SensitivityPlot(
network.torch_module.sensitivity_layers[i],
saved=f"Sensitivity_{i}.pdf",
shown=False,
)
)
plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency)
# build the loss function
validation_losses = {}
# energy
energy_rmse = loss.MSELoss.of_node(energy) ** 0.5
validation_losses["E-RMSE"] = energy_rmse
energy_mae = loss.MAELoss.of_node(energy)
validation_losses["E-MAE"] = energy_mae
energy_loss = energy_rmse + energy_mae
validation_losses["E-Loss"] = energy_loss
total_loss = energy_loss
# dipole
dipole_rmse = loss.MSEPhaseLoss.of_node(dipole) ** 0.5
validation_losses["D-RMSE"] = dipole_rmse
dipole_mae = loss.MAEPhaseLoss.of_node(dipole)
validation_losses["D-MAE"] = dipole_mae
dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae
validation_losses["D-Loss"] = dipole_loss
total_loss += dipole_weight * dipole_loss
# nacr
nacr_rmse = loss.MSEPhaseLoss.of_node(nacr) ** 0.5
validation_losses["NACR-RMSE"] = nacr_rmse
nacr_mae = loss.MAEPhaseLoss.of_node(nacr)
validation_losses["NACR-MAE"] = nacr_mae
nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae
validation_losses["NACR-Loss"] = nacr_loss
total_loss += nacr_weight * nacr_loss
# l2 regularization
l2_reg = loss.l2reg(network)
validation_losses["L2"] = l2_reg
loss_regularization = l2_weight * l2_reg
# add total loss to the dictionary
validation_losses["Loss_wo_L2"] = total_loss
validation_losses["Loss"] = total_loss + loss_regularization

# set up experiment
training_modules, db_info = hippynn.experiment.assemble_for_training(
validation_losses["Loss"],
validation_losses,
plot_maker=plotter,
)
# set up the optimizer
optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3)
# use higher patience for production runs
scheduler = RaiseBatchSizeOnPlateau(
optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5
)
controller = PatienceController(
optimizer=optimizer,
scheduler=scheduler,
batch_size=32,
eval_batch_size=2048,
# use higher max_epochs for production runs
max_epochs=100,
stopping_key="Loss",
fraction_train_eval=0.1,
# use higher termination_patience for production runs
termination_patience=10,
)
experiment_params = hippynn.experiment.SetupParams(controller=controller)

# load database
database = hippynn.databases.DirectoryDatabase(
name="azo_", # Prefix for arrays in the directory
directory="./database",
seed=114514, # Random seed for splitting data
**db_info, # Adds the inputs and targets db_names from the model as things to load
)
# use 10% of the dataset just for quick testing purpose
database.make_random_split("train", 0.07)
database.make_random_split("valid", 0.02)
database.make_random_split("test", 0.01)
database.splitting_completed = True
# split the whole dataset into train, valid, test in the ratio of 7:2:1
# database.make_trainvalidtest_split(0.1, 0.2)

# set up training
training_modules, controller, metric_tracker = setup_training(
training_modules=training_modules,
setup_params=experiment_params,
)
# train model
metric_tracker = train_model(
training_modules,
database,
controller,
metric_tracker,
callbacks=None,
batch_callbacks=None,
)

del network_params["possible_species"]
network_params["metric"] = metric_tracker.best_metric_values
network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times)
network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"]

with open("training_summary.json", "w") as out:
json.dump(network_params, out, indent=4)
1 change: 0 additions & 1 deletion hippynn/additional/__init__.py

This file was deleted.

Loading

0 comments on commit 5e8f3d9

Please sign in to comment.