-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add nodes and example for excited states
- Loading branch information
Showing
9 changed files
with
465 additions
and
276 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
Excited States Training | ||
============== | ||
|
||
hippynn now is able to predict excited-state energies, transition dipoles, and | ||
the non-adiabatic coupling vectors (NACR) for a given molecule. | ||
|
||
Multi-targets nodes are recommended due to efficiency and fewer recursive | ||
layers. | ||
|
||
For energies, the node can be constructed just like the ground-state | ||
counterpart:: | ||
|
||
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) | ||
mol_energy = energy.mol_energy | ||
mol_energy.db_name = "E" | ||
|
||
Note that a ``multi-target node`` is used here, defined by the keyword | ||
``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of | ||
states in consideration. The extra state is for the ground state, which is often | ||
useful. The database name is simply `E` with a shape of ``(n_molecules, | ||
n_states+1)``. | ||
|
||
Predicting the transition dipoles is also similar to the ground-state permanent | ||
dipole:: | ||
|
||
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) | ||
dipole = physics.DipoleNode("D", (charge, positions), db_name="D") | ||
|
||
The database name is `D` with a shape of ``(n_molecules, n_states, 3)``. | ||
|
||
For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE | ||
instead:: | ||
|
||
nacr = physics.NACRMultiStateNode( | ||
"ScaledNACR", | ||
(charge, positions, energy), | ||
db_name="ScaledNACR", | ||
module_kwargs={"n_target": n_states}, | ||
) | ||
|
||
For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed | ||
in the following way | ||
|
||
.. math:: | ||
\boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}} | ||
:math:`E_{ij}` is energy difference between state `i` and `j`, which is | ||
calculated internally in the NACR node based on the input of the ``energy`` | ||
node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code. | ||
:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition | ||
atomic charges for state `i` and `j` contained in the ``charge`` node. This | ||
charge node can be constructed from scratch or reused from the dipole | ||
predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules, | ||
n_states*(n_states-1)/2, 3*n_atoms)``. | ||
|
||
Due to the phase problem, when the loss function is constructed, the | ||
`phase-less` version of MAE or RMSE should be used:: | ||
|
||
energy_mae = loss.MAELoss.of_node(energy) | ||
dipole_mae = loss.MAEPhaseLoss.of_node(dipole) | ||
nacr_mae = loss.MAEPhaseLoss.of_node(nacr) | ||
|
||
:func:`~hippynn.graphs.nodes.loss.MAEPhaseLoss` and | ||
:func:`~hippynn.graphs.nodes.loss.MSEPhaseLoss` are the `phase-less` version MAE | ||
and MSE, respectively, behaving exactly like the common version. | ||
|
||
For a complete script, please take a look at ``examples/excited_states.py``. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import json | ||
|
||
import matplotlib | ||
import numpy as np | ||
import torch | ||
|
||
import hippynn | ||
from hippynn import plotting | ||
from hippynn.experiment import setup_training, train_model | ||
from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau | ||
from hippynn.graphs import inputs, loss, networks, physics, targets | ||
|
||
matplotlib.use("Agg") | ||
# default types for torch | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.set_default_dtype(torch.float32) | ||
|
||
hippynn.settings.WARN_LOW_DISTANCES = False | ||
hippynn.settings.TRANSPARENT_PLOT = True | ||
|
||
n_atoms = 10 | ||
n_states = 3 | ||
plot_frequency = 100 | ||
dipole_weight = 4 | ||
nacr_weight = 2 | ||
l2_weight = 2e-5 | ||
|
||
# Hyperparameters for the network | ||
network_params = { | ||
"possible_species": [0, 1, 6, 7], | ||
"n_features": 30, | ||
"n_sensitivities": 28, | ||
"dist_soft_min": 0.7665723566179274, | ||
"dist_soft_max": 3.4134447177301515, | ||
"dist_hard_max": 4.6860240434651805, | ||
"n_interaction_layers": 3, | ||
"n_atom_layers": 3, | ||
} | ||
# dump parameters to the log file | ||
print("Network parameters\n\n", json.dumps(network_params, indent=4)) | ||
|
||
with hippynn.tools.log_terminal("training_log.txt", "wt"): | ||
# build network | ||
species = inputs.SpeciesNode(db_name="Z") | ||
positions = inputs.PositionsNode(db_name="R") | ||
network = networks.Hipnn( | ||
"hipnn_model", (species, positions), module_kwargs=network_params | ||
) | ||
# add energy | ||
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) | ||
mol_energy = energy.mol_energy | ||
mol_energy.db_name = "E" | ||
# add dipole | ||
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) | ||
dipole = physics.DipoleNode("D", (charge, positions), db_name="D") | ||
# add NACR | ||
nacr = physics.NACRMultiStateNode( | ||
"ScaledNACR", | ||
(charge, positions, energy), | ||
db_name="ScaledNACR", | ||
module_kwargs={"n_target": n_states}, | ||
) | ||
# set up plotter | ||
plotter = [] | ||
for node in [mol_energy, dipole, nacr]: | ||
plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False)) | ||
for i in range(network_params["n_interaction_layers"]): | ||
plotter.append( | ||
plotting.SensitivityPlot( | ||
network.torch_module.sensitivity_layers[i], | ||
saved=f"Sensitivity_{i}.pdf", | ||
shown=False, | ||
) | ||
) | ||
plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency) | ||
# build the loss function | ||
validation_losses = {} | ||
# energy | ||
energy_rmse = loss.MSELoss.of_node(energy) ** 0.5 | ||
validation_losses["E-RMSE"] = energy_rmse | ||
energy_mae = loss.MAELoss.of_node(energy) | ||
validation_losses["E-MAE"] = energy_mae | ||
energy_loss = energy_rmse + energy_mae | ||
validation_losses["E-Loss"] = energy_loss | ||
total_loss = energy_loss | ||
# dipole | ||
dipole_rmse = loss.MSEPhaseLoss.of_node(dipole) ** 0.5 | ||
validation_losses["D-RMSE"] = dipole_rmse | ||
dipole_mae = loss.MAEPhaseLoss.of_node(dipole) | ||
validation_losses["D-MAE"] = dipole_mae | ||
dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae | ||
validation_losses["D-Loss"] = dipole_loss | ||
total_loss += dipole_weight * dipole_loss | ||
# nacr | ||
nacr_rmse = loss.MSEPhaseLoss.of_node(nacr) ** 0.5 | ||
validation_losses["NACR-RMSE"] = nacr_rmse | ||
nacr_mae = loss.MAEPhaseLoss.of_node(nacr) | ||
validation_losses["NACR-MAE"] = nacr_mae | ||
nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae | ||
validation_losses["NACR-Loss"] = nacr_loss | ||
total_loss += nacr_weight * nacr_loss | ||
# l2 regularization | ||
l2_reg = loss.l2reg(network) | ||
validation_losses["L2"] = l2_reg | ||
loss_regularization = l2_weight * l2_reg | ||
# add total loss to the dictionary | ||
validation_losses["Loss_wo_L2"] = total_loss | ||
validation_losses["Loss"] = total_loss + loss_regularization | ||
|
||
# set up experiment | ||
training_modules, db_info = hippynn.experiment.assemble_for_training( | ||
validation_losses["Loss"], | ||
validation_losses, | ||
plot_maker=plotter, | ||
) | ||
# set up the optimizer | ||
optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3) | ||
# use higher patience for production runs | ||
scheduler = RaiseBatchSizeOnPlateau( | ||
optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5 | ||
) | ||
controller = PatienceController( | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
batch_size=32, | ||
eval_batch_size=2048, | ||
# use higher max_epochs for production runs | ||
max_epochs=100, | ||
stopping_key="Loss", | ||
fraction_train_eval=0.1, | ||
# use higher termination_patience for production runs | ||
termination_patience=10, | ||
) | ||
experiment_params = hippynn.experiment.SetupParams(controller=controller) | ||
|
||
# load database | ||
database = hippynn.databases.DirectoryDatabase( | ||
name="azo_", # Prefix for arrays in the directory | ||
directory="./database", | ||
seed=114514, # Random seed for splitting data | ||
**db_info, # Adds the inputs and targets db_names from the model as things to load | ||
) | ||
# use 10% of the dataset just for quick testing purpose | ||
database.make_random_split("train", 0.07) | ||
database.make_random_split("valid", 0.02) | ||
database.make_random_split("test", 0.01) | ||
database.splitting_completed = True | ||
# split the whole dataset into train, valid, test in the ratio of 7:2:1 | ||
# database.make_trainvalidtest_split(0.1, 0.2) | ||
|
||
# set up training | ||
training_modules, controller, metric_tracker = setup_training( | ||
training_modules=training_modules, | ||
setup_params=experiment_params, | ||
) | ||
# train model | ||
metric_tracker = train_model( | ||
training_modules, | ||
database, | ||
controller, | ||
metric_tracker, | ||
callbacks=None, | ||
batch_callbacks=None, | ||
) | ||
|
||
del network_params["possible_species"] | ||
network_params["metric"] = metric_tracker.best_metric_values | ||
network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times) | ||
network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"] | ||
|
||
with open("training_summary.json", "w") as out: | ||
json.dump(network_params, out, indent=4) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.