diff --git a/docs/source/examples/excited_states.rst b/docs/source/examples/excited_states.rst new file mode 100644 index 00000000..11974d1d --- /dev/null +++ b/docs/source/examples/excited_states.rst @@ -0,0 +1,67 @@ +Excited States Training +============== + +hippynn now is able to predict excited-state energies, transition dipoles, and +the non-adiabatic coupling vectors (NACR) for a given molecule. + +Multi-targets nodes are recommended due to efficiency and fewer recursive +layers. + +For energies, the node can be constructed just like the ground-state +counterpart:: + + energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) + mol_energy = energy.mol_energy + mol_energy.db_name = "E" + +Note that a ``multi-target node`` is used here, defined by the keyword +``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of +states in consideration. The extra state is for the ground state, which is often +useful. The database name is simply `E` with a shape of ``(n_molecules, +n_states+1)``. + +Predicting the transition dipoles is also similar to the ground-state permanent +dipole:: + + charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) + dipole = physics.DipoleNode("D", (charge, positions), db_name="D") + +The database name is `D` with a shape of ``(n_molecules, n_states, 3)``. + +For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE +instead:: + + nacr = physics.NACRMultiStateNode( + "ScaledNACR", + (charge, positions, energy), + db_name="ScaledNACR", + module_kwargs={"n_target": n_states}, + ) + +For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed +in the following way + +.. math:: + \boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}} + +:math:`E_{ij}` is energy difference between state `i` and `j`, which is +calculated internally in the NACR node based on the input of the ``energy`` +node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code. +:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition +atomic charges for state `i` and `j` contained in the ``charge`` node. This +charge node can be constructed from scratch or reused from the dipole +predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules, +n_states*(n_states-1)/2, 3*n_atoms)``. + +Due to the phase problem, when the loss function is constructed, the +`phase-less` version of MAE or RMSE should be used:: + + energy_mae = loss.MAELoss.of_node(energy) + dipole_mae = loss.MAEPhaseLoss.of_node(dipole) + nacr_mae = loss.MAEPhaseLoss.of_node(nacr) + +:func:`~hippynn.graphs.nodes.loss.MAEPhaseLoss` and +:func:`~hippynn.graphs.nodes.loss.MSEPhaseLoss` are the `phase-less` version MAE +and MSE, respectively, behaving exactly like the common version. + +For a complete script, please take a look at ``examples/excited_states.py``. diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index f7b4ba50..548b884b 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -18,4 +18,5 @@ the examples are just snippets. For fully-fledged examples see the restarting ase_calculator mliap_unified + excited_states diff --git a/examples/excited_states.py b/examples/excited_states.py new file mode 100644 index 00000000..2423668c --- /dev/null +++ b/examples/excited_states.py @@ -0,0 +1,173 @@ +import json + +import matplotlib +import numpy as np +import torch + +import hippynn +from hippynn import plotting +from hippynn.experiment import setup_training, train_model +from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau +from hippynn.graphs import inputs, loss, networks, physics, targets + +matplotlib.use("Agg") +# default types for torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_default_dtype(torch.float32) + +hippynn.settings.WARN_LOW_DISTANCES = False +hippynn.settings.TRANSPARENT_PLOT = True + +n_atoms = 10 +n_states = 3 +plot_frequency = 100 +dipole_weight = 4 +nacr_weight = 2 +l2_weight = 2e-5 + +# Hyperparameters for the network +network_params = { + "possible_species": [0, 1, 6, 7], + "n_features": 30, + "n_sensitivities": 28, + "dist_soft_min": 0.7665723566179274, + "dist_soft_max": 3.4134447177301515, + "dist_hard_max": 4.6860240434651805, + "n_interaction_layers": 3, + "n_atom_layers": 3, +} +# dump parameters to the log file +print("Network parameters\n\n", json.dumps(network_params, indent=4)) + +with hippynn.tools.log_terminal("training_log.txt", "wt"): + # build network + species = inputs.SpeciesNode(db_name="Z") + positions = inputs.PositionsNode(db_name="R") + network = networks.Hipnn( + "hipnn_model", (species, positions), module_kwargs=network_params + ) + # add energy + energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) + mol_energy = energy.mol_energy + mol_energy.db_name = "E" + # add dipole + charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) + dipole = physics.DipoleNode("D", (charge, positions), db_name="D") + # add NACR + nacr = physics.NACRMultiStateNode( + "ScaledNACR", + (charge, positions, energy), + db_name="ScaledNACR", + module_kwargs={"n_target": n_states}, + ) + # set up plotter + plotter = [] + for node in [mol_energy, dipole, nacr]: + plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False)) + for i in range(network_params["n_interaction_layers"]): + plotter.append( + plotting.SensitivityPlot( + network.torch_module.sensitivity_layers[i], + saved=f"Sensitivity_{i}.pdf", + shown=False, + ) + ) + plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency) + # build the loss function + validation_losses = {} + # energy + energy_rmse = loss.MSELoss.of_node(energy) ** 0.5 + validation_losses["E-RMSE"] = energy_rmse + energy_mae = loss.MAELoss.of_node(energy) + validation_losses["E-MAE"] = energy_mae + energy_loss = energy_rmse + energy_mae + validation_losses["E-Loss"] = energy_loss + total_loss = energy_loss + # dipole + dipole_rmse = loss.MSEPhaseLoss.of_node(dipole) ** 0.5 + validation_losses["D-RMSE"] = dipole_rmse + dipole_mae = loss.MAEPhaseLoss.of_node(dipole) + validation_losses["D-MAE"] = dipole_mae + dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae + validation_losses["D-Loss"] = dipole_loss + total_loss += dipole_weight * dipole_loss + # nacr + nacr_rmse = loss.MSEPhaseLoss.of_node(nacr) ** 0.5 + validation_losses["NACR-RMSE"] = nacr_rmse + nacr_mae = loss.MAEPhaseLoss.of_node(nacr) + validation_losses["NACR-MAE"] = nacr_mae + nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae + validation_losses["NACR-Loss"] = nacr_loss + total_loss += nacr_weight * nacr_loss + # l2 regularization + l2_reg = loss.l2reg(network) + validation_losses["L2"] = l2_reg + loss_regularization = l2_weight * l2_reg + # add total loss to the dictionary + validation_losses["Loss_wo_L2"] = total_loss + validation_losses["Loss"] = total_loss + loss_regularization + + # set up experiment + training_modules, db_info = hippynn.experiment.assemble_for_training( + validation_losses["Loss"], + validation_losses, + plot_maker=plotter, + ) + # set up the optimizer + optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3) + # use higher patience for production runs + scheduler = RaiseBatchSizeOnPlateau( + optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5 + ) + controller = PatienceController( + optimizer=optimizer, + scheduler=scheduler, + batch_size=32, + eval_batch_size=2048, + # use higher max_epochs for production runs + max_epochs=100, + stopping_key="Loss", + fraction_train_eval=0.1, + # use higher termination_patience for production runs + termination_patience=10, + ) + experiment_params = hippynn.experiment.SetupParams(controller=controller) + + # load database + database = hippynn.databases.DirectoryDatabase( + name="azo_", # Prefix for arrays in the directory + directory="./database", + seed=114514, # Random seed for splitting data + **db_info, # Adds the inputs and targets db_names from the model as things to load + ) + # use 10% of the dataset just for quick testing purpose + database.make_random_split("train", 0.07) + database.make_random_split("valid", 0.02) + database.make_random_split("test", 0.01) + database.splitting_completed = True + # split the whole dataset into train, valid, test in the ratio of 7:2:1 + # database.make_trainvalidtest_split(0.1, 0.2) + + # set up training + training_modules, controller, metric_tracker = setup_training( + training_modules=training_modules, + setup_params=experiment_params, + ) + # train model + metric_tracker = train_model( + training_modules, + database, + controller, + metric_tracker, + callbacks=None, + batch_callbacks=None, + ) + +del network_params["possible_species"] +network_params["metric"] = metric_tracker.best_metric_values +network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times) +network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"] + +with open("training_summary.json", "w") as out: + json.dump(network_params, out, indent=4) diff --git a/hippynn/additional/__init__.py b/hippynn/additional/__init__.py deleted file mode 100644 index 085649fe..00000000 --- a/hippynn/additional/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .new_ops import * diff --git a/hippynn/additional/new_ops.py b/hippynn/additional/new_ops.py deleted file mode 100644 index 54cc94e9..00000000 --- a/hippynn/additional/new_ops.py +++ /dev/null @@ -1,224 +0,0 @@ -""" -Additional nodes and loss functions used for excited states training. -""" -from typing import List, Optional, Tuple - -import torch -from torch import Tensor -from torch.optim.optimizer import Optimizer - -from ..graphs import loss -from ..graphs.nodes.base import AutoKw, SingleNode -from ..graphs.indextypes import IdxType - - -class NACR(torch.nn.Module): - """ - Compute NAC vector * ΔE. Originally in hippynn.layers.physics. - """ - - def __init__(self): - super().__init__() - - def forward( - self, - charges1: Tensor, - charges2: Tensor, - positions: Tensor, - energy1: Tensor, - energy2: Tensor, - ): - dE = energy2 - energy1 - nacr = torch.autograd.grad( - charges2, [positions], grad_outputs=[charges1], create_graph=True - )[0].reshape(len(dE), -1) - return nacr * dE - - -class NACRMultiState(torch.nn.Module): - """ - Compute NAC vector * ΔE for all paris of states. Originally in hippynn.layers.physics. - """ - - def __init__(self, n_target=1): - self.n_target = n_target - super().__init__() - - def forward(self, charges: Tensor, positions: Tensor, energies: Tensor): - # charges shape: n_molecules, n_atoms, n_targets - # positions shape: n_molecules, n_atoms, 3 - # energies shape: n_molecules, n_targets - # dE shape: n_molecules, n_targets, n_targets - dE = energies.unsqueeze(1) - energies.unsqueeze(2) - # take the upper triangle excluding the diagonal - indices = torch.triu_indices( - self.n_target, self.n_target, offset=1, device=dE.device - ) - # dE shape: n_molecules, n_pairs - # n_pairs = n_targets * (n_targets - 1) / 2 - dE = dE[..., indices[0], indices[1]] - # compute q1 * dq2/dR - nacr_ij = [] - for i, j in zip(*indices): - nacr = torch.autograd.grad( - charges[..., j], - positions, - grad_outputs=charges[..., i], - create_graph=True, - )[0] - nacr_ij.append(nacr) - # nacr shape: n_molecules, n_atoms, 3, n_pairs - nacr = torch.stack(nacr_ij, dim=1) - n_molecule, n_pairs, n_atoms, n_dims = nacr.shape - nacr = nacr.reshape(n_molecule, n_pairs, n_atoms * n_dims) - # multiply dE - return nacr * dE.unsqueeze(2) - - -class NACRNode(AutoKw, SingleNode): - """ - Compute the non-adiabatic coupling vector multiplied by the energy difference - between two states. Originally in hippynn.graphs.nodes.physics. - """ - - _input_names = "charges i", "charges j", "coordinates", "energy i", "energy j" - # _auto_module_class = physics_layers.NACR - _auto_module_class = NACR - - def __init__( - self, name: str, parents: Tuple, module="auto", module_kwargs=None, **kwargs - ): - """Automatically build the node for calculating NACR * ΔE between two states i - and j. - - :param name: name of the node - :type name: str - :param parents: parents of the NACR node in the sequence of (charges i, \ - charges j, positions, energy i, energy j) - :type parents: Tuple - :param module: _description_, defaults to "auto" - :type module: str, optional - :param module_kwargs: keyword arguments passed to the corresponding layer, - defaults to None - :type module_kwargs: dict, optional - """ - - self.module_kwargs = {} - if module_kwargs is not None: - self.module_kwargs.update(module_kwargs) - charges1, charges2, positions, energy1, energy2 = parents - positions.requires_grad = True - self._index_state = IdxType.Molecules - # self._index_state = positions._index_state - parents = ( - charges1.main_output, - charges2.main_output, - positions, - energy1.main_output, - energy2.main_output, - ) - super().__init__(name, parents, module=module, **kwargs) - - -class NACRMultiStateNode(AutoKw, SingleNode): - """ - Compute the non-adiabatic coupling vector multiplied by the energy difference - between all pairs of states. Originally in hippynn.graphs.nodes.physics. - """ - - _input_names = "charges", "coordinates", "energies" - # _auto_module_class = physics_layers.NACR - _auto_module_class = NACRMultiState - - def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): - """Automatically build the node for calculating NACR * ΔE between all pairs of - states. - - :param name: name of the node - :type name: str - :param parents: parents of the NACR node in the sequence of (charges, \ - positions, energies) - :type parents: Tuple - :param module: _description_, defaults to "auto" - :type module: str, optional - :param module_kwargs: keyword arguments passed to the corresponding layer, - defaults to None - :type module_kwargs: dict, optional - """ - - self.module_kwargs = {} - if module_kwargs is not None: - self.module_kwargs.update(module_kwargs) - charges, positions, energies = parents - positions.requires_grad = True - self._index_state = IdxType.Molecules - # self._index_state = positions._index_state - parents = ( - charges.main_output, - positions, - energies.main_output, - ) - super().__init__(name, parents, module=module, **kwargs) - - -# For loss functions with phases -def absolute_errors(predict: Tensor, true: Tensor): - """Compute the absolute errors with phases between predicted and true values. In - other words, prediction should be close to the absolute value of true, and the sign - does not matter. - - :param predict: predicted values - :type predict: torch.Tensor - :param true: true values - :type true: torch.Tensor - :return: errors - :rtype: torch.Tensor - """ - - return torch.minimum(torch.abs(true - predict), torch.abs(true + predict)) - - -def mae_with_phases(predict: Tensor, true: Tensor): - """MAE with phases - - :param predict: predicted values - :type predict: torch.Tensor - :param true: true values - :type true: torch.Tensor - :return: MAE with phases - :rtype: torch.Tensor - """ - - errors = torch.minimum( - torch.linalg.norm(true - predict, ord=1, dim=-1), - torch.linalg.norm(true + predict, ord=1, dim=-1), - ) - # errors = absolute_errors(predict, true) - return torch.sum(errors) / predict.numel() - - -def mse_with_phases(predict: Tensor, true: Tensor): - """MSE with phases - - :param predict: predicted values - :type predict: torch.Tensor - :param true: true values - :type true: torch.Tensor - :return: MSE with phases - :rtype: torch.Tensor - """ - - errors = torch.minimum( - torch.linalg.norm(true - predict, dim=-1), - torch.linalg.norm(true + predict, dim=-1), - ) - # errors = absolute_errors(predict, true) ** 2 - return torch.sum(errors**2) / predict.numel() - - -class MAEPhaseLoss(loss._BaseCompareLoss, op=mae_with_phases): - pass - - -class MSEPhaseLoss(loss._BaseCompareLoss, op=mse_with_phases): - pass diff --git a/hippynn/additional/test_nacr.py b/hippynn/additional/test_nacr.py deleted file mode 100644 index 03d0daa0..00000000 --- a/hippynn/additional/test_nacr.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest - -import numpy as np -import torch - - -class TestNACRLayers(unittest.TestCase): - - # random number of molecules, atoms, and states - n_mol, n_atoms, n_states = np.random.randint(3, 8, 3) - # random initial positions and charges - positions = torch.rand(n_mol, n_atoms, 3) - positions.requires_grad = True - energies = torch.rand(n_mol, n_states) - layer = torch.nn.Linear(3, n_states) - charges = layer(positions) - - def setUp(self): - from hippynn.additional import NACR, NACRMultiState - - self.NACR_layer = NACR() - self.NACR_multi_layer = NACRMultiState(self.n_states) - - def test_multi_targets(self): - indices = torch.triu_indices(self.n_states, self.n_states, offset=1, device=self.positions.device) - nacr_singles = torch.empty(self.n_mol, self.n_atoms, 3, len(indices[0])) - for i, (j, k) in enumerate(zip(*indices)): - nacr_singles[..., i] = self.NACR_layer( - self.charges[..., j], - self.charges[..., k], - self.positions, - self.energies[:, j].unsqueeze(1), - self.energies[:, k].unsqueeze(1), - ) - nacr_multi = self.NACR_multi_layer(self.charges, self.positions, self.energies) - self.assertTrue(torch.equal(nacr_singles, nacr_multi)) - - def _numpy_implementation(self): - # TODO: with a linear layer, it's possible to use its weights to implement an analytical gradient in numpy. - pass diff --git a/hippynn/graphs/nodes/loss.py b/hippynn/graphs/nodes/loss.py index 2320df1d..067571c7 100644 --- a/hippynn/graphs/nodes/loss.py +++ b/hippynn/graphs/nodes/loss.py @@ -134,3 +134,65 @@ def l2reg(network): def l1reg(network): return lpreg(network, p=1) + +# For loss functions with phases +def absolute_errors(predict: torch.Tensor, true: torch.Tensor): + """Compute the absolute errors with phases between predicted and true values. In + other words, prediction should be close to the absolute value of true, and the sign + does not matter. + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: errors + :rtype: torch.Tensor + """ + + return torch.minimum(torch.abs(true - predict), torch.abs(true + predict)) + + +def mae_with_phases(predict: torch.Tensor, true: torch.Tensor): + """MAE with phases + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: MAE with phases + :rtype: torch.Tensor + """ + + errors = torch.minimum( + torch.linalg.norm(true - predict, ord=1, dim=-1), + torch.linalg.norm(true + predict, ord=1, dim=-1), + ) + # errors = absolute_errors(predict, true) + return torch.sum(errors) / predict.numel() + + +def mse_with_phases(predict: torch.Tensor, true: torch.Tensor): + """MSE with phases + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: MSE with phases + :rtype: torch.Tensor + """ + + errors = torch.minimum( + torch.linalg.norm(true - predict, dim=-1), + torch.linalg.norm(true + predict, dim=-1), + ) + # errors = absolute_errors(predict, true) ** 2 + return torch.sum(errors**2) / predict.numel() + + +class MAEPhaseLoss(_BaseCompareLoss, op=mae_with_phases): + pass + + +class MSEPhaseLoss(_BaseCompareLoss, op=mse_with_phases): + pass diff --git a/hippynn/graphs/nodes/physics.py b/hippynn/graphs/nodes/physics.py index 558937b5..e789186f 100644 --- a/hippynn/graphs/nodes/physics.py +++ b/hippynn/graphs/nodes/physics.py @@ -2,18 +2,26 @@ Nodes for physics transformations """ import warnings +from typing import List, Optional, Tuple -from .base import SingleNode, MultiNode, AutoNoKw, AutoKw, ExpandParents, find_unique_relative, _BaseNode +from ...layers import indexers as index_layers +from ...layers import pairs as pair_layers +from ...layers import physics as physics_layers +from ..indextypes import IdxType, elementwise_compare_reduce, index_type_coercion +from .base import ( + AutoKw, + AutoNoKw, + ExpandParents, + MultiNode, + SingleNode, + _BaseNode, + find_unique_relative, +) from .base.node_functions import NodeNotFound from .indexers import AtomIndexer, PaddingIndexer, acquire_encoding_padding -from .pairs import OpenPairIndexer -from .tags import Encoder, PairIndexer, Charges from .inputs import PositionsNode, SpeciesNode - -from ..indextypes import IdxType, index_type_coercion, elementwise_compare_reduce -from ...layers import indexers as index_layers -from ...layers import physics as physics_layers -from ...layers import pairs as pair_layers +from .pairs import OpenPairIndexer +from .tags import Charges, Encoder, PairIndexer class GradientNode(AutoKw, SingleNode): @@ -159,8 +167,7 @@ def _validate_pairfinder(pairfinder, cutoff_distance): if pairfinder.torch_module.hard_dist_cutoff is not None: raise ValueError( - "dist_hard_max is set to a finite value,\n" - "coulomb energy requires summing over the entire set of pairs" + "dist_hard_max is set to a finite value,\ncoulomb energy requires summing over the entire set of pairs" ) def __init__(self, name, parents, energy_conversion, module="auto"): @@ -256,7 +263,6 @@ def __init__(self, name, parents, module="auto", **kwargs): # TODO: This seems broken for parent expanders, check the signature of the layer. class BondToMolSummmer(ExpandParents, AutoNoKw, SingleNode): - _input_names = "pairfeatures", "mol_index", "n_molecules", "pair_first" _auto_module_class = pair_layers.MolPairSummer _index_state = IdxType.Molecules @@ -300,3 +306,85 @@ def expansion1(self, features, species, **kwargs): def __init__(self, name, parents, module="auto", **kwargs): parents = self.expand_parents(parents) super().__init__(name, parents, module=module, **kwargs) + + +class NACRNode(AutoKw, SingleNode): + """ + Compute the non-adiabatic coupling vector multiplied by the energy difference + between two states. + """ + + _input_names = "charges i", "charges j", "coordinates", "energy i", "energy j" + _auto_module_class = physics_layers.NACR + + def __init__(self, name: str, parents: Tuple, module="auto", module_kwargs=None, **kwargs): + """Automatically build the node for calculating NACR * ΔE between two states i + and j. + + :param name: name of the node + :type name: str + :param parents: parents of the NACR node in the sequence of (charges i, \ + charges j, positions, energy i, energy j) + :type parents: Tuple + :param module: _description_, defaults to "auto" + :type module: str, optional + :param module_kwargs: keyword arguments passed to the corresponding layer, + defaults to None + :type module_kwargs: dict, optional + """ + + self.module_kwargs = {} + if module_kwargs is not None: + self.module_kwargs.update(module_kwargs) + charges1, charges2, positions, energy1, energy2 = parents + positions.requires_grad = True + self._index_state = IdxType.Molecules + # self._index_state = positions._index_state + parents = ( + charges1.main_output, + charges2.main_output, + positions, + energy1.main_output, + energy2.main_output, + ) + super().__init__(name, parents, module=module, **kwargs) + + +class NACRMultiStateNode(AutoKw, SingleNode): + """ + Compute the non-adiabatic coupling vector multiplied by the energy difference + between all pairs of states. + """ + + _input_names = "charges", "coordinates", "energies" + _auto_module_class = physics_layers.NACRMultiState + + def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): + """Automatically build the node for calculating NACR * ΔE between all pairs of + states. + + :param name: name of the node + :type name: str + :param parents: parents of the NACR node in the sequence of (charges, \ + positions, energies) + :type parents: Tuple + :param module: _description_, defaults to "auto" + :type module: str, optional + :param module_kwargs: keyword arguments passed to the corresponding layer, + defaults to None + :type module_kwargs: dict, optional + """ + + self.module_kwargs = {} + if module_kwargs is not None: + self.module_kwargs.update(module_kwargs) + charges, positions, energies = parents + positions.requires_grad = True + self._index_state = IdxType.Molecules + # self._index_state = positions._index_state + parents = ( + charges.main_output, + positions, + energies.main_output, + ) + super().__init__(name, parents, module=module, **kwargs) diff --git a/hippynn/layers/physics.py b/hippynn/layers/physics.py index 2a7a18b4..6816dea0 100644 --- a/hippynn/layers/physics.py +++ b/hippynn/layers/physics.py @@ -239,3 +239,66 @@ def forward(self, features, species): class VecMag(torch.nn.Module): def forward(self, vector_feature): return torch.norm(vector_feature, dim=1).unsqueeze(1) + + +class NACR(torch.nn.Module): + """ + Compute NAC vector * ΔE. Originally in hippynn.layers.physics. + """ + + def __init__(self): + super().__init__() + + def forward( + self, + charges1: Tensor, + charges2: Tensor, + positions: Tensor, + energy1: Tensor, + energy2: Tensor, + ): + dE = energy2 - energy1 + nacr = torch.autograd.grad( + charges2, [positions], grad_outputs=[charges1], create_graph=True + )[0].reshape(len(dE), -1) + return nacr * dE + + +class NACRMultiState(torch.nn.Module): + """ + Compute NAC vector * ΔE for all paris of states. Originally in hippynn.layers.physics. + """ + + def __init__(self, n_target=1): + self.n_target = n_target + super().__init__() + + def forward(self, charges: Tensor, positions: Tensor, energies: Tensor): + # charges shape: n_molecules, n_atoms, n_targets + # positions shape: n_molecules, n_atoms, 3 + # energies shape: n_molecules, n_targets + # dE shape: n_molecules, n_targets, n_targets + dE = energies.unsqueeze(1) - energies.unsqueeze(2) + # take the upper triangle excluding the diagonal + indices = torch.triu_indices( + self.n_target, self.n_target, offset=1, device=dE.device + ) + # dE shape: n_molecules, n_pairs + # n_pairs = n_targets * (n_targets - 1) / 2 + dE = dE[..., indices[0], indices[1]] + # compute q1 * dq2/dR + nacr_ij = [] + for i, j in zip(*indices): + nacr = torch.autograd.grad( + charges[..., j], + positions, + grad_outputs=charges[..., i], + create_graph=True, + )[0] + nacr_ij.append(nacr) + # nacr shape: n_molecules, n_atoms, 3, n_pairs + nacr = torch.stack(nacr_ij, dim=1) + n_molecule, n_pairs, n_atoms, n_dims = nacr.shape + nacr = nacr.reshape(n_molecule, n_pairs, n_atoms * n_dims) + # multiply dE + return nacr * dE.unsqueeze(2)