From d599de33feec9331d6ec4c973b1ddd0c2d3135b2 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sat, 26 Oct 2024 13:37:17 -0400 Subject: [PATCH 01/37] Add radius_graph_jarvis. --- alignn/config.py | 6 +- .../config_example_atomwise.json | 9 +-- alignn/graphs.py | 55 +++++++++++++++++++ alignn/lmdb_dataset.py | 1 + alignn/models/alignn.py | 6 +- alignn/models/alignn_atomwise.py | 3 + alignn/train.py | 2 + alignn/train_alignn.py | 34 +++++++----- 8 files changed, 93 insertions(+), 23 deletions(-) diff --git a/alignn/config.py b/alignn/config.py index d7c807a..36595f4 100644 --- a/alignn/config.py +++ b/alignn/config.py @@ -162,9 +162,9 @@ class TrainingConfig(BaseSettings): ] = "dft_3d" target: TARGET_ENUM = "exfoliation_energy" atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn" - neighbor_strategy: Literal["k-nearest", "voronoi", "radius_graph"] = ( - "k-nearest" - ) + neighbor_strategy: Literal[ + "k-nearest", "voronoi", "radius_graph", "radius_graph_jarvis" + ] = "k-nearest" id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid" # logging configuration diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index 2b915dd..b9e4791 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -3,7 +3,7 @@ "dataset": "user_data", "target": "target", "atom_features": "cgcnn", - "neighbor_strategy": "k-nearest", + "neighbor_strategy": "radius_graph", "id_tag": "jid", "random_seed": 123, "classification_threshold": null, @@ -33,7 +33,7 @@ "standard_scalar_and_pca": false, "use_canonize": true, "num_workers": 0, - "cutoff": 8.0, + "cutoff": 4.0, "max_neighbors": 12, "keep_data_order": true, "distributed":false, @@ -43,8 +43,9 @@ "atom_input_features": 92, "calculate_gradient":true, "atomwise_output_features":0, - "alignn_layers":4, - "gcn_layers":4, + "alignn_layers":2, + "gcn_layers":2, + "hidden_features":128, "output_features": 1, "graphwise_weight":0.85, "gradwise_weight":0.05, diff --git a/alignn/graphs.py b/alignn/graphs.py index 53e772a..67a4c20 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -320,6 +320,52 @@ def radius_graph_old( ### +def radius_graph_jarvis( + atoms, cutoff=4, atom_features="cgcnn", line_graph=True +): + """Construct edge list for radius graph.""" + u, v, r, atom_feats = [], [], [], [] + elements = atoms.elements + + # Loop over each atom in the structure + for ii, i in enumerate(atoms.cart_coords): + # Get neighbors within the cutoff distance + neighs = atoms.lattice.get_points_in_sphere( + atoms.frac_coords, i, cutoff, distance_vector=True + ) + + # Filter out self-loops (where the neighbor is the same as the source atom) + valid_indices = neighs[2] != ii # Exclude self-loops + + # Store source (u), destination (v), and distances (r) only for valid neighbors + u.extend( + [ii] * np.sum(valid_indices) + ) # Add the source atom multiple times + v.extend(neighs[2][valid_indices]) # Add valid neighbors only + r.extend(neighs[-1][valid_indices]) # Add distances of valid neighbors + + # Store atom features for the current atom + feat = list( + get_node_attributes(elements[ii], atom_features=atom_features) + ) + atom_feats.append(feat) + + # Create DGL graph + g = dgl.graph((np.array(u), np.array(v))) + g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=torch.float32) + g.edata["r"] = torch.tensor(r, dtype=torch.float32) + g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=torch.float32) + g.ndata["V"] = torch.tensor( + [atoms.volume] * atoms.num_atoms, dtype=torch.float32 + ) + + # Optional: Create a line graph if requested + if line_graph: + lg = g.line_graph(shared=True) + lg.apply_edges(compute_bond_cosines) + return g, lg + + return g class Graph(object): @@ -372,6 +418,7 @@ def atom_dgl_multigraph( ): """Obtain a DGLGraph for Atoms object.""" # print('id',id) + # print('stratgery', neighbor_strategy) if neighbor_strategy == "k-nearest": edges = nearest_neighbor_edges( atoms=atoms, @@ -388,6 +435,14 @@ def atom_dgl_multigraph( u, v, r = radius_graph( atoms, cutoff=cutoff, cutoff_extra=cutoff_extra ) + elif neighbor_strategy == "radius_graph_jarvis": + g, lg = radius_graph_jarvis( + atoms, + cutoff=cutoff, + atom_features=atom_features, + line_graph=compute_line_graph, + ) + return g, lg else: raise ValueError("Not implemented yet", neighbor_strategy) # elif neighbor_strategy == "voronoi": diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index 7a41ab7..844430c 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -150,6 +150,7 @@ def get_torch_dataset( compute_line_graph=line_graph, use_canonize=use_canonize, cutoff_extra=cutoff_extra, + neighbor_strategy=neighbor_strategy, ) if line_graph: g, lg = g diff --git a/alignn/models/alignn.py b/alignn/models/alignn.py index aafd036..f3ed795 100644 --- a/alignn/models/alignn.py +++ b/alignn/models/alignn.py @@ -309,9 +309,9 @@ def forward( g = g.local_var() # initial node features: atom feature network... x = g.ndata.pop("atom_features") - # print('x1',x.shape) + # print("x1", x, x.shape) x = self.atom_embedding(x) - # print('x2',x.shape) + # print("x2", x, x.shape) # initial bond features bondlength = torch.norm(g.edata.pop("r"), dim=1) @@ -327,7 +327,7 @@ def forward( # norm-activation-pool-classify h = self.readout(g, x) - # print('h',h.shape) + # print("h", h, h.shape) # print('features',features.shape) if self.config.extra_features != 0: h_feat = self.readout_feat(g, features) diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 1b15ec2..06e0fab 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -386,7 +386,10 @@ def forward( # initial node features: atom feature network... x = g.ndata.pop("atom_features") + # print('x1',x,x.shape) + x = self.atom_embedding(x) + # print('x2',x,x.shape) r = g.edata["r"] if self.config.calculate_gradient: r.requires_grad_(True) diff --git a/alignn/train.py b/alignn/train.py index a67d1db..d68693e 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -18,6 +18,7 @@ from alignn.data import get_train_val_loaders from alignn.config import TrainingConfig from alignn.models.alignn_atomwise import ALIGNNAtomWise +from alignn.models.alignn import ALIGNN from jarvis.db.jsonutils import dumpjson import json import pprint @@ -210,6 +211,7 @@ def train_dgl( config.model.classification = True _model = { "alignn_atomwise": ALIGNNAtomWise, + "alignn": ALIGNN, } if config.random_seed is not None: random.seed(config.random_seed) diff --git a/alignn/train_alignn.py b/alignn/train_alignn.py index 16ef6ab..7bb61ed 100644 --- a/alignn/train_alignn.py +++ b/alignn/train_alignn.py @@ -198,19 +198,27 @@ def train_for_folder( train_grad = False train_stress = False train_atom = False - if config.model.calculate_gradient and config.model.gradwise_weight != 0: - train_grad = True - else: - train_grad = False - if config.model.calculate_gradient and config.model.stresswise_weight != 0: - train_stress = True - else: - train_stress = False - if config.model.atomwise_weight != 0: - train_atom = True - else: - train_atom = False - + try: + if ( + config.model.calculate_gradient + and config.model.gradwise_weight != 0 + ): + train_grad = True + else: + train_grad = False + if ( + config.model.calculate_gradient + and config.model.stresswise_weight != 0 + ): + train_stress = True + else: + train_stress = False + if config.model.atomwise_weight != 0: + train_atom = True + else: + train_atom = False + except Exception as exp: + pass # if config.model.atomwise_weight == 0: # train_atom = False # if config.model.gradwise_weight == 0: From 77837a38b3ba1a0925e00e6fa76cdd00f73aa2c0 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sat, 26 Oct 2024 21:23:26 -0400 Subject: [PATCH 02/37] Add ALIGNN_FF2, radius_graph_jarvis. --- alignn/config.py | 16 +- alignn/data.py | 4 + alignn/dataset.py | 6 + .../config_example_atomwise.json | 5 +- alignn/graphs.py | 145 ++-- alignn/lmdb_dataset.py | 2 + alignn/models/alignn_ff2.py | 623 ++++++++++++++++++ alignn/train.py | 18 +- alignn/train_alignn.py | 45 +- 9 files changed, 765 insertions(+), 99 deletions(-) create mode 100644 alignn/models/alignn_ff2.py diff --git a/alignn/config.py b/alignn/config.py index 36595f4..dc838f2 100644 --- a/alignn/config.py +++ b/alignn/config.py @@ -6,18 +6,10 @@ from typing import Literal from alignn.utils import BaseSettings from alignn.models.alignn import ALIGNNConfig +from alignn.models.alignn_ff2 import ALIGNNFF2Config from alignn.models.alignn_atomwise import ALIGNNAtomWiseConfig -# from alignn.models.modified_cgcnn import CGCNNConfig -# from alignn.models.icgcnn import ICGCNNConfig -# from alignn.models.gcn import SimpleGCNConfig -# from alignn.models.densegcn import DenseGCNConfig -# from pydantic import model_validator -# from alignn.models.dense_alignn import DenseALIGNNConfig -# from alignn.models.alignn_cgcnn import ACGCNNConfig -# from alignn.models.alignn_layernorm import ALIGNNConfig as ALIGNN_LN_Config - -# from typing import List +# import torch try: VERSION = ( @@ -167,9 +159,8 @@ class TrainingConfig(BaseSettings): ] = "k-nearest" id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid" - # logging configuration - # training configuration + dtype: str = "float32" random_seed: Optional[int] = 123 classification_threshold: Optional[float] = None # target_range: Optional[List] = None @@ -219,6 +210,7 @@ class TrainingConfig(BaseSettings): # model configuration model: Union[ ALIGNNConfig, + ALIGNNFF2Config, ALIGNNAtomWiseConfig, # CGCNNConfig, # ICGCNNConfig, diff --git a/alignn/data.py b/alignn/data.py index c52d0c7..a30cba7 100644 --- a/alignn/data.py +++ b/alignn/data.py @@ -153,6 +153,7 @@ def get_train_val_loaders( world_size=0, rank=0, use_lmdb: bool = True, + dtype="float32", ): """Help function to set up JARVIS train and val dataloaders.""" if use_lmdb: @@ -383,6 +384,7 @@ def get_train_val_loaders( output_dir=output_dir, sampler=train_sampler, tmp_name=tmp_name, + dtype=dtype, # tmp_name="train_data", ) tmp_name = filename + "val_data" @@ -406,6 +408,7 @@ def get_train_val_loaders( classification=classification_threshold is not None, output_dir=output_dir, tmp_name=tmp_name, + dtype=dtype, # tmp_name="val_data", ) if len(dataset_val) > 0 @@ -431,6 +434,7 @@ def get_train_val_loaders( classification=classification_threshold is not None, output_dir=output_dir, tmp_name=tmp_name, + dtype=dtype, # tmp_name="test_data", ) if len(dataset_test) > 0 diff --git a/alignn/dataset.py b/alignn/dataset.py index 6baec25..8ad2004 100644 --- a/alignn/dataset.py +++ b/alignn/dataset.py @@ -26,6 +26,7 @@ def load_graphs( id_tag="jid", # extra_feats_json=None, map_size=1e12, + dtype="float32", ): """Construct crystal graphs. @@ -54,6 +55,7 @@ def atoms_to_graph(atoms): compute_line_graph=False, use_canonize=use_canonize, neighbor_strategy=neighbor_strategy, + dtype=dtype, ) if cachedir is not None: @@ -84,6 +86,7 @@ def atoms_to_graph(atoms): use_canonize=use_canonize, neighbor_strategy=neighbor_strategy, id=i[id_tag], + dtype=dtype, ) # print ('ii',ii) if "extra_features" in i: @@ -124,6 +127,7 @@ def get_torch_dataset( output_dir=".", tmp_name="dataset", sampler=None, + dtype="float32", ): """Get Torch Dataset.""" df = pd.DataFrame(dataset) @@ -147,6 +151,7 @@ def get_torch_dataset( cutoff_extra=cutoff_extra, max_neighbors=max_neighbors, id_tag=id_tag, + dtype=dtype, ) data = StructureDataset( df, @@ -160,5 +165,6 @@ def get_torch_dataset( id_tag=id_tag, classification=classification, sampler=sampler, + dtype=dtype, ) return data diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index b9e4791..f1d13ad 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -3,8 +3,9 @@ "dataset": "user_data", "target": "target", "atom_features": "cgcnn", - "neighbor_strategy": "radius_graph", + "neighbor_strategy": "radius_graph_jarvis", "id_tag": "jid", + "dtype": "float32", "random_seed": 123, "classification_threshold": null, "n_val": null, @@ -39,7 +40,7 @@ "distributed":false, "use_lmdb": true, "model": { - "name": "alignn_atomwise", + "name": "alignn_ff2", "atom_input_features": 92, "calculate_gradient":true, "atomwise_output_features":0, diff --git a/alignn/graphs.py b/alignn/graphs.py index 67a4c20..94072c6 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -9,20 +9,100 @@ from jarvis.analysis.structure.neighbors import NeighborsAnalysis from jarvis.core.specie import chem_data, get_node_attributes import math - -# from jarvis.core.atoms import Atoms from collections import defaultdict from typing import List, Tuple, Sequence, Optional from dgl.data import DGLDataset - import torch import dgl +from tqdm import tqdm + + +def temp_graph(atoms=None, cutoff=4.0, atom_features="cgcnn", dtype="float32"): + """Helper function to construct a graph for a given cutoff.""" + TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat": torch.bfloat16, + } + dtype = TORCH_DTYPES[dtype] + u, v, r, d, images, atom_feats = [], [], [], [], [], [] + elements = atoms.elements + + # Loop over each atom in the structure + for ii, i in enumerate(atoms.cart_coords): + # Get neighbors within the cutoff distance + neighs = atoms.lattice.get_points_in_sphere( + atoms.frac_coords, i, cutoff, distance_vector=True + ) + + # Filter out self-loops (exclude cases where atom is bonded to itself) + valid_indices = neighs[2] != ii + + u.extend([ii] * np.sum(valid_indices)) + d.extend(neighs[1][valid_indices]) + v.extend(neighs[2][valid_indices]) + images.extend(neighs[3][valid_indices]) + r.extend(neighs[4][valid_indices]) + + feat = list( + get_node_attributes(elements[ii], atom_features=atom_features) + ) + atom_feats.append(feat) + + # Create DGL graph + g = dgl.graph((np.array(u), np.array(v))) + + # Add data to the graph with the specified dtype + g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=dtype) + g.edata["r"] = torch.tensor(r, dtype=dtype) + g.edata["d"] = torch.tensor(d, dtype=dtype) + g.edata["images"] = torch.tensor(images, dtype=dtype) + g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=dtype) + g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype) + + return g, u, v, r + + +def radius_graph_jarvis( + atoms, + cutoff_extra=0.5, + cutoff=4.0, + atom_features="cgcnn", + line_graph=True, + dtype="float32", +): + """Construct radius graph with dynamic cutoff.""" + + while True: + # try: + # Attempt to create the graph + g, u, v, r = temp_graph( + atoms=atoms, + cutoff=cutoff, + atom_features=atom_features, + dtype=dtype, + ) + # Check if all atoms are included as nodes + if g.num_nodes() == len(atoms.elements): + # print(f"Graph constructed with cutoff: {cutoff}") + break # Exit the loop when successful + # Increment the cutoff if the graph is incomplete + cutoff += cutoff_extra + # print(f"Increasing cutoff to: {cutoff}") + + # except Exception as exp: + # # Handle exceptions and try again + # print(f"Graph construction failed: {exp}") + # cutoff += cutoff_extra # Try with a larger cutoff + + # Optional: Create a line graph if requested + if line_graph: + lg = g.line_graph(shared=True) + lg.apply_edges(compute_bond_cosines) + return g, lg -try: - from tqdm import tqdm -except Exception as exp: - print("tqdm is not installed.", exp) - pass + return g def canonize_edge( @@ -320,52 +400,6 @@ def radius_graph_old( ### -def radius_graph_jarvis( - atoms, cutoff=4, atom_features="cgcnn", line_graph=True -): - """Construct edge list for radius graph.""" - u, v, r, atom_feats = [], [], [], [] - elements = atoms.elements - - # Loop over each atom in the structure - for ii, i in enumerate(atoms.cart_coords): - # Get neighbors within the cutoff distance - neighs = atoms.lattice.get_points_in_sphere( - atoms.frac_coords, i, cutoff, distance_vector=True - ) - - # Filter out self-loops (where the neighbor is the same as the source atom) - valid_indices = neighs[2] != ii # Exclude self-loops - - # Store source (u), destination (v), and distances (r) only for valid neighbors - u.extend( - [ii] * np.sum(valid_indices) - ) # Add the source atom multiple times - v.extend(neighs[2][valid_indices]) # Add valid neighbors only - r.extend(neighs[-1][valid_indices]) # Add distances of valid neighbors - - # Store atom features for the current atom - feat = list( - get_node_attributes(elements[ii], atom_features=atom_features) - ) - atom_feats.append(feat) - - # Create DGL graph - g = dgl.graph((np.array(u), np.array(v))) - g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=torch.float32) - g.edata["r"] = torch.tensor(r, dtype=torch.float32) - g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=torch.float32) - g.ndata["V"] = torch.tensor( - [atoms.volume] * atoms.num_atoms, dtype=torch.float32 - ) - - # Optional: Create a line graph if requested - if line_graph: - lg = g.line_graph(shared=True) - lg.apply_edges(compute_bond_cosines) - return g, lg - - return g class Graph(object): @@ -415,6 +449,7 @@ def atom_dgl_multigraph( # use_canonize: bool = False, use_lattice_prop: bool = False, cutoff_extra=3.5, + dtype=torch.float32, ): """Obtain a DGLGraph for Atoms object.""" # print('id',id) @@ -441,6 +476,7 @@ def atom_dgl_multigraph( cutoff=cutoff, atom_features=atom_features, line_graph=compute_line_graph, + dtype=dtype, ) return g, lg else: @@ -784,6 +820,7 @@ def __init__( classification=False, id_tag="jid", sampler=None, + dtype="float32", ): """Pytorch Dataset for atomistic graphs. diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index 844430c..7225788 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -116,6 +116,7 @@ def get_torch_dataset( tmp_name="dataset", map_size=1e12, read_existing=True, + dtype="float32", ): """Get Torch Dataset with LMDB.""" vals = np.array([ii[target] for ii in dataset]) # df[target].values @@ -151,6 +152,7 @@ def get_torch_dataset( use_canonize=use_canonize, cutoff_extra=cutoff_extra, neighbor_strategy=neighbor_strategy, + dtype=dtype, ) if line_graph: g, lg = g diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py new file mode 100644 index 0000000..4cbf6b0 --- /dev/null +++ b/alignn/models/alignn_ff2.py @@ -0,0 +1,623 @@ +"""Atomistic LIne Graph Neural Network. + +A prototype crystal line graph network dgl implementation. +""" + +from typing import Tuple, Union +from torch.autograd import grad +import dgl +import dgl.function as fn +import numpy as np +from dgl.nn import AvgPooling +import torch + +# from dgl.nn.functional import edge_softmax +from typing import Literal +from torch import nn +from torch.nn import functional as F +from alignn.models.utils import RBFExpansion +from alignn.graphs import compute_bond_cosines +from alignn.utils import BaseSettings + + +def _ensure_3body_line_graph_compatibility( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float +): + """Ensure that 3body line graph is compatible with a given graph. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + """ + valid_three_body = graph.edata["bond_dist"] <= threebody_cutoff + if ( + line_graph.num_nodes() + == graph.edata["bond_vec"][valid_three_body].shape[0] + ): + line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][ + valid_three_body + ] + line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][ + valid_three_body + ] + line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][ + valid_three_body + ] + else: + three_body_id = torch.concatenate(line_graph.edges()) + max_three_body_id = ( + torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 + ) + line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][ + :max_three_body_id + ] + line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][ + :max_three_body_id + ] + line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][ + :max_three_body_id + ] + + return line_graph + + +class ALIGNNFF2Config(BaseSettings): + """Hyperparameter schema for jarvisdgl.models.alignn.""" + + name: Literal["alignn_ff2"] + alignn_layers: int = 4 + gcn_layers: int = 4 + atom_input_features: int = 92 + edge_input_features: int = 80 + triplet_input_features: int = 40 + embedding_features: int = 64 + hidden_features: int = 256 + # fc_layers: int = 1 + # fc_features: int = 64 + output_features: int = 1 + grad_multiplier: int = -1 + calculate_gradient: bool = True + atomwise_output_features: int = 0 + graphwise_weight: float = 1.0 + gradwise_weight: float = 0.0 + stresswise_weight: float = 0.0 + atomwise_weight: float = 0.0 + # if link == log, apply `exp` to final outputs + # to constrain predictions to be positive + link: Literal["identity", "log", "logit"] = "identity" + zero_inflated: bool = False + classification: bool = False + force_mult_natoms: bool = False + energy_mult_natoms: bool = False + include_pos_deriv: bool = False + use_cutoff_function: bool = False + inner_cutoff: float = 6 # Ansgtrom + stress_multiplier: float = 1 + add_reverse_forces: bool = False # will make True as default soon + lg_on_fly: bool = False # will make True as default soon + batch_stress: bool = True + multiply_cutoff: bool = False + extra_features: int = 0 + exponent: int = 3 + + class Config: + """Configure model settings behavior.""" + + env_prefix = "jv_model" + + +def cutoff_function_based_edges_old(r, inner_cutoff=4): + """Apply smooth cutoff to pairwise interactions + + r: bond lengths + inner_cutoff: cutoff radius + + inside cutoff radius, apply smooth cutoff envelope + outside cutoff radius: hard zeros + """ + ratio = r / inner_cutoff + return torch.where( + ratio <= 1, + 1 - 6 * ratio**5 + 15 * ratio**4 - 10 * ratio**3, + torch.zeros_like(r), + ) + + +def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3): + """Apply smooth cutoff to pairwise interactions + + r: bond lengths + inner_cutoff: cutoff radius + + inside cutoff radius, apply smooth cutoff envelope + outside cutoff radius: hard zeros + """ + ratio = r / inner_cutoff + c1 = -(exponent + 1) * (exponent + 2) / 2 + c2 = exponent * (exponent + 2) + c3 = -exponent * (exponent + 1) / 2 + envelope = ( + 1 + + c1 * ratio**exponent + + c2 * ratio ** (exponent + 1) + + c3 * ratio ** (exponent + 2) + ) + # r_cut = inner_cutoff + # r_on = inner_cutoff+1 + + # r_sq = r * r + # r_on_sq = r_on * r_on + # r_cut_sq = r_cut * r_cut + # envelope = (r_cut_sq - r_sq) + # ** 2 * (r_cut_sq + 2 * r_sq - 3 * r_on_sq)/ (r_cut_sq - r_on_sq) ** 3 + return torch.where(r <= inner_cutoff, envelope, torch.zeros_like(r)) + + +class EdgeGatedGraphConv(nn.Module): + """Edge gated graph convolution from arxiv:1711.07553. + + see also arxiv:2003.0098. + + This is similar to CGCNN, but edge features only go into + the soft attention / edge gating function, and the primary + node update function is W cat(u, v) + b + """ + + def __init__( + self, input_features: int, output_features: int, residual: bool = True + ): + """Initialize parameters for ALIGNN update.""" + super().__init__() + self.residual = residual + # CGCNN-Conv operates on augmented edge features + # z_ij = cat(v_i, v_j, u_ij) + # m_ij = σ(z_ij W_f + b_f) ⊙ g_s(z_ij W_s + b_s) + # coalesce parameters for W_f and W_s + # but -- split them up along feature dimension + self.src_gate = nn.Linear(input_features, output_features) + self.dst_gate = nn.Linear(input_features, output_features) + self.edge_gate = nn.Linear(input_features, output_features) + self.bn_edges = nn.LayerNorm(output_features) + + self.src_update = nn.Linear(input_features, output_features) + self.dst_update = nn.Linear(input_features, output_features) + self.bn_nodes = nn.LayerNorm(output_features) + + def forward( + self, + g: dgl.DGLGraph, + node_feats: torch.Tensor, + edge_feats: torch.Tensor, + ) -> torch.Tensor: + """Edge-gated graph convolution. + + h_i^l+1 = ReLU(U h_i + sum_{j->i} eta_{ij} ⊙ V h_j) + """ + g = g.local_var() + + # instead of concatenating (u || v || e) and applying one weight matrix + # split the weight matrix into three, apply, then sum + # see https://docs.dgl.ai/guide/message-efficient.html + # but split them on feature dimensions to update u, v, e separately + # m = BatchNorm(Linear(cat(u, v, e))) + + # compute edge updates, equivalent to: + # Softplus(Linear(u || v || e)) + g.ndata["e_src"] = self.src_gate(node_feats) + g.ndata["e_dst"] = self.dst_gate(node_feats) + g.apply_edges(fn.u_add_v("e_src", "e_dst", "e_nodes")) + m = g.edata.pop("e_nodes") + self.edge_gate(edge_feats) + + g.edata["sigma"] = torch.sigmoid(m) + g.ndata["Bh"] = self.dst_update(node_feats) + g.update_all( + fn.u_mul_e("Bh", "sigma", "m"), fn.sum("m", "sum_sigma_h") + ) + g.update_all(fn.copy_e("sigma", "m"), fn.sum("m", "sum_sigma")) + g.ndata["h"] = g.ndata["sum_sigma_h"] / (g.ndata["sum_sigma"] + 1e-6) + x = self.src_update(node_feats) + g.ndata.pop("h") + + # softmax version seems to perform slightly worse + # that the sigmoid-gated version + # compute node updates + # Linear(u) + edge_gates ⊙ Linear(v) + # g.edata["gate"] = edge_softmax(g, y) + # g.ndata["h_dst"] = self.dst_update(node_feats) + # g.update_all(fn.u_mul_e("h_dst", "gate", "m"), fn.sum("m", "h")) + # x = self.src_update(node_feats) + g.ndata.pop("h") + + # node and edge updates + x = F.silu(self.bn_nodes(x)) + y = F.silu(self.bn_edges(m)) + + if self.residual: + x = node_feats + x + y = edge_feats + y + + return x, y + + +class ALIGNNConv(nn.Module): + """Line graph update.""" + + def __init__( + self, + in_features: int, + out_features: int, + ): + """Set up ALIGNN parameters.""" + super().__init__() + self.node_update = EdgeGatedGraphConv(in_features, out_features) + self.edge_update = EdgeGatedGraphConv(out_features, out_features) + + def forward( + self, + g: dgl.DGLGraph, + lg: dgl.DGLGraph, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + ): + """Node and Edge updates for ALIGNN layer. + + x: node input features + y: edge input features + z: edge pair input features + """ + g = g.local_var() + lg = lg.local_var() + # Edge-gated graph convolution update on crystal graph + x, m = self.node_update(g, x, y) + + # Edge-gated graph convolution update on crystal graph + y, z = self.edge_update(lg, m, z) + + return x, y, z + + +class MLPLayer(nn.Module): + """Multilayer perceptron layer helper.""" + + def __init__(self, in_features: int, out_features: int): + """Linear, Batchnorm, SiLU layer.""" + super().__init__() + self.layer = nn.Sequential( + nn.Linear(in_features, out_features), + nn.LayerNorm(out_features), + nn.SiLU(), + ) + + def forward(self, x): + """Linear, Batchnorm, silu layer.""" + return self.layer(x) + + +class ALIGNNFF2(nn.Module): + """Atomistic Line graph network. + + Chain alternating gated graph convolution updates on crystal graph + and atomistic line graph. + """ + + def __init__( + self, + config: ALIGNNFF2Config = ALIGNNFF2Config(name="alignn_ff2"), + ): + """Initialize class with number of input features, conv layers.""" + super().__init__() + # print(config) + self.classification = config.classification + self.config = config + if self.config.gradwise_weight == 0: + self.config.calculate_gradient = False + # if self.config.atomwise_weight == 0: + # self.config.atomwise_output_features = None + self.atom_embedding = MLPLayer( + config.atom_input_features, config.hidden_features + ) + + self.edge_embedding = nn.Sequential( + RBFExpansion( + vmin=0, + vmax=8.0, + bins=config.edge_input_features, + ), + MLPLayer(config.edge_input_features, config.embedding_features), + MLPLayer(config.embedding_features, config.hidden_features), + ) + self.angle_embedding = nn.Sequential( + RBFExpansion( + vmin=-1, + vmax=1.0, + bins=config.triplet_input_features, + ), + MLPLayer(config.triplet_input_features, config.embedding_features), + MLPLayer(config.embedding_features, config.hidden_features), + ) + + self.alignn_layers = nn.ModuleList( + [ + ALIGNNConv( + config.hidden_features, + config.hidden_features, + ) + for idx in range(config.alignn_layers) + ] + ) + self.gcn_layers = nn.ModuleList( + [ + EdgeGatedGraphConv( + config.hidden_features, config.hidden_features + ) + for idx in range(config.gcn_layers) + ] + ) + + self.readout = AvgPooling() + + if config.extra_features != 0: + self.readout_feat = AvgPooling() + # Credit for extra_features work: + # Gong et al., https://doi.org/10.48550/arXiv.2208.05039 + self.extra_feature_embedding = MLPLayer( + config.extra_features, config.extra_features + ) + # print('config.output_features',config.output_features) + self.fc3 = nn.Linear( + config.hidden_features + config.extra_features, + config.output_features, + ) + self.fc1 = MLPLayer( + config.extra_features + config.hidden_features, + config.extra_features + config.hidden_features, + ) + self.fc2 = MLPLayer( + config.extra_features + config.hidden_features, + config.extra_features + config.hidden_features, + ) + + if config.atomwise_output_features > 0: + # if config.atomwise_output_features is not None: + self.fc_atomwise = nn.Linear( + config.hidden_features, config.atomwise_output_features + ) + + if self.classification: + self.fc = nn.Linear(config.hidden_features, 1) + self.softmax = nn.Sigmoid() + # self.softmax = nn.LogSoftmax(dim=1) + else: + self.fc = nn.Linear(config.hidden_features, config.output_features) + self.link = None + self.link_name = config.link + if config.link == "identity": + self.link = lambda x: x + elif config.link == "log": + self.link = torch.exp + avg_gap = 0.7 # magic number -- average bandgap in dft_3d + self.fc.bias.data = torch.tensor( + np.log(avg_gap), dtype=torch.float + ) + elif config.link == "logit": + self.link = torch.sigmoid + + def forward( + self, g: Union[Tuple[dgl.DGLGraph, dgl.DGLGraph], dgl.DGLGraph] + ): + """ALIGNN : start with `atom_features`. + + x: atom features (g.ndata) + y: bond features (g.edata and lg.ndata) + z: angle features (lg.edata) + """ + if len(self.alignn_layers) > 0: + g, lg = g + lg = lg.local_var() + + # angle features (fixed) + z = self.angle_embedding(lg.edata.pop("h")) + if self.config.extra_features != 0: + features = g.ndata["extra_features"] + # print('features',features,features.shape) + features = self.extra_feature_embedding(features) + g = g.local_var() + result = {} + + # initial node features: atom feature network... + x = g.ndata.pop("atom_features") + # print('x1',x,x.shape) + + x = self.atom_embedding(x) + # print('x2',x,x.shape) + r = g.edata["r"] + if self.config.calculate_gradient: + r.requires_grad_(True) + bondlength = torch.norm(r, dim=1) + # mask = bondlength >= self.config.inner_cutoff + # bondlength[mask]=float(1.1) + if self.config.lg_on_fly and len(self.alignn_layers) > 0: + # re-compute bond angle cosines here to ensure + # the three-body interactions are fully included + # in the autograd graph. don't rely on dataloader/caching. + lg.ndata["r"] = r # overwrites precomputed r values + lg.apply_edges(compute_bond_cosines) # overwrites precomputed h + z = self.angle_embedding(lg.edata.pop("h")) + + # r = g.edata["r"].clone().detach().requires_grad_(True) + if self.config.use_cutoff_function: + # bondlength = cutoff_function_based_edges( + if self.config.multiply_cutoff: + c_off = cutoff_function_based_edges( + bondlength, + inner_cutoff=self.config.inner_cutoff, + exponent=self.config.exponent, + ).unsqueeze(dim=1) + + y = self.edge_embedding(bondlength) * c_off + else: + bondlength = cutoff_function_based_edges( + bondlength, + inner_cutoff=self.config.inner_cutoff, + exponent=self.config.exponent, + ) + y = self.edge_embedding(bondlength) + else: + y = self.edge_embedding(bondlength) + # y = self.edge_embedding(bondlength) + # ALIGNN updates: update node, edge, triplet features + for alignn_layer in self.alignn_layers: + x, y, z = alignn_layer(g, lg, x, y, z) + + # gated GCN updates: update node, edge features + for gcn_layer in self.gcn_layers: + x, y = gcn_layer(g, x, y) + # norm-activation-pool-classify + out = torch.empty(1) + if self.config.output_features is not None: + h = self.readout(g, x) + out = self.fc(h) + if self.config.extra_features != 0: + h_feat = self.readout_feat(g, features) + # print('h_feat',h_feat) + h = torch.cat((h, h_feat), 1) + h = self.fc1(h) + h = self.fc2(h) + out = self.fc3(h) + # print('out',out) + else: + out = torch.squeeze(out) + atomwise_pred = torch.empty(1) + if ( + self.config.atomwise_output_features > 0 + # self.config.atomwise_output_features is not None + and self.config.atomwise_weight != 0 + ): + atomwise_pred = self.fc_atomwise(x) + # atomwise_pred = torch.squeeze(self.readout(g, atomwise_pred)) + forces = torch.empty(1) + # gradient = torch.empty(1) + stress = torch.empty(1) + + if self.config.calculate_gradient: + if self.config.include_pos_deriv: + # Not tested yet + g.ndata["coords"].requires_grad_(True) + dx = [g.ndata["coords"], r] + else: + dx = r + + if self.config.energy_mult_natoms: + en_out = out * g.num_nodes() + else: + en_out = out + + # force calculation based on bond displacement vectors + # autograd gives dE / d{r_{i->j}} + pair_forces = ( + self.config.grad_multiplier + * grad( + en_out, + dx, + grad_outputs=torch.ones_like(en_out), + create_graph=True, + retain_graph=True, + )[0] + ) + if self.config.force_mult_natoms: + pair_forces *= g.num_nodes() + + # construct force_i = dE / d{r_i} + # reduce over bonds to get forces on each atom + + # force_i contributions from r_{j->i} (in edges) + g.edata["pair_forces"] = pair_forces + g.update_all( + fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ji") + ) + if self.config.add_reverse_forces: + # reduce over reverse edges too! + # force_i contributions from r_{i->j} (out edges) + # aggregate pairwise_force_contributions over reversed edges + rg = dgl.reverse(g, copy_edata=True) + rg.update_all( + fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ij") + ) + + # combine dE / d(r_{j->i}) and dE / d(r_{i->j}) + forces = torch.squeeze( + g.ndata["forces_ji"] - rg.ndata["forces_ij"] + ) + else: + forces = torch.squeeze(g.ndata["forces_ji"]) + + if self.config.stresswise_weight != 0: + # Under development, use with caution + # 1 eV/Angstrom3 = 160.21766208 GPa + # 1 GPa = 10 kbar + # Following Virial stress formula, assuming inital velocity = 0 + # Save volume as g.gdta['V']? + # print('pair_forces',pair_forces.shape) + # print('r',r.shape) + # print('g.ndata["V"]',g.ndata["V"].shape) + if not self.config.batch_stress: + # print('Not batch_stress') + stress = ( + -1 + * 160.21766208 + * ( + torch.matmul(r.T, pair_forces) + # / (2 * g.edata["V"]) + / (2 * g.ndata["V"][0]) + ) + ) + # print("stress1", stress, stress.shape) + # print("g.batch_size", g.batch_size) + else: + # print('Using batch_stress') + stresses = [] + count_edge = 0 + count_node = 0 + for graph_id in range(g.batch_size): + num_edges = g.batch_num_edges()[graph_id] + num_nodes = 0 + st = -1 * ( + 160.21766208 + * torch.matmul( + r[count_edge : count_edge + num_edges].T, + pair_forces[ + count_edge : count_edge + num_edges + ], + ) + / g.ndata["V"][count_node + num_nodes] + ) + + count_edge = count_edge + num_edges + num_nodes = g.batch_num_nodes()[graph_id] + count_node = count_node + num_nodes + # print("stresses.append",stresses[-1],stresses[-1].shape) + for n in range(num_nodes): + stresses.append(st) + # stress = (stresses) + stress = self.config.stress_multiplier * torch.cat( + stresses + ) + # print("stress2", stress, stress.shape) + # virial = ( + # 160.21766208 + # * 10 + # * torch.einsum("ij, ik->jk", result["r"], result["dy_dr"]) + # / 2 + # ) # / ( g.ndata["V"][0]) + if self.link: + out = self.link(out) + + if self.classification: + # out = torch.max(out,dim=1) + out = self.softmax(out) + result["out"] = out + result["grad"] = forces + result["stresses"] = stress + result["atomwise_pred"] = atomwise_pred + # print(result) + return result diff --git a/alignn/train.py b/alignn/train.py index d68693e..f4b094d 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -18,6 +18,7 @@ from alignn.data import get_train_val_loaders from alignn.config import TrainingConfig from alignn.models.alignn_atomwise import ALIGNNAtomWise +from alignn.models.alignn_ff2 import ALIGNNFF2 from alignn.models.alignn import ALIGNN from jarvis.db.jsonutils import dumpjson import json @@ -28,7 +29,7 @@ from sklearn.metrics import roc_auc_score warnings.filterwarnings("ignore", category=RuntimeWarning) -torch.set_default_dtype(torch.float32) +# torch.set_default_dtype(torch.float32) # def setup(rank, world_size): @@ -147,7 +148,13 @@ def train_dgl( pprint.pprint(tmp) # , sort_dicts=False) if config.classification_threshold is not None: classification = True - + TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat": torch.bfloat16, + } + torch.set_default_dtype(TORCH_DTYPES[config.dtype]) line_graph = False if config.model.alignn_layers > 0: line_graph = True @@ -197,6 +204,7 @@ def train_dgl( keep_data_order=config.keep_data_order, output_dir=config.output_dir, use_lmdb=config.use_lmdb, + dtype=config.dtype, ) else: train_loader = train_val_test_loaders[0] @@ -212,6 +220,7 @@ def train_dgl( _model = { "alignn_atomwise": ALIGNNAtomWise, "alignn": ALIGNN, + "alignn_ff2": ALIGNNFF2, } if config.random_seed is not None: random.seed(config.random_seed) @@ -265,7 +274,10 @@ def train_dgl( optimizer, ) - if config.model.name == "alignn_atomwise": + if ( + config.model.name == "alignn_atomwise" + or config.model.name == "alignn_ff2" + ): def get_batch_errors(dat=[]): """Get errors for samples.""" diff --git a/alignn/train_alignn.py b/alignn/train_alignn.py index 7bb61ed..6dcf8ab 100644 --- a/alignn/train_alignn.py +++ b/alignn/train_alignn.py @@ -13,6 +13,7 @@ from jarvis.db.jsonutils import loadjson import argparse from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig +from alignn.models.alignn_ff2 import ALIGNNFF2, ALIGNNFF2Config import torch import time from jarvis.core.atoms import Atoms @@ -218,6 +219,7 @@ def train_for_folder( else: train_atom = False except Exception as exp: + print("exp", exp) pass # if config.model.atomwise_weight == 0: # train_atom = False @@ -322,36 +324,22 @@ def train_for_folder( ) tmp = ALIGNNAtomWiseConfig(**rest_config["model"]) - # tmp = ALIGNNAtomWiseConfig( - # name="alignn_atomwise", - # output_features=config.model.output_features, - # alignn_layers=config.model.alignn_layers, - # atomwise_weight=config.model.atomwise_weight, - # stresswise_weight=config.model.stresswise_weight, - # graphwise_weight=config.model.graphwise_weight, - # gradwise_weight=config.model.gradwise_weight, - # gcn_layers=config.model.gcn_layers, - # atom_input_features=config.model.atom_input_features, - # edge_input_features=config.model.edge_input_features, - # triplet_input_features=config.model.triplet_input_features, - # embedding_features=config.model.embedding_features, - # ) print("Rest config", tmp) - # for i,j in config_dict['model'].items(): - # print ('i',i) - # tmp.i=j - # print ('tmp1',tmp) model = ALIGNNAtomWise(tmp) # config.model) - # model = ALIGNNAtomWise(ALIGNNAtomWiseConfig( - # name="alignn_atomwise", - # output_features=1, - # graphwise_weight=1, - # alignn_layers=4, - # gradwise_weight=10, - # stresswise_weight=0.01, - # atomwise_weight=0, - # ) - # ) + print("model", model) + model.load_state_dict( + torch.load(restart_model_path, map_location=device) + ) + model = model.to(device) + if config.model.name == "alignn_ff2": + rest_config = loadjson( + restart_model_path.replace("current_model.pt", "config.json") + # restart_model_path.replace("best_model.pt", "config.json") + ) + + tmp = ALIGNNFF2Config(**rest_config["model"]) + print("Rest config", tmp) + model = ALIGNNFF2(tmp) # config.model) print("model", model) model.load_state_dict( torch.load(restart_model_path, map_location=device) @@ -410,6 +398,7 @@ def train_for_folder( keep_data_order=config.keep_data_order, output_dir=config.output_dir, use_lmdb=config.use_lmdb, + dtype=config.dtype, ) # print("dataset", dataset[0]) t1 = time.time() From 060e8b59158d4fdbde68fa8102398e71427330e8 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 27 Oct 2024 00:36:37 -0400 Subject: [PATCH 03/37] Lint. --- alignn/models/alignn_ff2.py | 150 ++++++++++++++---------------------- alignn/models/utils.py | 4 +- 2 files changed, 61 insertions(+), 93 deletions(-) diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py index 4cbf6b0..b4a012b 100644 --- a/alignn/models/alignn_ff2.py +++ b/alignn/models/alignn_ff2.py @@ -7,11 +7,8 @@ from torch.autograd import grad import dgl import dgl.function as fn -import numpy as np from dgl.nn import AvgPooling import torch - -# from dgl.nn.functional import edge_softmax from typing import Literal from torch import nn from torch.nn import functional as F @@ -19,8 +16,28 @@ from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings +# from math import pi, sqrt + + +def compute_pair_vector_and_distance(g: dgl.DGLGraph): + """Calculate bond vectors and distances using dgl graphs. + + Args: + g: DGL graph + + Returns: + bond_vec (torch.tensor): bond distance between two atoms + bond_dist (torch.tensor): vector from src node to dst node + """ + dst_pos = g.ndata["coords"][g.edges()[1]] + g.edata["images"] + src_pos = g.ndata["coords"][g.edges()[0]] + bond_vec = dst_pos - src_pos + bond_dist = torch.norm(bond_vec, dim=1) -def _ensure_3body_line_graph_compatibility( + return bond_vec, bond_dist + + +def check_line_graph( graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float ): """Ensure that 3body line graph is compatible with a given graph. @@ -30,34 +47,19 @@ def _ensure_3body_line_graph_compatibility( line_graph: line graph of atomistic graph threebody_cutoff: cutoff for three-body interactions """ - valid_three_body = graph.edata["bond_dist"] <= threebody_cutoff - if ( - line_graph.num_nodes() - == graph.edata["bond_vec"][valid_three_body].shape[0] - ): - line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][ - valid_three_body - ] - line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][ - valid_three_body - ] - line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][ - valid_three_body - ] + valid_three_body = graph.edata["d"] <= threebody_cutoff + if line_graph.num_nodes() == graph.edata["r"][valid_three_body].shape[0]: + line_graph.ndata["r"] = graph.edata["r"][valid_three_body] + line_graph.ndata["d"] = graph.edata["d"][valid_three_body] + line_graph.ndata["images"] = graph.edata["images"][valid_three_body] else: three_body_id = torch.concatenate(line_graph.edges()) max_three_body_id = ( torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 ) - line_graph.ndata["bond_vec"] = graph.edata["bond_vec"][ - :max_three_body_id - ] - line_graph.ndata["bond_dist"] = graph.edata["bond_dist"][ - :max_three_body_id - ] - line_graph.ndata["pbc_offset"] = graph.edata["pbc_offset"][ - :max_three_body_id - ] + line_graph.ndata["r"] = graph.edata["r"][:max_three_body_id] + line_graph.ndata["d"] = graph.edata["d"][:max_three_body_id] + line_graph.ndata["images"] = graph.edata["images"][:max_three_body_id] return line_graph @@ -73,8 +75,6 @@ class ALIGNNFF2Config(BaseSettings): triplet_input_features: int = 40 embedding_features: int = 64 hidden_features: int = 256 - # fc_layers: int = 1 - # fc_features: int = 64 output_features: int = 1 grad_multiplier: int = -1 calculate_gradient: bool = True @@ -83,14 +83,8 @@ class ALIGNNFF2Config(BaseSettings): gradwise_weight: float = 0.0 stresswise_weight: float = 0.0 atomwise_weight: float = 0.0 - # if link == log, apply `exp` to final outputs - # to constrain predictions to be positive - link: Literal["identity", "log", "logit"] = "identity" - zero_inflated: bool = False classification: bool = False force_mult_natoms: bool = False - energy_mult_natoms: bool = False - include_pos_deriv: bool = False use_cutoff_function: bool = False inner_cutoff: float = 6 # Ansgtrom stress_multiplier: float = 1 @@ -100,6 +94,9 @@ class ALIGNNFF2Config(BaseSettings): multiply_cutoff: bool = False extra_features: int = 0 exponent: int = 3 + max_n: int = 9 + max_f: int = 4 + learn_basis: bool = True class Config: """Configure model settings behavior.""" @@ -316,11 +313,15 @@ def __init__( self.atom_embedding = MLPLayer( config.atom_input_features, config.hidden_features ) - + # self.bond_expansion = RadialBesselFunction( + # max_n=config.max_n, + # cutoff=config.inner_cutoff, + # learnable=config.learn_basis, + # ) self.edge_embedding = nn.Sequential( RBFExpansion( vmin=0, - vmax=8.0, + vmax=5.0, bins=config.edge_input_features, ), MLPLayer(config.edge_input_features, config.embedding_features), @@ -389,18 +390,6 @@ def __init__( # self.softmax = nn.LogSoftmax(dim=1) else: self.fc = nn.Linear(config.hidden_features, config.output_features) - self.link = None - self.link_name = config.link - if config.link == "identity": - self.link = lambda x: x - elif config.link == "log": - self.link = torch.exp - avg_gap = 0.7 # magic number -- average bandgap in dft_3d - self.fc.bias.data = torch.tensor( - np.log(avg_gap), dtype=torch.float - ) - elif config.link == "logit": - self.link = torch.sigmoid def forward( self, g: Union[Tuple[dgl.DGLGraph, dgl.DGLGraph], dgl.DGLGraph] @@ -411,42 +400,37 @@ def forward( y: bond features (g.edata and lg.ndata) z: angle features (lg.edata) """ + result = {} if len(self.alignn_layers) > 0: g, lg = g lg = lg.local_var() - - # angle features (fixed) - z = self.angle_embedding(lg.edata.pop("h")) if self.config.extra_features != 0: features = g.ndata["extra_features"] - # print('features',features,features.shape) features = self.extra_feature_embedding(features) - g = g.local_var() - result = {} - - # initial node features: atom feature network... x = g.ndata.pop("atom_features") - # print('x1',x,x.shape) - x = self.atom_embedding(x) - # print('x2',x,x.shape) - r = g.edata["r"] + r, bondlength = compute_pair_vector_and_distance(g) if self.config.calculate_gradient: r.requires_grad_(True) bondlength = torch.norm(r, dim=1) - # mask = bondlength >= self.config.inner_cutoff - # bondlength[mask]=float(1.1) - if self.config.lg_on_fly and len(self.alignn_layers) > 0: - # re-compute bond angle cosines here to ensure - # the three-body interactions are fully included - # in the autograd graph. don't rely on dataloader/caching. - lg.ndata["r"] = r # overwrites precomputed r values - lg.apply_edges(compute_bond_cosines) # overwrites precomputed h - z = self.angle_embedding(lg.edata.pop("h")) - - # r = g.edata["r"].clone().detach().requires_grad_(True) + g.edata["d"] = bondlength + g.edata["r"] = r + # bond_expansion = self.bond_expansion(bondlength) + lg = check_line_graph(g, lg, self.config.inner_cutoff) + lg.apply_edges(compute_bond_cosines) + + # smooth_cutoff = polynomial_cutoff( + # bond_expansion, self.config.inner_cutoff, self.config.exponent + # ) + # bond_expansion *= smooth_cutoff + # g.edata["bond_expansion"] = ( + # bond_expansion # smooth_cutoff * bond_expansion + # ) + + # y = self.edge_embedding(bondlength) + z = self.angle_embedding(lg.edata.pop("h")) + if self.config.use_cutoff_function: - # bondlength = cutoff_function_based_edges( if self.config.multiply_cutoff: c_off = cutoff_function_based_edges( bondlength, @@ -464,7 +448,6 @@ def forward( y = self.edge_embedding(bondlength) else: y = self.edge_embedding(bondlength) - # y = self.edge_embedding(bondlength) # ALIGNN updates: update node, edge, triplet features for alignn_layer in self.alignn_layers: x, y, z = alignn_layer(g, lg, x, y, z) @@ -479,12 +462,10 @@ def forward( out = self.fc(h) if self.config.extra_features != 0: h_feat = self.readout_feat(g, features) - # print('h_feat',h_feat) h = torch.cat((h, h_feat), 1) h = self.fc1(h) h = self.fc2(h) out = self.fc3(h) - # print('out',out) else: out = torch.squeeze(out) atomwise_pred = torch.empty(1) @@ -500,25 +481,14 @@ def forward( stress = torch.empty(1) if self.config.calculate_gradient: - if self.config.include_pos_deriv: - # Not tested yet - g.ndata["coords"].requires_grad_(True) - dx = [g.ndata["coords"], r] - else: - dx = r - - if self.config.energy_mult_natoms: - en_out = out * g.num_nodes() - else: - en_out = out - + en_out = out # force calculation based on bond displacement vectors # autograd gives dE / d{r_{i->j}} pair_forces = ( self.config.grad_multiplier * grad( en_out, - dx, + r, grad_outputs=torch.ones_like(en_out), create_graph=True, retain_graph=True, @@ -609,8 +579,6 @@ def forward( # * torch.einsum("ij, ik->jk", result["r"], result["dy_dr"]) # / 2 # ) # / ( g.ndata["V"][0]) - if self.link: - out = self.link(out) if self.classification: # out = torch.max(out,dim=1) diff --git a/alignn/models/utils.py b/alignn/models/utils.py index be5643b..ba17b42 100644 --- a/alignn/models/utils.py +++ b/alignn/models/utils.py @@ -1,6 +1,6 @@ """Shared model-building components.""" -from typing import Optional +from typing import Optional import numpy as np import torch from torch import nn @@ -33,7 +33,7 @@ def __init__( else: self.lengthscale = lengthscale - self.gamma = 1 / (lengthscale ** 2) + self.gamma = 1 / (lengthscale**2) def forward(self, distance: torch.Tensor) -> torch.Tensor: """Apply RBF expansion to interatomic distance tensor.""" From f00a94a2655d0722f1396728487bff0352bfc504 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 27 Oct 2024 00:46:07 -0400 Subject: [PATCH 04/37] Lint --- alignn/graphs.py | 3 +-- alignn/models/alignn_ff2.py | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/alignn/graphs.py b/alignn/graphs.py index 94072c6..9299015 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -18,7 +18,7 @@ def temp_graph(atoms=None, cutoff=4.0, atom_features="cgcnn", dtype="float32"): - """Helper function to construct a graph for a given cutoff.""" + """Construct a graph for a given cutoff.""" TORCH_DTYPES = { "float16": torch.float16, "float32": torch.float32, @@ -73,7 +73,6 @@ def radius_graph_jarvis( dtype="float32", ): """Construct radius graph with dynamic cutoff.""" - while True: # try: # Attempt to create the graph diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py index b4a012b..af4b611 100644 --- a/alignn/models/alignn_ff2.py +++ b/alignn/models/alignn_ff2.py @@ -89,7 +89,6 @@ class ALIGNNFF2Config(BaseSettings): inner_cutoff: float = 6 # Ansgtrom stress_multiplier: float = 1 add_reverse_forces: bool = False # will make True as default soon - lg_on_fly: bool = False # will make True as default soon batch_stress: bool = True multiply_cutoff: bool = False extra_features: int = 0 @@ -98,10 +97,9 @@ class ALIGNNFF2Config(BaseSettings): max_f: int = 4 learn_basis: bool = True - class Config: - """Configure model settings behavior.""" - - env_prefix = "jv_model" + # class Config: + # """Configure model settings behavior.""" + # env_prefix = "jv_model" def cutoff_function_based_edges_old(r, inner_cutoff=4): From 27708dfa27cb450c2693300a1d0e7fd6ec145ee3 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 27 Oct 2024 11:14:28 -0400 Subject: [PATCH 05/37] Lint --- .../config_example_atomwise.json | 9 +- alignn/graphs.py | 40 +-- alignn/models/alignn_ff2.py | 207 +++++------- alignn/models/utils.py | 309 ++++++++++++++++- alignn/tests/test_alignn_ff.py | 27 +- alignn/train.py | 314 ++++-------------- alignn/utils.py | 141 +++++++- 7 files changed, 640 insertions(+), 407 deletions(-) diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index f1d13ad..0fae17d 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -44,16 +44,15 @@ "atom_input_features": 92, "calculate_gradient":true, "atomwise_output_features":0, - "alignn_layers":2, - "gcn_layers":2, - "hidden_features":128, + "alignn_layers":1, + "gcn_layers":1, + "hidden_features":64, "output_features": 1, "graphwise_weight":0.85, "gradwise_weight":0.05, "atomwise_weight":0.0, "stresswise_weight":0.05, - "add_reverse_forces":true, - "lg_on_fly":true + "add_reverse_forces":true } diff --git a/alignn/graphs.py b/alignn/graphs.py index 9299015..61001f3 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -74,26 +74,26 @@ def radius_graph_jarvis( ): """Construct radius graph with dynamic cutoff.""" while True: - # try: - # Attempt to create the graph - g, u, v, r = temp_graph( - atoms=atoms, - cutoff=cutoff, - atom_features=atom_features, - dtype=dtype, - ) - # Check if all atoms are included as nodes - if g.num_nodes() == len(atoms.elements): - # print(f"Graph constructed with cutoff: {cutoff}") - break # Exit the loop when successful - # Increment the cutoff if the graph is incomplete - cutoff += cutoff_extra - # print(f"Increasing cutoff to: {cutoff}") - - # except Exception as exp: - # # Handle exceptions and try again - # print(f"Graph construction failed: {exp}") - # cutoff += cutoff_extra # Try with a larger cutoff + try: + # Attempt to create the graph + g, u, v, r = temp_graph( + atoms=atoms, + cutoff=cutoff, + atom_features=atom_features, + dtype=dtype, + ) + # Check if all atoms are included as nodes + if g.num_nodes() == len(atoms.elements): + # print(f"Graph constructed with cutoff: {cutoff}") + break # Exit the loop when successful + # Increment the cutoff if the graph is incomplete + cutoff += cutoff_extra + # print(f"Increasing cutoff to: {cutoff}") + + except Exception as exp: + # Handle exceptions and try again + print(f"Graph construction failed: {exp,cutoff}") + cutoff += cutoff_extra # Try with a larger cutoff # Optional: Create a line graph if requested if line_graph: diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py index af4b611..9041b13 100644 --- a/alignn/models/alignn_ff2.py +++ b/alignn/models/alignn_ff2.py @@ -12,56 +12,19 @@ from typing import Literal from torch import nn from torch.nn import functional as F -from alignn.models.utils import RBFExpansion +from alignn.models.utils import ( + RBFExpansion, + BesselExpansion, + SphericalHarmonicsExpansion, + FourierExpansion, + compute_pair_vector_and_distance, + check_line_graph, + cutoff_function_based_edges, +) from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings -# from math import pi, sqrt - - -def compute_pair_vector_and_distance(g: dgl.DGLGraph): - """Calculate bond vectors and distances using dgl graphs. - - Args: - g: DGL graph - - Returns: - bond_vec (torch.tensor): bond distance between two atoms - bond_dist (torch.tensor): vector from src node to dst node - """ - dst_pos = g.ndata["coords"][g.edges()[1]] + g.edata["images"] - src_pos = g.ndata["coords"][g.edges()[0]] - bond_vec = dst_pos - src_pos - bond_dist = torch.norm(bond_vec, dim=1) - - return bond_vec, bond_dist - - -def check_line_graph( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float -): - """Ensure that 3body line graph is compatible with a given graph. - - Args: - graph: atomistic graph - line_graph: line graph of atomistic graph - threebody_cutoff: cutoff for three-body interactions - """ - valid_three_body = graph.edata["d"] <= threebody_cutoff - if line_graph.num_nodes() == graph.edata["r"][valid_three_body].shape[0]: - line_graph.ndata["r"] = graph.edata["r"][valid_three_body] - line_graph.ndata["d"] = graph.edata["d"][valid_three_body] - line_graph.ndata["images"] = graph.edata["images"][valid_three_body] - else: - three_body_id = torch.concatenate(line_graph.edges()) - max_three_body_id = ( - torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 - ) - line_graph.ndata["r"] = graph.edata["r"][:max_three_body_id] - line_graph.ndata["d"] = graph.edata["d"][:max_three_body_id] - line_graph.ndata["images"] = graph.edata["images"][:max_three_body_id] - - return line_graph +torch.autograd.set_detect_anomaly(True) class ALIGNNFF2Config(BaseSettings): @@ -80,74 +43,25 @@ class ALIGNNFF2Config(BaseSettings): calculate_gradient: bool = True atomwise_output_features: int = 0 graphwise_weight: float = 1.0 - gradwise_weight: float = 0.0 - stresswise_weight: float = 0.0 + gradwise_weight: float = 1.0 + stresswise_weight: float = 0.00001 atomwise_weight: float = 0.0 classification: bool = False - force_mult_natoms: bool = False - use_cutoff_function: bool = False - inner_cutoff: float = 6 # Ansgtrom + force_mult_natoms: bool = True + use_cutoff_function: bool = True + inner_cutoff: float = 4 # Ansgtrom stress_multiplier: float = 1 add_reverse_forces: bool = False # will make True as default soon batch_stress: bool = True multiply_cutoff: bool = False extra_features: int = 0 exponent: int = 3 + bond_exp_basis: str = "gaussian" # "bessel" # or gaussian + angle_exp_basis: str = "gaussian" # "bessel" # or gaussian max_n: int = 9 max_f: int = 4 learn_basis: bool = True - # class Config: - # """Configure model settings behavior.""" - # env_prefix = "jv_model" - - -def cutoff_function_based_edges_old(r, inner_cutoff=4): - """Apply smooth cutoff to pairwise interactions - - r: bond lengths - inner_cutoff: cutoff radius - - inside cutoff radius, apply smooth cutoff envelope - outside cutoff radius: hard zeros - """ - ratio = r / inner_cutoff - return torch.where( - ratio <= 1, - 1 - 6 * ratio**5 + 15 * ratio**4 - 10 * ratio**3, - torch.zeros_like(r), - ) - - -def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3): - """Apply smooth cutoff to pairwise interactions - - r: bond lengths - inner_cutoff: cutoff radius - - inside cutoff radius, apply smooth cutoff envelope - outside cutoff radius: hard zeros - """ - ratio = r / inner_cutoff - c1 = -(exponent + 1) * (exponent + 2) / 2 - c2 = exponent * (exponent + 2) - c3 = -exponent * (exponent + 1) / 2 - envelope = ( - 1 - + c1 * ratio**exponent - + c2 * ratio ** (exponent + 1) - + c3 * ratio ** (exponent + 2) - ) - # r_cut = inner_cutoff - # r_on = inner_cutoff+1 - - # r_sq = r * r - # r_on_sq = r_on * r_on - # r_cut_sq = r_cut * r_cut - # envelope = (r_cut_sq - r_sq) - # ** 2 * (r_cut_sq + 2 * r_sq - 3 * r_on_sq)/ (r_cut_sq - r_on_sq) ** 3 - return torch.where(r <= inner_cutoff, envelope, torch.zeros_like(r)) - class EdgeGatedGraphConv(nn.Module): """Edge gated graph convolution from arxiv:1711.07553. @@ -311,29 +225,67 @@ def __init__( self.atom_embedding = MLPLayer( config.atom_input_features, config.hidden_features ) - # self.bond_expansion = RadialBesselFunction( - # max_n=config.max_n, - # cutoff=config.inner_cutoff, - # learnable=config.learn_basis, - # ) - self.edge_embedding = nn.Sequential( - RBFExpansion( - vmin=0, - vmax=5.0, - bins=config.edge_input_features, - ), - MLPLayer(config.edge_input_features, config.embedding_features), - MLPLayer(config.embedding_features, config.hidden_features), - ) - self.angle_embedding = nn.Sequential( - RBFExpansion( - vmin=-1, - vmax=1.0, - bins=config.triplet_input_features, - ), - MLPLayer(config.triplet_input_features, config.embedding_features), - MLPLayer(config.embedding_features, config.hidden_features), - ) + if self.config.bond_exp_basis == "bessel": + self.edge_embedding = nn.Sequential( + BesselExpansion( + # RadialBesselFunction( + vmin=0, + vmax=8.0, + bins=config.edge_input_features, + ), + MLPLayer( + config.edge_input_features, config.embedding_features + ), + MLPLayer(config.embedding_features, config.hidden_features), + ) + else: + self.edge_embedding = nn.Sequential( + RBFExpansion( + vmin=0, + vmax=8.0, + bins=config.edge_input_features, + ), + MLPLayer( + config.edge_input_features, config.embedding_features + ), + MLPLayer(config.embedding_features, config.hidden_features), + ) + if self.config.angle_exp_basis == "spherical": + self.angle_embedding = nn.Sequential( + SphericalHarmonicsExpansion(), + MLPLayer( + config.triplet_input_features, config.embedding_features + ), + MLPLayer(config.embedding_features, config.hidden_features), + ) # not tested + elif self.config.angle_exp_basis == "bessel": + self.angle_embedding = nn.Sequential( + BesselExpansion(), + MLPLayer( + config.triplet_input_features, config.embedding_features + ), + MLPLayer(config.embedding_features, config.hidden_features), + ) # not tested + elif self.config.angle_exp_basis == "fourier": + self.angle_embedding = nn.Sequential( + FourierExpansion(), + MLPLayer( + config.triplet_input_features, config.embedding_features + ), + MLPLayer(config.embedding_features, config.hidden_features), + ) # not tested + else: + self.angle_embedding = nn.Sequential( + RBFExpansion( + vmin=-1, + vmax=1.0, + bins=config.triplet_input_features, + ), + MLPLayer( + config.triplet_input_features, config.embedding_features + ), + MLPLayer(config.embedding_features, config.hidden_features), + ) self.alignn_layers = nn.ModuleList( [ @@ -407,9 +359,11 @@ def forward( features = self.extra_feature_embedding(features) x = g.ndata.pop("atom_features") x = self.atom_embedding(x) + # r=g.edata['r'] r, bondlength = compute_pair_vector_and_distance(g) if self.config.calculate_gradient: r.requires_grad_(True) + # print('gradient') bondlength = torch.norm(r, dim=1) g.edata["d"] = bondlength g.edata["r"] = r @@ -518,6 +472,7 @@ def forward( ) else: forces = torch.squeeze(g.ndata["forces_ji"]) + # print('forces',forces) if self.config.stresswise_weight != 0: # Under development, use with caution diff --git a/alignn/models/utils.py b/alignn/models/utils.py index ba17b42..b8937f5 100644 --- a/alignn/models/utils.py +++ b/alignn/models/utils.py @@ -3,7 +3,78 @@ from typing import Optional import numpy as np import torch -from torch import nn + +# from torch import nn +from math import pi +import torch.nn as nn + +# from scipy.special import spherical_jn +# from scipy.special import sph_harm, lpmv +import math +import dgl + + +class BesselExpansion(nn.Module): + """Expand interatomic distances with spherical Bessel functions.""" + + def __init__( + self, + vmin: float = 0, + vmax: float = 8, + bins: int = 40, + cutoff: Optional[float] = None, + ): + """Register torch parameters for Bessel function expansion.""" + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.cutoff = cutoff if cutoff is not None else vmax + + # Generate frequency parameters for Bessel functions + # Convert to float32 explicitly + frequencies = torch.tensor( + [(n * np.pi) / self.cutoff for n in range(1, bins + 1)], + dtype=torch.float32, + ) + self.register_buffer("frequencies", frequencies) + + # Precompute normalization factors + norm_factors = torch.tensor( + [np.sqrt(2 / self.cutoff) for _ in range(bins)], + dtype=torch.float32, + ) + self.register_buffer("norm_factors", norm_factors) + + def forward(self, distance: torch.Tensor) -> torch.Tensor: + """Apply Bessel function expansion to interatomic distance tensor.""" + # Ensure input is float32 + distance = distance.to(torch.float32) + + # Compute the zero-order spherical Bessel functions + x = distance.unsqueeze(-1) * self.frequencies + + # Handle the case where x is close to zero + mask = x.abs() < 1e-10 + j0 = torch.where(mask, torch.ones_like(x), torch.sin(x) / x) + + # Apply normalization + bessel_features = j0 * self.norm_factors + + # Apply smooth cutoff function if cutoff is specified + if self.cutoff < self.vmax: + envelope = self._smooth_cutoff(distance) + bessel_features = bessel_features * envelope.unsqueeze(-1) + + return bessel_features + + def _smooth_cutoff(self, distance: torch.Tensor) -> torch.Tensor: + """Apply smooth cutoff function to ensure continuity at boundary.""" + x = torch.pi * distance / self.cutoff + cutoffs = 0.5 * (torch.cos(x) + 1.0) + return torch.where( + distance <= self.cutoff, cutoffs, torch.zeros_like(distance) + ) class RBFExpansion(nn.Module): @@ -40,3 +111,239 @@ def forward(self, distance: torch.Tensor) -> torch.Tensor: return torch.exp( -self.gamma * (distance.unsqueeze(1) - self.centers) ** 2 ) + + +class FourierExpansion(nn.Module): + """Fourier Expansion of a (periodic) scalar feature.""" + + def __init__( + self, + max_f: int = 5, + interval: float = pi, + scale_factor: float = 1.0, + learnable: bool = False, + ): + """Args: + max_f (int): the maximum frequency of the Fourier expansion. + Default = 5 + interval (float): interval of the Fourier exp, such that functions + are orthonormal over [-interval, interval]. Default = pi + scale_factor (float): pre-factor to scale all values. + learnable (bool): whether to set the frequencies as learnable + Default = False. + """ + super().__init__() + self.max_f = max_f + self.interval = interval + self.scale_factor = scale_factor + # Initialize frequencies at canonical + if learnable: + self.frequencies = torch.nn.Parameter( + data=torch.arange(0, max_f + 1, dtype=torch.float32), + requires_grad=True, + ) + else: + self.register_buffer( + "frequencies", torch.arange(0, max_f + 1, dtype=torch.float32) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Expand x into cos and sin functions.""" + result = x.new_zeros(x.shape[0], 1 + 2 * self.max_f) + tmp = torch.outer(x, self.frequencies) + result[:, ::2] = torch.cos(tmp * pi / self.interval) + result[:, 1::2] = torch.sin(tmp[:, 1:] * pi / self.interval) + return result / self.interval * self.scale_factor + + +class SphericalHarmonicsExpansion(nn.Module): + """Expand angles with spherical harmonics.""" + + def __init__( + self, + vmin: float = 0, + vmax: float = math.pi, + bins: int = 20, + l_max: int = 3, + ): + """Register torch parameters for spherical harmonics expansion.""" + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.l_max = l_max + self.num_harmonics = (l_max + 1) ** 2 + self.register_buffer( + "centers", torch.linspace(self.vmin, self.vmax, self.bins) + ) + + def forward(self, theta: torch.Tensor) -> torch.Tensor: + """Apply spherical harmonics expansion to angular tensors.""" + harmonics = [] + phi = torch.zeros_like(theta) + for l_x in range(self.l_max + 1): + for m in range(-l_x, l_x + 1): + y_lm = self._spherical_harmonic(l_x, m, theta, phi) + harmonics.append(y_lm) + return torch.stack(harmonics, dim=-1) + + def _legendre_polynomial( + self, l_x: int, m: int, x: torch.Tensor + ) -> torch.Tensor: + """ + Compute the associated Legendre polynomials P_l^m(x). + :param l: Degree of the polynomial. + :param m: Order of the polynomial. + :param x: Input tensor. + :return: Associated Legendre polynomial evaluated at x. + """ + pmm = torch.ones_like(x) + if m > 0: + somx2 = torch.sqrt((1 - x) * (1 + x)) + fact = 1.0 + for i in range(1, m + 1): + pmm = -pmm * fact * somx2 + fact += 2.0 + + if l_x == m: + return pmm + pmmp1 = x * (2 * m + 1) * pmm + if l_x == m + 1: + return pmmp1 + + pll = torch.zeros_like(x) + for ll in range(m + 2, l_x + 1): + pll = ((2 * ll - 1) * x * pmmp1 - (ll + m - 1) * pmm) / (ll - m) + pmm = pmmp1 + pmmp1 = pll + + return pll + + def _spherical_harmonic( + self, l_x: int, m: int, theta: torch.Tensor, phi: torch.Tensor + ) -> torch.Tensor: + """ + Compute the real part of the spherical harmonics Y_l^m(theta, phi). + :param l: Degree of the harmonic. + :param m: Order of the harmonic. + :param theta: Polar angle (in radians). + :param phi: Azimuthal angle (in radians). + :return: Real part of the spherical harmonic Y_l^m. + """ + sqrt2 = torch.sqrt(torch.tensor(2.0)) + if m > 0: + return ( + sqrt2 + * self._k(l_x, m) + * torch.cos(m * phi) + * self._legendre_polynomial(l_x, m, torch.cos(theta)) + ) + elif m < 0: + return ( + sqrt2 + * self._k(l_x, -m) + * torch.sin(-m * phi) + * self._legendre_polynomial(l_x, -m, torch.cos(theta)) + ) + else: + return self._k(l_x, 0) * self._legendre_polynomial( + l_x, 0, torch.cos(theta) + ) + + def _k(self, l_x: int, m: int) -> float: + """ + Normalization constant for the spherical harmonics. + :param l: Degree of the harmonic. + :param m: Order of the harmonic. + :return: Normalization constant. + """ + return math.sqrt( + (2 * l_x + 1) + / (4 * math.pi) + * math.factorial(l_x - m) + / math.factorial(l_x + m) + ) + + +def compute_pair_vector_and_distance(g: dgl.DGLGraph): + """Calculate bond vectors and distances using dgl graphs.""" + dst_pos = g.ndata["coords"][g.edges()[1]] + g.edata["images"] + src_pos = g.ndata["coords"][g.edges()[0]] + bond_vec = dst_pos - src_pos + bond_dist = torch.norm(bond_vec, dim=1) + + return bond_vec, bond_dist + + +def check_line_graph( + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float +): + """Ensure that 3body line graph is compatible with a given graph. + + Args: + graph: atomistic graph + line_graph: line graph of atomistic graph + threebody_cutoff: cutoff for three-body interactions + """ + valid_three_body = graph.edata["d"] <= threebody_cutoff + if line_graph.num_nodes() == graph.edata["r"][valid_three_body].shape[0]: + line_graph.ndata["r"] = graph.edata["r"][valid_three_body] + line_graph.ndata["d"] = graph.edata["d"][valid_three_body] + line_graph.ndata["images"] = graph.edata["images"][valid_three_body] + else: + three_body_id = torch.concatenate(line_graph.edges()) + max_three_body_id = ( + torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 + ) + line_graph.ndata["r"] = graph.edata["r"][:max_three_body_id] + line_graph.ndata["d"] = graph.edata["d"][:max_three_body_id] + line_graph.ndata["images"] = graph.edata["images"][:max_three_body_id] + + return line_graph + + +def cutoff_function_based_edges_old(r, inner_cutoff=4): + """Apply smooth cutoff to pairwise interactions + + r: bond lengths + inner_cutoff: cutoff radius + + inside cutoff radius, apply smooth cutoff envelope + outside cutoff radius: hard zeros + """ + ratio = r / inner_cutoff + return torch.where( + ratio <= 1, + 1 - 6 * ratio**5 + 15 * ratio**4 - 10 * ratio**3, + torch.zeros_like(r), + ) + + +def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3): + """Apply smooth cutoff to pairwise interactions + + r: bond lengths + inner_cutoff: cutoff radius + + inside cutoff radius, apply smooth cutoff envelope + outside cutoff radius: hard zeros + """ + ratio = r / inner_cutoff + c1 = -(exponent + 1) * (exponent + 2) / 2 + c2 = exponent * (exponent + 2) + c3 = -exponent * (exponent + 1) / 2 + envelope = ( + 1 + + c1 * ratio**exponent + + c2 * ratio ** (exponent + 1) + + c3 * ratio ** (exponent + 2) + ) + # r_cut = inner_cutoff + # r_on = inner_cutoff+1 + + # r_sq = r * r + # r_on_sq = r_on * r_on + # r_cut_sq = r_cut * r_cut + # envelope = (r_cut_sq - r_sq) + # ** 2 * (r_cut_sq + 2 * r_sq - 3 * r_on_sq)/ (r_cut_sq - r_on_sq) ** 3 + return torch.where(r <= inner_cutoff, envelope, torch.zeros_like(r)) diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index 740dbab..5621e5c 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -8,7 +8,7 @@ ForceField, get_interface_energy, ) -from alignn.graphs import Graph +from alignn.graphs import Graph, radius_graph_jarvis from alignn.ff.ff import phonons from jarvis.core.atoms import ase_to_atoms from jarvis.db.figshare import get_jid_data @@ -21,6 +21,31 @@ fd_path, ForceField, ) +from jarvis.io.vasp.inputs import Poscar + +# JVASP-25139 +pos = """Rb8 +1.0 +8.534892364405636 0.6983003603741366 -0.0 +-3.4905051320748712 7.819743736978101 -0.0 +0.0 -0.0 9.899741852856957 +Rb +8 +Cartesian +-0.48436620907024275 6.0395021169791425 0.0 +-0.48395379092975643 6.039257883020857 4.94987 +5.528746209070245 2.478537883020856 0.0 +5.528333790929757 2.478782116979143 4.94987 +1.264246578587533 2.1348318180359995 2.469410532600589 +1.2579434214124685 2.1241881819640005 7.419280532600588 +3.7864365785875354 6.393851818035999 2.4804594673994105 +3.7801334214124656 6.383208181964002 7.430329467399411 +""" + + +def test_radius_graph_jarvis(): + atoms = Poscar.from_string(pos).atoms + g, lg = radius_graph_jarvis(atoms=atoms) def test_alignnff(): diff --git a/alignn/train.py b/alignn/train.py index f4b094d..730d2da 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -1,9 +1,4 @@ -"""Ignite training script. - -from the repository root, run -`PYTHONPATH=$PYTHONPATH:. python alignn/train.py` -then `tensorboard --logdir tb_logs/test` to monitor results... -""" +"""Module for training script.""" from torch.nn.parallel import DistributedDataParallel as DDP from functools import partial @@ -11,7 +6,6 @@ import torch import random from sklearn.metrics import mean_absolute_error -from sklearn.metrics import log_loss import pickle as pk import numpy as np from torch import nn @@ -27,85 +21,18 @@ import warnings import time from sklearn.metrics import roc_auc_score +from alignn.utils import ( + # activated_output_transform, + # make_standard_scalar_and_pca, + # thresholded_output_transform, + group_decay, + setup_optimizer, + print_train_val_loss, +) -warnings.filterwarnings("ignore", category=RuntimeWarning) -# torch.set_default_dtype(torch.float32) - - -# def setup(rank, world_size): -# """Set up multi GPU rank.""" -# os.environ["MASTER_ADDR"] = "localhost" -# os.environ["MASTER_PORT"] = "12355" -# # Initialize the distributed environment. -# dist.init_process_group("nccl", rank=rank, world_size=world_size) -# torch.cuda.set_device(rank) - - -def activated_output_transform(output): - """Exponentiate output.""" - y_pred, y = output - y_pred = torch.exp(y_pred) - y_pred = y_pred[:, 1] - return y_pred, y - - -def make_standard_scalar_and_pca(output): - """Use standard scalar and PCS for multi-output data.""" - sc = pk.load(open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")) - y_pred, y = output - y_pred = torch.tensor( - sc.transform(y_pred.cpu().numpy()), device=y_pred.device - ) - y = torch.tensor(sc.transform(y.cpu().numpy()), device=y.device) - # pc = pk.load(open("pca.pkl", "rb")) - # y_pred = torch.tensor(pc.transform(y_pred), device=device) - # y = torch.tensor(pc.transform(y), device=device) - # y_pred = torch.tensor(pca_sc.inverse_transform(y_pred),device=device) - # y = torch.tensor(pca_sc.inverse_transform(y),device=device) - # print (y.shape,y_pred.shape) - return y_pred, y - +# from sklearn.metrics import log_loss -def thresholded_output_transform(output): - """Round off output.""" - y_pred, y = output - y_pred = torch.round(torch.exp(y_pred)) - # print ('output',y_pred) - return y_pred, y - - -def group_decay(model): - """Omit weight decay from bias and batchnorm params.""" - decay, no_decay = [], [] - - for name, p in model.named_parameters(): - if "bias" in name or "bn" in name or "norm" in name: - no_decay.append(p) - else: - decay.append(p) - - return [ - {"params": decay}, - {"params": no_decay, "weight_decay": 0}, - ] - - -def setup_optimizer(params, config: TrainingConfig): - """Set up optimizer for param groups.""" - if config.optimizer == "adamw": - optimizer = torch.optim.AdamW( - params, - lr=config.learning_rate, - weight_decay=config.weight_decay, - ) - elif config.optimizer == "sgd": - optimizer = torch.optim.SGD( - params, - lr=config.learning_rate, - momentum=0.9, - weight_decay=config.weight_decay, - ) - return optimizer +warnings.filterwarnings("ignore", category=RuntimeWarning) def train_dgl( @@ -279,76 +206,6 @@ def train_dgl( or config.model.name == "alignn_ff2" ): - def get_batch_errors(dat=[]): - """Get errors for samples.""" - target_out = [] - pred_out = [] - grad = [] - atomw = [] - stress = [] - mean_out = 0 - mean_atom = 0 - mean_grad = 0 - mean_stress = 0 - # natoms_batch=False - # print ('lendat',len(dat)) - for i in dat: - if i["target_grad"]: - # if config.normalize_graph_level_loss: - # natoms_batch = 0 - for m, n in zip(i["target_grad"], i["pred_grad"]): - x = np.abs(np.array(m) - np.array(n)) - grad.append(np.mean(x)) - # if config.normalize_graph_level_loss: - # natoms_batch += np.array(i["pred_grad"]).shape[0] - if i["target_out"]: - for j, k in zip(i["target_out"], i["pred_out"]): - # if config.normalize_graph_level_loss and - # natoms_batch: - # j=j/natoms_batch - # k=k/natoms_batch - # if config.normalize_graph_level_loss and - # not natoms_batch: - # tmp = 'Add above in atomwise if not train grad.' - # raise ValueError(tmp) - - target_out.append(j) - pred_out.append(k) - if i["target_stress"]: - for p, q in zip(i["target_stress"], i["pred_stress"]): - x = np.abs(np.array(p) - np.array(q)) - stress.append(np.mean(x)) - if i["target_atomwise_pred"]: - for m, n in zip( - i["target_atomwise_pred"], i["pred_atomwise_pred"] - ): - x = np.abs(np.array(m) - np.array(n)) - atomw.append(np.mean(x)) - if "target_out" in i: - # if i["target_out"]: - target_out = np.array(target_out) - pred_out = np.array(pred_out) - # print('target_out',target_out,target_out.shape) - # print('pred_out',pred_out,pred_out.shape) - if classification: - mean_out = log_loss(target_out, pred_out) - else: - mean_out = mean_absolute_error(target_out, pred_out) - if "target_stress" in i: - # if i["target_stress"]: - mean_stress = np.array(stress).mean() - if "target_grad" in i: - # if i["target_grad"]: - mean_grad = np.array(grad).mean() - if "target_atomwise_pred" in i: - # if i["target_atomwise_pred"]: - mean_atom = np.array(atomw).mean() - # print ('natoms_batch',natoms_batch) - # if natoms_batch!=0: - # mean_out = mean_out/natoms_batch - # print ('dat',dat) - return mean_out, mean_atom, mean_grad, mean_stress - best_loss = np.inf criterion = nn.L1Loss() if classification: @@ -362,8 +219,11 @@ def get_batch_errors(dat=[]): # optimizer.zero_grad() train_init_time = time.time() running_loss = 0 + running_loss1 = 0 + running_loss2 = 0 + running_loss3 = 0 + running_loss4 = 0 train_result = [] - # for dats in train_loader: for dats, jid in zip(train_loader, train_loader.dataset.ids): info = {} # info["id"] = jid @@ -387,8 +247,6 @@ def get_batch_errors(dat=[]): loss3 = 0 # Such as forces loss4 = 0 # Such as stresses if config.model.output_features is not None: - # print('result["out"]',result["out"]) - # print('dats[2]',dats[2]) loss1 = config.model.graphwise_weight * criterion( result["out"], dats[-1].to(device), @@ -399,18 +257,7 @@ def get_batch_errors(dat=[]): info["pred_out"] = ( result["out"].cpu().detach().numpy().tolist() ) - # graphlevel_loss += np.mean( - # np.abs( - # dats[2].cpu().numpy() - # - result["out"].cpu().detach().numpy() - # ) - # ) - # print("target_out", info["target_out"][0]) - # print("pred_out", info["pred_out"][0]) - # print( - # "config.model.atomwise_output_features", - # config.model.atomwise_output_features, - # ) + running_loss1 += loss1.item() if ( config.model.atomwise_output_features > 0 # config.model.atomwise_output_features is not None @@ -426,12 +273,7 @@ def get_batch_errors(dat=[]): info["pred_atomwise_pred"] = ( result["atomwise_pred"].cpu().detach().numpy().tolist() ) - # atomlevel_loss += np.mean( - # np.abs( - # dats[0].ndata["atomwise_target"].cpu().numpy() - # - result["atomwise_pred"].cpu().detach().numpy() - # ) - # ) + running_loss2 += loss2.item() if config.model.calculate_gradient: loss3 = config.model.gradwise_weight * criterion( @@ -444,28 +286,8 @@ def get_batch_errors(dat=[]): info["pred_grad"] = ( result["grad"].cpu().detach().numpy().tolist() ) - # gradlevel_loss += np.mean( - # np.abs( - # dats[0].ndata["atomwise_grad"].cpu().numpy() - # - result["grad"].cpu().detach().numpy() - # ) - # ) - # print("target_grad", info["target_grad"][0]) - # print("pred_grad", info["pred_grad"][0]) + running_loss3 += loss3.item() if config.model.stresswise_weight != 0: - # print( - # 'result["stress"]', - # result["stresses"], - # result["stresses"].shape, - # ) - # print( - # 'dats[0].ndata["stresses"]', - # torch.cat(tuple(dats[0].ndata["stresses"])), - # dats[0].ndata["stresses"].shape, - # ) # ,torch.cat(dats[0].ndata["stresses"]), - # torch.cat(dats[0].ndata["stresses"]).shape) - # print('result["stresses"]',result["stresses"],result["stresses"].shape) - # print(dats[0].ndata["stresses"],dats[0].ndata["stresses"].shape) loss4 = config.model.stresswise_weight * criterion( (result["stresses"]).to(device), torch.cat(tuple(dats[0].ndata["stresses"])).to(device), @@ -480,6 +302,7 @@ def get_batch_errors(dat=[]): info["pred_stress"] = ( result["stresses"].cpu().detach().numpy().tolist() ) + running_loss4 += loss4.item() # print("target_stress", info["target_stress"][0]) # print("pred_stress", info["pred_stress"][0]) train_result.append(info) @@ -488,22 +311,35 @@ def get_batch_errors(dat=[]): optimizer.step() # optimizer.zero_grad() #never running_loss += loss.item() - mean_out, mean_atom, mean_grad, mean_stress = get_batch_errors( - train_result - ) + # mean_out, mean_atom, mean_grad, mean_stress = get_batch_errors( + # train_result + # ) # dumpjson(filename="Train_results.json", data=train_result) scheduler.step() train_final_time = time.time() train_ep_time = train_final_time - train_init_time # if rank == 0: # or world_size == 1: - history_train.append([mean_out, mean_atom, mean_grad, mean_stress]) + history_train.append( + [ + running_loss, + running_loss1, + running_loss2, + running_loss3, + running_loss4, + ] + ) dumpjson( filename=os.path.join(config.output_dir, "history_train.json"), data=history_train, ) val_loss = 0 + val_loss1 = 0 + val_loss2 = 0 + val_loss3 = 0 + val_loss4 = 0 val_result = [] # for dats in val_loader: + val_init_time = time.time() for dats, jid in zip(val_loader, val_loader.dataset.ids): info = {} info["id"] = jid @@ -534,6 +370,7 @@ def get_batch_errors(dat=[]): info["pred_out"] = ( result["out"].cpu().detach().numpy().tolist() ) + val_loss1 += loss1.item() if ( config.model.atomwise_output_features > 0 @@ -549,6 +386,7 @@ def get_batch_errors(dat=[]): info["pred_atomwise_pred"] = ( result["atomwise_pred"].cpu().detach().numpy().tolist() ) + val_loss2 += loss2.item() if config.model.calculate_gradient: loss3 = config.model.gradwise_weight * criterion( result["grad"].to(device), @@ -560,17 +398,13 @@ def get_batch_errors(dat=[]): info["pred_grad"] = ( result["grad"].cpu().detach().numpy().tolist() ) + val_loss3 += loss3.item() if config.model.stresswise_weight != 0: # loss4 = config.model.stresswise_weight * criterion( # result["stress"].to(device), # dats[0].ndata["stresses"][0].to(device), # ) loss4 = config.model.stresswise_weight * criterion( - # torch.flatten(result["stress"].to(device)), - # (dats[0].ndata["stresses"]).to(device), - # torch.flatten(dats[0].ndata["stresses"]).to(device), - # torch.flatten(torch.cat(dats[0].ndata["stresses"])).to(device), - # dats[0].ndata["stresses"][0].to(device), (result["stresses"]).to(device), torch.cat(tuple(dats[0].ndata["stresses"])).to(device), ) @@ -584,12 +418,15 @@ def get_batch_errors(dat=[]): info["pred_stress"] = ( result["stresses"].cpu().detach().numpy().tolist() ) + val_loss4 += loss4.item() loss = loss1 + loss2 + loss3 + loss4 val_result.append(info) val_loss += loss.item() - mean_out, mean_atom, mean_grad, mean_stress = get_batch_errors( - val_result - ) + # mean_out, mean_atom, mean_grad, mean_stress = get_batch_errors( + # val_result + # ) + val_fin_time = time.time() + val_ep_time = val_fin_time - val_init_time current_model_name = "current_model.pt" torch.save( net.state_dict(), @@ -618,46 +455,30 @@ def get_batch_errors(dat=[]): data=val_result, ) best_model = net - history_val.append([mean_out, mean_atom, mean_grad, mean_stress]) + history_train.append( + [val_loss, val_loss1, val_loss2, val_loss3, val_loss4] + ) + # history_val.append([mean_out, mean_atom, mean_grad, mean_stress]) dumpjson( filename=os.path.join(config.output_dir, "history_val.json"), data=history_val, ) - # print('rank',rank) - # print('world_size',world_size) if rank == 0: - print( - "TrainLoss", - "Epoch", + print_train_val_loss( e, - "total", running_loss, - "out", - mean_out, - "atom", - mean_atom, - "grad", - mean_grad, - "stress", - mean_stress, - "time", - train_ep_time, - ) - print( - "ValLoss", - "Epoch", - e, - "total", + running_loss1, + running_loss2, + running_loss3, + running_loss4, val_loss, - "out", - mean_out, - "atom", - mean_atom, - "grad", - mean_grad, - "stress", - mean_stress, - saving_msg, + val_loss1, + val_loss2, + val_loss3, + val_loss4, + train_ep_time, + val_ep_time, + saving_msg=saving_msg, ) if rank == 0 or world_size == 1: @@ -668,10 +489,6 @@ def get_batch_errors(dat=[]): info = {} info["id"] = jid optimizer.zero_grad() - # print('dats[0]',dats[0]) - # print('test_loader',test_loader) - # print('test_loader.dataset.ids',test_loader.dataset.ids) - # result = net([dats[0].to(device), dats[1].to(device)]) if (config.model.alignn_layers) > 0: result = net([dats[0].to(device), dats[1].to(device)]) else: @@ -719,18 +536,9 @@ def get_batch_errors(dat=[]): ) if config.model.stresswise_weight != 0: loss4 = config.model.stresswise_weight * criterion( - # torch.flatten(result["stress"].to(device)), - # (dats[0].ndata["stresses"]).to(device), - # torch.flatten(dats[0].ndata["stresses"]).to(device), result["stresses"].to(device), torch.cat(tuple(dats[0].ndata["stresses"])).to(device), - # torch.flatten(torch.cat(dats[0].ndata["stresses"])).to(device), - # dats[0].ndata["stresses"][0].to(device), ) - # loss4 = config.model.stresswise_weight * criterion( - # result["stress"][0].to(device), - # dats[0].ndata["stresses"].to(device), - # ) info["target_stress"] = ( torch.cat(tuple(dats[0].ndata["stresses"])) .cpu() diff --git a/alignn/utils.py b/alignn/utils.py index cc6e57b..840cdf3 100644 --- a/alignn/utils.py +++ b/alignn/utils.py @@ -4,8 +4,10 @@ from pathlib import Path from typing import Union import matplotlib.pyplot as plt - from pydantic_settings import BaseSettings as PydanticBaseSettings +import torch +import pickle as pk +import os class BaseSettings(PydanticBaseSettings): @@ -43,3 +45,140 @@ def plot_learning_curve( plt.ylabel(key) return train, val + + +def activated_output_transform(output): + """Exponentiate output.""" + y_pred, y = output + y_pred = torch.exp(y_pred) + y_pred = y_pred[:, 1] + return y_pred, y + + +def make_standard_scalar_and_pca(output, tmp_output_dir="out"): + """Use standard scalar and PCS for multi-output data.""" + sc = pk.load(open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")) + y_pred, y = output + y_pred = torch.tensor( + sc.transform(y_pred.cpu().numpy()), device=y_pred.device + ) + y = torch.tensor(sc.transform(y.cpu().numpy()), device=y.device) + return y_pred, y + + +def thresholded_output_transform(output): + """Round off output.""" + y_pred, y = output + y_pred = torch.round(torch.exp(y_pred)) + # print ('output',y_pred) + return y_pred, y + + +def group_decay(model): + """Omit weight decay from bias and batchnorm params.""" + decay, no_decay = [], [] + + for name, p in model.named_parameters(): + if "bias" in name or "bn" in name or "norm" in name: + no_decay.append(p) + else: + decay.append(p) + + return [ + {"params": decay}, + {"params": no_decay, "weight_decay": 0}, + ] + + +def setup_optimizer(params, config): + """Set up optimizer for param groups.""" + if config.optimizer == "adamw": + optimizer = torch.optim.AdamW( + params, + lr=config.learning_rate, + weight_decay=config.weight_decay, + ) + elif config.optimizer == "sgd": + optimizer = torch.optim.SGD( + params, + lr=config.learning_rate, + momentum=0.9, + weight_decay=config.weight_decay, + ) + return optimizer + + +def print_train_val_loss( + e, + running_loss, + running_loss1, + running_loss2, + running_loss3, + running_loss4, + val_loss, + val_loss1, + val_loss2, + val_loss3, + val_loss4, + train_ep_time, + val_ep_time, + saving_msg="", +): + """Train loss header.""" + header = ("{:<12} {:<8} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}").format( + "Train Loss:", + "Epoch", + "Total", + "Graph", + "Atom", + "Grad", + "Stress", + "Time", + ) + print(header) + + # Train loss values + train_row = ( + "{:<12} {:<8} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} " + "{:<10.2f}" + ).format( + "", + e, + running_loss, + running_loss1, + running_loss2, + running_loss3, + running_loss4, + train_ep_time, + ) + print(train_row) + + # Validation loss header + header = ("{:<12} {:<8} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}").format( + "Val Loss:", + "Epoch", + "Total", + "Graph", + "Atom", + "Grad", + "Stress", + "Time", + ) + print(header) + + # Validation loss values + val_row = ( + "{:<12} {:<8} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} " + "{:<10.2f} {:<10}" + ).format( + "", + e, + val_loss, + val_loss1, + val_loss2, + val_loss3, + val_loss4, + val_ep_time, + saving_msg, + ) + print(val_row) From 2ef5930b9c55e0e01d985a1895981c89fee61396 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 27 Oct 2024 18:13:40 -0400 Subject: [PATCH 06/37] Add images in all graphs. --- .github/workflows/main.yml | 2 ++ alignn/graphs.py | 42 +++++++++++++++++++++++++------------ alignn/models/alignn_ff2.py | 4 +++- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index db27c63..ab955e3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -39,6 +39,8 @@ jobs: export DGLBACKEND=pytorch export CUDA_VISIBLE_DEVICES="-1" pip install phonopy flake8 pytest pycodestyle pydocstyle codecov pytest-cov coverage + pip uninstall jarvis-tools -y + pip install -q git+https://github.com/usnistgov/jarvis.git@develop python setup.py develop echo 'environment.yml' conda env export diff --git a/alignn/graphs.py b/alignn/graphs.py index 61001f3..6d34df0 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -203,7 +203,7 @@ def nearest_neighbor_edges( else: edges[(site_idx, dst)].add(tuple(image)) - return edges + return edges, images def build_undirected_edgedata( @@ -217,7 +217,7 @@ def build_undirected_edgedata( """ # second pass: construct *undirected* graph # import pprint - u, v, r = [], [], [] + u, v, r, all_images = [], [], [], [] for (src_id, dst_id), images in edges.items(): for dst_image in images: # fractional coordinate for periodic image of dst @@ -233,12 +233,14 @@ def build_undirected_edgedata( u.append(uu) v.append(vv) r.append(dd) + all_images.append(dst_image) u, v, r = (np.array(x) for x in (u, v, r)) u = torch.tensor(u) v = torch.tensor(v) r = torch.tensor(r).type(torch.get_default_dtype()) + all_images = torch.tensor(all_images).type(torch.get_default_dtype()) - return u, v, r + return u, v, r, all_images def radius_graph( @@ -284,6 +286,8 @@ def temp_graph(cutoff=5): # tile periodic images into X_dst # index id_dst into X_dst maps to atom id as id_dest % num_atoms X_dst = (cell_images @ lattice_mat)[:, None, :] + X_src + # cell_images = cell_images[:,None,:]+cell_images + # print('cell_images',cell_images,cell_images.shape) X_dst = X_dst.reshape(-1, 3) # pairwise distances between atoms in (0,0,0) cell # and atoms in all periodic image @@ -300,11 +304,15 @@ def temp_graph(cutoff=5): atol=atol, ), ) + # get node indices for edgelist from neighbor mask u, v = torch.where(neighbor_mask) + # cell_images=cell_images[neighbor_mask] + # u, v = torch.where(neighbor_mask) # print("u2v2", u, v, u.shape, v.shape) # print("v1", v, v.shape) # print("v2", v % num_atoms, (v % num_atoms).shape) + cell_images = cell_images[v // num_atoms] r = (X_dst[v] - X_src[u]).float() # gk = dgl.knn_graph(X_dst, 12) @@ -312,22 +320,22 @@ def temp_graph(cutoff=5): # print("gk", gk) v = v % num_atoms g = dgl.graph((u, v)) - return g, u, v, r + return g, u, v, r, cell_images - g, u, v, r = temp_graph(cutoff) + g, u, v, r, cell_images = temp_graph(cutoff) while (g.num_nodes()) != len(atoms.elements): try: cutoff += cutoff_extra - g, u, v, r = temp_graph(cutoff) + g, u, v, r, cell_images = temp_graph(cutoff) print("cutoff", id, cutoff) print(atoms) except Exception as exp: print("Graph exp", exp) pass - return u, v, r + return u, v, r, cell_images - return u, v, r + return u, v, r, cell_images ### @@ -454,19 +462,19 @@ def atom_dgl_multigraph( # print('id',id) # print('stratgery', neighbor_strategy) if neighbor_strategy == "k-nearest": - edges = nearest_neighbor_edges( + edges, images = nearest_neighbor_edges( atoms=atoms, cutoff=cutoff, max_neighbors=max_neighbors, id=id, use_canonize=use_canonize, ) - u, v, r = build_undirected_edgedata(atoms, edges) + u, v, r, images = build_undirected_edgedata(atoms, edges) elif neighbor_strategy == "radius_graph": # print('HERE') # import sys # sys.exit() - u, v, r = radius_graph( + u, v, r, images = radius_graph( atoms, cutoff=cutoff, cutoff_extra=cutoff_extra ) elif neighbor_strategy == "radius_graph_jarvis": @@ -498,10 +506,18 @@ def atom_dgl_multigraph( ) g = dgl.graph((u, v)) g.ndata["atom_features"] = node_features - g.edata["r"] = r + g.edata["r"] = torch.tensor(r).type(torch.get_default_dtype()) + # images=torch.tensor(images).type(torch.get_default_dtype()) + # print('images',images.shape,r.shape) + # print('type',torch.get_default_dtype()) + g.edata["images"] = torch.tensor(images).type( + torch.get_default_dtype() + ) vol = atoms.volume g.ndata["V"] = torch.tensor([vol for ii in range(atoms.num_atoms)]) - g.ndata["coords"] = torch.tensor(atoms.cart_coords) + g.ndata["coords"] = torch.tensor(atoms.cart_coords).type( + torch.get_default_dtype() + ) if use_lattice_prop: lattice_prop = np.array( [atoms.lattice.lat_lengths(), atoms.lattice.lat_angles()] diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py index 9041b13..40d8db1 100644 --- a/alignn/models/alignn_ff2.py +++ b/alignn/models/alignn_ff2.py @@ -199,6 +199,7 @@ def __init__(self, in_features: int, out_features: int): def forward(self, x): """Linear, Batchnorm, silu layer.""" + # print('xtype',x.dtype) return self.layer(x) @@ -277,7 +278,7 @@ def __init__( else: self.angle_embedding = nn.Sequential( RBFExpansion( - vmin=-1, + vmin=-1.0, vmax=1.0, bins=config.triplet_input_features, ), @@ -368,6 +369,7 @@ def forward( g.edata["d"] = bondlength g.edata["r"] = r # bond_expansion = self.bond_expansion(bondlength) + # z = self.angle_embedding(lg.edata.pop("h")) lg = check_line_graph(g, lg, self.config.inner_cutoff) lg.apply_edges(compute_bond_cosines) From 03e12e9213709e4adc1b4a626b218b4247036a2f Mon Sep 17 00:00:00 2001 From: knc6 Date: Tue, 29 Oct 2024 02:09:01 -0400 Subject: [PATCH 07/37] External. --- .../config_example_atomwise.json | 1 + alignn/ff/ff.py | 7 +- alignn/graphs.py | 36 ++- alignn/lmdb_dataset.py | 26 ++- alignn/models/alignn_ff2.py | 207 +++++++++++++++--- alignn/models/utils.py | 5 +- alignn/train.py | 12 +- 7 files changed, 241 insertions(+), 53 deletions(-) diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index 0fae17d..baf5198 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -51,6 +51,7 @@ "graphwise_weight":0.85, "gradwise_weight":0.05, "atomwise_weight":0.0, + "use_cutoff_function":false, "stresswise_weight":0.05, "add_reverse_forces":true diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index aa07b0f..de851f2 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -30,6 +30,8 @@ from jarvis.db.jsonutils import loadjson from alignn.graphs import Graph from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig +from alignn.models.alignn_ff2 import ALIGNNFF2, ALIGNNFF2Config + from jarvis.analysis.defects.vacancy import Vacancy import numpy as np from alignn.pretrained import get_prediction @@ -278,7 +280,10 @@ def __init__( self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) - model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"])) + if config['model']['name']=='alignn_ff2': + model = ALIGNNFF2(ALIGNNFF2Config(**config["model"])) + if config['model']['name']=='alignn_atomwise': + model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"])) model.state_dict() model.load_state_dict( torch.load( diff --git a/alignn/graphs.py b/alignn/graphs.py index 6d34df0..6471dfc 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -57,8 +57,20 @@ def temp_graph(atoms=None, cutoff=4.0, atom_features="cgcnn", dtype="float32"): g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=dtype) g.edata["r"] = torch.tensor(r, dtype=dtype) g.edata["d"] = torch.tensor(d, dtype=dtype) + g.edata["pbc_offset"] = torch.tensor(images, dtype=dtype) g.edata["images"] = torch.tensor(images, dtype=dtype) - g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=dtype) + #g.edata["lattice"] = torch.tensor(torch.repeat_interleave(torch.tensor(atoms.lattice_mat.flatten()), atoms.num_atoms), dtype=dtype) + node_type=torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) + g.ndata['node_type']=node_type + lattice_mat = atoms.lattice_mat + g.ndata["lattice"] = torch.tensor( + [lattice_mat for ii in range(g.num_nodes())] + , dtype=dtype) + g.edata["lattice"] = torch.tensor( + [lattice_mat for ii in range(g.num_edges())] + , dtype=dtype) + #g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=dtype) + g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords, dtype=dtype) g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype) return g, u, v, r @@ -456,7 +468,7 @@ def atom_dgl_multigraph( # use_canonize: bool = False, use_lattice_prop: bool = False, cutoff_extra=3.5, - dtype=torch.float32, + dtype="float32", ): """Obtain a DGLGraph for Atoms object.""" # print('id',id) @@ -835,6 +847,7 @@ def __init__( classification=False, id_tag="jid", sampler=None, + lattices=None, dtype="float32", ): """Pytorch Dataset for atomistic graphs. @@ -853,6 +866,7 @@ def __init__( self.target_stress = target_stress self.line_graph = line_graph print("df", df) + self.lattices = lattices self.labels = self.df[target] if ( @@ -897,6 +911,11 @@ def __init__( self.labels = torch.tensor(self.df[target]).type( torch.get_default_dtype() ) + self.lattices = [] + for ii, i in df.iterrows(): + self.lattices.append(Atoms.from_dict(i['atoms']).lattice_mat) + + self.lattices = torch.tensor(self.lattices).type(torch.get_default_dtype()) self.transform = transform features = self._get_attribute_lookup(atom_features) @@ -972,14 +991,15 @@ def __getitem__(self, idx): """Get StructureDataset sample.""" g = self.graphs[idx] label = self.labels[idx] + lattice = self.lattices[idx] # id = self.ids[idx] if self.transform: g = self.transform(g) if self.line_graph: - return g, self.line_graphs[idx], label + return g, self.line_graphs[idx], lattice, label - return g, label + return g, lattice, label def setup_standardizer(self, ids): """Atom-wise feature standardization transform.""" @@ -1000,22 +1020,22 @@ def setup_standardizer(self, ids): @staticmethod def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]): """Dataloader helper to batch graphs cross `samples`.""" - graphs, labels = map(list, zip(*samples)) + graphs, lattice, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) - return batched_graph, torch.tensor(labels) + return batched_graph, torch.tensor(lattices), torch.tensor(labels) @staticmethod def collate_line_graph( samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]] ): """Dataloader helper to batch graphs cross `samples`.""" - graphs, line_graphs, labels = map(list, zip(*samples)) + graphs, line_graphs, lattices, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) batched_line_graph = dgl.batch(line_graphs) if len(labels[0].size()) > 0: return batched_graph, batched_line_graph, torch.stack(labels) else: - return batched_graph, batched_line_graph, torch.tensor(labels) + return batched_graph, batched_line_graph, torch.tensor(lattices), torch.tensor(labels) """ diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index 7225788..9dbd9cc 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -58,11 +58,11 @@ def __getitem__(self, idx): with self.env.begin() as txn: serialized_data = txn.get(f"{idx}".encode()) if self.line_graph: - graph, line_graph, label = pk.loads(serialized_data) - return graph, line_graph, label + graph, line_graph, lattice,label = pk.loads(serialized_data) + return graph, line_graph, lattice,label else: - graph, label = pk.loads(serialized_data) - return graph, label + graph, lattice,label = pk.loads(serialized_data) + return graph, lattice,label def close(self): """Close connection.""" @@ -76,23 +76,23 @@ def __del__(self): def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]): """Dataloader helper to batch graphs cross `samples`.""" # print('samples',samples) - graphs, labels = map(list, zip(*samples)) + graphs, lattices,labels = map(list, zip(*samples)) # graphs, lgs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) - return batched_graph, torch.tensor(labels) + return batched_graph, torch.tensor(lattices), torch.tensor(labels) @staticmethod def collate_line_graph( samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]] ): """Dataloader helper to batch graphs cross `samples`.""" - graphs, line_graphs, labels = map(list, zip(*samples)) + graphs, line_graphs, lattices, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) batched_line_graph = dgl.batch(line_graphs) if len(labels[0].size()) > 0: - return batched_graph, batched_line_graph, torch.stack(labels) + return batched_graph, batched_line_graph, torch.tensor(lattices),torch.stack(labels) else: - return batched_graph, batched_line_graph, torch.tensor(labels) + return batched_graph, batched_line_graph, torch.stack(lattices),torch.tensor(labels) def get_torch_dataset( @@ -143,8 +143,9 @@ def get_torch_dataset( for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)): ids.append(d[id_tag]) # g, lg = Graph.atom_dgl_multigraph( + atoms=Atoms.from_dict(d["atoms"]) g = Graph.atom_dgl_multigraph( - Atoms.from_dict(d["atoms"]), + atoms, cutoff=float(cutoff), max_neighbors=max_neighbors, atom_features=atom_features, @@ -156,6 +157,7 @@ def get_torch_dataset( ) if line_graph: g, lg = g + lattice=torch.tensor(atoms.lattice_mat).type(torch.get_default_dtype()) label = torch.tensor(d[target]).type(torch.get_default_dtype()) # print('label',label,label.view(-1).long()) if classification: @@ -182,9 +184,9 @@ def get_torch_dataset( # labels.append(label) if line_graph: - serialized_data = pk.dumps((g, lg, label)) + serialized_data = pk.dumps((g, lg, lattice,label)) else: - serialized_data = pk.dumps((g, label)) + serialized_data = pk.dumps((g, lattice,label)) txn.put(f"{idx}".encode(), serialized_data) env.close() diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py index 40d8db1..f4acd8c 100644 --- a/alignn/models/alignn_ff2.py +++ b/alignn/models/alignn_ff2.py @@ -23,6 +23,8 @@ ) from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings +from matgl.layers._basis import RadialBesselFunction,FourierExpansion +from matgl.layers import MLP_norm torch.autograd.set_detect_anomaly(True) @@ -185,6 +187,154 @@ def forward( return x, y, z + +class AtomWise(nn.Module): + """A class representing an interatomic potential.""" + + __version__ = 3 + + def __init__( + self, + model: nn.Module, + data_mean: torch.Tensor | float = 0.0, + data_std: torch.Tensor | float = 1.0, + #element_refs: np.ndarray | None = None, + calc_forces: bool = True, + calc_stresses: bool = True, + calc_hessian: bool = False, + calc_magmom: bool = False, + calc_repuls: bool = False, + zbl_trainable: bool = False, + debug_mode: bool = False, + ): + """Initialize Potential from a model and elemental references. + + Args: + model: Model for predicting energies. + data_mean: Mean of target. + data_std: Std dev of target. + element_refs: Element reference values for each element. + calc_forces: Enable force calculations. + calc_stresses: Enable stress calculations. + calc_hessian: Enable hessian calculations. + calc_magmom: Enable site-wise property calculation. + calc_repuls: Whether the ZBL repulsion is included + zbl_trainable: Whether zbl repulsion is trainable + debug_mode: Return gradient of total energy with respect to atomic positions and lattices for checking + """ + super().__init__() + self.save_args(locals()) + self.model = model + self.calc_forces = calc_forces + self.calc_stresses = calc_stresses + self.calc_hessian = calc_hessian + self.calc_magmom = calc_magmom + #self.element_refs: AtomRef | None + self.debug_mode = debug_mode + self.calc_repuls = calc_repuls + + if calc_repuls: + self.repuls = NuclearRepulsion(self.model.cutoff, trainable=zbl_trainable) + + if element_refs is not None: + self.element_refs = AtomRef(property_offset=torch.tensor(element_refs, dtype=matgl.float_th)) + else: + self.element_refs = None + # for backward compatibility + if data_mean is None: + data_mean = 0.0 + self.register_buffer("data_mean", torch.tensor(data_mean, dtype=matgl.float_th)) + self.register_buffer("data_std", torch.tensor(data_std, dtype=matgl.float_th)) + + def forward( + self, + g: dgl.DGLGraph, + lat: torch.Tensor, + lg: dgl.DGLGraph | None = None, + ) -> tuple[torch.Tensor, ...]: + """Args: + g: DGL graph + lat: lattice + state_attr: State attrs + l_g: Line graph. + + Returns: + (energies, forces, stresses, hessian) or (energies, forces, stresses, hessian, site-wise properties) + """ + # st (strain) for stress calculations + result = {} + #st = lat.new_zeros([g.batch_size, 3, 3]) + #if self.calc_stresses: + # st.requires_grad_(True) + lattice = lat @ (torch.eye(3, device=lat.device) + st) + g.edata["lattice"] = torch.repeat_interleave(lattice, g.batch_num_edges(), dim=0) + g.edata["pbc_offshift"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) + g.ndata["pos"] = ( + g.ndata["frac_coords"].unsqueeze(dim=-1) * torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0) + ).sum(dim=1) + if self.calc_forces: + g.ndata["pos"].requires_grad_(True) + + total_energies = self.model(g,lg) + + total_energies = self.data_std * total_energies + self.data_mean + + if self.calc_repuls: + total_energies += self.repuls(self.model.element_types, g) + + if self.element_refs is not None: + property_offset = torch.squeeze(self.element_refs(g)) + total_energies += property_offset + + forces = torch.zeros(1) + stresses = torch.zeros(1) + hessian = torch.zeros(1) + + grad_vars = [g.ndata["pos"], st] if self.calc_stresses else [g.ndata["pos"]] + + if self.calc_forces: + grads = grad( + total_energies, + grad_vars, + grad_outputs=torch.ones_like(total_energies), + create_graph=True, + retain_graph=True, + ) + forces = -grads[0] + + if self.calc_hessian: + r = -grads[0].view(-1) + s = r.size(0) + hessian = total_energies.new_zeros((s, s)) + for iatom in range(s): + tmp = grad([r[iatom]], g.ndata["pos"], retain_graph=iatom < s)[0] + if tmp is not None: + hessian[iatom] = tmp.view(-1) + + if self.calc_stresses: + volume = ( + torch.abs(torch.det(lattice.float())).half() + if matgl.float_th == torch.float16 + else torch.abs(torch.det(lattice)) + ) + sts = -grads[1] + scale = 1.0 / volume * -160.21766208 + sts = [i * j for i, j in zip(sts, scale)] if sts.dim() == 3 else [sts * scale] + stresses = torch.cat(sts) + + if self.debug_mode: + return total_energies, grads[0], grads[1] + + if self.calc_magmom: + return total_energies, forces, stresses, hessian, g.ndata["magmom"] + result['out']=total_energies + result['grad']=forces + result['stresses']=stresses + result['atomwise_pred']=atomwise_pred + + return result + + class MLPLayer(nn.Module): """Multilayer perceptron layer helper.""" @@ -227,18 +377,10 @@ def __init__( config.atom_input_features, config.hidden_features ) if self.config.bond_exp_basis == "bessel": - self.edge_embedding = nn.Sequential( - BesselExpansion( - # RadialBesselFunction( - vmin=0, - vmax=8.0, - bins=config.edge_input_features, - ), - MLPLayer( - config.edge_input_features, config.embedding_features - ), - MLPLayer(config.embedding_features, config.hidden_features), - ) + self.bond_expansion = RadialBesselFunction(max_n=config.max_n, cutoff=config.inner_cutoff, learnable=False) + self.edge_embedding = MLP_norm([config.max_n, config.hidden_features],bias_last=False) + #self.edge_embedding = MLP_norm([config.edge_input_features, config.hidden_features],bias_last=False) + #self.bond_expansion = RadialBesselFunction(max_n=config.edge_input_features, cutoff=config.inner_cutoff, learnable=True) else: self.edge_embedding = nn.Sequential( RBFExpansion( @@ -260,13 +402,8 @@ def __init__( MLPLayer(config.embedding_features, config.hidden_features), ) # not tested elif self.config.angle_exp_basis == "bessel": - self.angle_embedding = nn.Sequential( - BesselExpansion(), - MLPLayer( - config.triplet_input_features, config.embedding_features - ), - MLPLayer(config.embedding_features, config.hidden_features), - ) # not tested + self.angle_expansion = FourierExpansion(max_f=config.max_f, learnable=False) + self.angle_embedding = MLP_norm([2*config.max_f+1, config.hidden_features],bias_last=False) elif self.config.angle_exp_basis == "fourier": self.angle_embedding = nn.Sequential( FourierExpansion(), @@ -360,11 +497,14 @@ def forward( features = self.extra_feature_embedding(features) x = g.ndata.pop("atom_features") x = self.atom_embedding(x) - # r=g.edata['r'] - r, bondlength = compute_pair_vector_and_distance(g) + check_lg=True + g.edata["pbc_offshift"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) + g.ndata["cart_coords"] = (g.ndata["frac_coords"].unsqueeze(dim=-1) * g.ndata["lattice"][0]).sum(dim=1) if self.config.calculate_gradient: - r.requires_grad_(True) - # print('gradient') + #g.edata["images"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) + #torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0)).sum(dim=1) + g.ndata["cart_coords"].requires_grad_(True) + r, bondlength = compute_pair_vector_and_distance(g) bondlength = torch.norm(r, dim=1) g.edata["d"] = bondlength g.edata["r"] = r @@ -381,9 +521,13 @@ def forward( # bond_expansion # smooth_cutoff * bond_expansion # ) - # y = self.edge_embedding(bondlength) - z = self.angle_embedding(lg.edata.pop("h")) - + if self.config.bond_exp_basis=='bessel': + z = self.angle_embedding(self.angle_expansion(lg.edata.pop("h"))) + y = self.edge_embedding(self.bond_expansion(bondlength)) + else: + y = self.edge_embedding(bondlength) + z = self.angle_embedding(lg.edata.pop("h")) + if self.config.use_cutoff_function: if self.config.multiply_cutoff: c_off = cutoff_function_based_edges( @@ -401,7 +545,9 @@ def forward( ) y = self.edge_embedding(bondlength) else: + #print('bondlength',bondlength,bondlength.shape) y = self.edge_embedding(bondlength) + #""" # ALIGNN updates: update node, edge, triplet features for alignn_layer in self.alignn_layers: x, y, z = alignn_layer(g, lg, x, y, z) @@ -436,6 +582,13 @@ def forward( if self.config.calculate_gradient: en_out = out + #g.edata["images"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) + #torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0)).sum(dim=1) + #g.ndata["cart_coords"].requires_grad_(True) + grad_vars = [g.ndata["cart_coords"]] + grads = grad(g.num_nodes()*en_out,grad_vars,grad_outputs=torch.ones_like(en_out),create_graph=True,retain_graph=True) + forces_out = -grads[0] + # force calculation based on bond displacement vectors # autograd gives dE / d{r_{i->j}} pair_forces = ( @@ -539,7 +692,7 @@ def forward( # out = torch.max(out,dim=1) out = self.softmax(out) result["out"] = out - result["grad"] = forces + result["grad"] = forces_out result["stresses"] = stress result["atomwise_pred"] = atomwise_pred # print(result) diff --git a/alignn/models/utils.py b/alignn/models/utils.py index b8937f5..af7e3b5 100644 --- a/alignn/models/utils.py +++ b/alignn/models/utils.py @@ -267,8 +267,9 @@ def _k(self, l_x: int, m: int) -> float: def compute_pair_vector_and_distance(g: dgl.DGLGraph): """Calculate bond vectors and distances using dgl graphs.""" - dst_pos = g.ndata["coords"][g.edges()[1]] + g.edata["images"] - src_pos = g.ndata["coords"][g.edges()[0]] + #print('g.edges()',g.ndata["cart_coords"][g.edges()[1]].shape,g.edata["pbc_offshift"].shape) + dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["pbc_offshift"] + src_pos = g.ndata["cart_coords"][g.edges()[0]] bond_vec = dst_pos - src_pos bond_dist = torch.norm(bond_vec, dim=1) diff --git a/alignn/train.py b/alignn/train.py index 730d2da..24c925e 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -169,7 +169,11 @@ def train_dgl( net = _model.get(config.model.name)(config.model) else: net = model - + from matgl.models import CHGNet, M3GNet + from matgl.utils.training import ModelLightningModule, PotentialLightningModule + #model = M3GNet(element_types=['Si'], is_intensive=False) + model = CHGNet(element_types=['Si'], is_intensive=False,threebody_cutoff=4) + net = PotentialLightningModule(model=model, stress_weight=0.0001, include_line_graph=True) print("net parameters", sum(p.numel() for p in net.parameters())) # print("device", device) net.to(device) @@ -229,7 +233,8 @@ def train_dgl( # info["id"] = jid optimizer.zero_grad() if (config.model.alignn_layers) > 0: - result = net([dats[0].to(device), dats[1].to(device)]) + result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) + #result = net([dats[0].to(device), dats[1].to(device),lat=dats[2].to(device)]) else: result = net(dats[0].to(device)) # info = {} @@ -346,7 +351,8 @@ def train_dgl( optimizer.zero_grad() # result = net([dats[0].to(device), dats[1].to(device)]) if (config.model.alignn_layers) > 0: - result = net([dats[0].to(device), dats[1].to(device)]) + #result = net([dats[0].to(device), dats[2].to(device), dats[1].to(device)]) + result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) else: result = net(dats[0].to(device)) # info = {} From 8bb073f6dd92d68d42466bc06fb50f5ba8eb3de9 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 3 Nov 2024 14:24:37 -0500 Subject: [PATCH 08/37] New models for testing only. --- alignn/config.py | 2 + alignn/ff/ff.py | 112 +-- alignn/graphs.py | 115 ++- alignn/models/alignn.py | 6 +- alignn/models/alignn_atomwise.py | 241 +++--- alignn/models/alignn_eff.py | 1281 ++++++++++++++++++++++++++++++ alignn/models/alignn_ff2.py | 679 +++++----------- alignn/models/utils.py | 530 +++++++++++- alignn/train.py | 52 +- 9 files changed, 2298 insertions(+), 720 deletions(-) create mode 100644 alignn/models/alignn_eff.py diff --git a/alignn/config.py b/alignn/config.py index dc838f2..9dbe7ec 100644 --- a/alignn/config.py +++ b/alignn/config.py @@ -7,6 +7,7 @@ from alignn.utils import BaseSettings from alignn.models.alignn import ALIGNNConfig from alignn.models.alignn_ff2 import ALIGNNFF2Config +from alignn.models.alignn_eff import ALIGNNeFFConfig from alignn.models.alignn_atomwise import ALIGNNAtomWiseConfig # import torch @@ -211,6 +212,7 @@ class TrainingConfig(BaseSettings): model: Union[ ALIGNNConfig, ALIGNNFF2Config, + ALIGNNeFFConfig, ALIGNNAtomWiseConfig, # CGCNNConfig, # ICGCNNConfig, diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index de851f2..71e224e 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -31,7 +31,8 @@ from alignn.graphs import Graph from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig from alignn.models.alignn_ff2 import ALIGNNFF2, ALIGNNFF2Config - +from alignn.models.alignn_eff import ALIGNNeFF, ALIGNNeFFConfig +from alignn.config import TrainingConfig from jarvis.analysis.defects.vacancy import Vacancy import numpy as np from alignn.pretrained import get_prediction @@ -51,6 +52,7 @@ from matplotlib.gridspec import GridSpec from sklearn.metrics import mean_absolute_error from tqdm import tqdm +import torch try: from gpaw import GPAW, PW @@ -218,40 +220,39 @@ def __init__( ignore_bad_restart_file=ignore_bad_restart_file, label=None, include_stress=True, + intensive=True, atoms=None, directory=".", device=None, + model=None, + config=None, path=".", model_filename="best_model.pt", config_filename="config.json", - keep_data_order=False, - classification_threshold=None, - batch_size=None, - epochs=None, output_dir=None, - stress_wt=1.0, - force_multiplier=1.0, - force_mult_natoms=False, batch_stress=True, + stress_wt=0.1, **kwargs, ): """Initialize class.""" super(AlignnAtomwiseCalculator, self).__init__( restart, ignore_bad_restart_file, label, atoms, directory, **kwargs ) + self.model = model self.device = device + self.intensive = intensive + # config = loadjson(os.path.join(path, config_filename)) + # print('config',config) + self.config = config self.include_stress = include_stress self.stress_wt = stress_wt - config = loadjson(os.path.join(path, config_filename)) - self.config = config - self.force_multiplier = force_multiplier - self.force_mult_natoms = force_mult_natoms - # config = TrainingConfig(**config) - # if type(config) is dict: - # try: - # config = TrainingConfig(**config) - # except Exception as exp: - # print("Check", exp) + # self.force_multiplier = force_multiplier + # self.force_mult_natoms = force_mult_natoms + if self.config is None: + config = loadjson(os.path.join(path, config_filename)) + # print('config',config) + # config=TrainingConfig(**config).dict() + self.config = config if self.include_stress: self.implemented_properties = ["energy", "forces", "stress"] if config["model"]["stresswise_weight"] == 0: @@ -259,37 +260,31 @@ def __init__( else: self.implemented_properties = ["energy", "forces"] - config["keep_data_order"] = keep_data_order - if classification_threshold is not None: - config["classification_threshold"] = float( - classification_threshold - ) - if output_dir is not None: - config["output_dir"] = output_dir - if batch_size is not None: - config["batch_size"] = int(batch_size) - if epochs is not None: - config["epochs"] = int(epochs) if batch_stress is not None: config["model"]["batch_stress"] = batch_stress - # config["model.output_features"] = 1 - # print('config',config["model"]) import torch if self.device is None: self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) - if config['model']['name']=='alignn_ff2': - model = ALIGNNFF2(ALIGNNFF2Config(**config["model"])) - if config['model']['name']=='alignn_atomwise': - model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"])) - model.state_dict() - model.load_state_dict( - torch.load( - os.path.join(path, model_filename), map_location=self.device + if self.model is None: + + if config["model"]["name"] == "alignn_ff2": + model = ALIGNNFF2(ALIGNNFF2Config(**config["model"])) + if config["model"]["name"] == "alignn_atomwise": + model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"])) + if config["model"]["name"] == "alignn_eff": + model = ALIGNNeFF(ALIGNNeFFConfig(**config["model"])) + model.state_dict() + model.load_state_dict( + torch.load( + os.path.join(path, model_filename), + map_location=self.device, + ) ) - ) + else: + model = self.model model.to(device) model.eval() @@ -309,27 +304,38 @@ def calculate(self, atoms, properties=None, system_changes=None): atom_features=self.config["atom_features"], use_canonize=self.config["use_canonize"], ) + # print('g',g) + # print('lg',lg) + # print('config',self.config) + if self.config["model"]["alignn_layers"] > 0: # g,lg = g - result = self.net((g.to(self.device), lg.to(self.device))) + result = self.net( + ( + g.to(self.device), + lg.to(self.device), + torch.tensor(atoms.cell) + .type(torch.get_default_dtype()) + .to(self.device), + ) + ) else: - result = self.net((g.to(self.device))) + result = self.net( + (g.to(self.device, torch.tensor(atoms.cell).to(self.device))) + ) # print ('stress',result["stress"].detach().numpy()) - if self.force_mult_natoms: - mult = num_atoms + if self.intensive: + energy = result["out"].detach().cpu().numpy() * num_atoms else: - mult = 1 - # print('result["stresses"]',result["stresses"],result["stresses"].shape) + energy = result["out"].detach().cpu().numpy() + self.results = { - "energy": result["out"].detach().cpu().numpy() * num_atoms, - "forces": result["grad"].detach().cpu().numpy() - * mult - * self.force_multiplier, + "energy": energy, # * num_atoms, + "forces": result["grad"].detach().cpu().numpy(), "stress": full_3x3_to_voigt_6_stress( - result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() + np.eye(3) + # result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() ) - * self.stress_wt - # * num_atoms, / 160.21766208, "dipole": np.zeros(3), "charges": np.zeros(len(atoms)), diff --git a/alignn/graphs.py b/alignn/graphs.py index 6471dfc..686b7a5 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -16,8 +16,12 @@ import dgl from tqdm import tqdm +# import matgl -def temp_graph(atoms=None, cutoff=4.0, atom_features="cgcnn", dtype="float32"): + +def temp_graph( + atoms=None, cutoff=4.0, atom_features="atomic_number", dtype="float32" +): """Construct a graph for a given cutoff.""" TORCH_DTYPES = { "float16": torch.float16, @@ -52,26 +56,29 @@ def temp_graph(atoms=None, cutoff=4.0, atom_features="cgcnn", dtype="float32"): # Create DGL graph g = dgl.graph((np.array(u), np.array(v))) - + atom_feats = np.array(atom_feats) # Add data to the graph with the specified dtype + # print('atom_feats',atom_feats,atom_feats.shape) g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=dtype) - g.edata["r"] = torch.tensor(r, dtype=dtype) + g.ndata["Z"] = torch.tensor(atom_feats, dtype=torch.int64) + g.edata["r"] = torch.tensor(np.array(r), dtype=dtype) g.edata["d"] = torch.tensor(d, dtype=dtype) g.edata["pbc_offset"] = torch.tensor(images, dtype=dtype) + g.edata["pbc_offshift"] = torch.tensor(images, dtype=dtype) g.edata["images"] = torch.tensor(images, dtype=dtype) - #g.edata["lattice"] = torch.tensor(torch.repeat_interleave(torch.tensor(atoms.lattice_mat.flatten()), atoms.num_atoms), dtype=dtype) - node_type=torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) - g.ndata['node_type']=node_type + # g.edata["lattice"] = torch.tensor(torch.repeat_interleave(torch.tensor(atoms.lattice_mat.flatten()), atoms.num_atoms), dtype=dtype) + node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) + g.ndata["node_type"] = node_type lattice_mat = atoms.lattice_mat - g.ndata["lattice"] = torch.tensor( - [lattice_mat for ii in range(g.num_nodes())] - , dtype=dtype) - g.edata["lattice"] = torch.tensor( - [lattice_mat for ii in range(g.num_edges())] - , dtype=dtype) - #g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=dtype) + # g.ndata["lattice"] = torch.tensor( + # [lattice_mat for ii in range(g.num_nodes())] + # , dtype=dtype) + # g.edata["lattice"] = torch.tensor( + # [lattice_mat for ii in range(g.num_edges())] + # , dtype=dtype) + g.ndata["pos"] = torch.tensor(atoms.cart_coords, dtype=dtype) g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords, dtype=dtype) - g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype) + # g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype) return g, u, v, r @@ -80,33 +87,36 @@ def radius_graph_jarvis( atoms, cutoff_extra=0.5, cutoff=4.0, - atom_features="cgcnn", + atom_features="atomic_number", line_graph=True, dtype="float32", + max_attempts=10, ): """Construct radius graph with dynamic cutoff.""" - while True: - try: - # Attempt to create the graph - g, u, v, r = temp_graph( - atoms=atoms, - cutoff=cutoff, - atom_features=atom_features, - dtype=dtype, - ) - # Check if all atoms are included as nodes - if g.num_nodes() == len(atoms.elements): - # print(f"Graph constructed with cutoff: {cutoff}") - break # Exit the loop when successful - # Increment the cutoff if the graph is incomplete - cutoff += cutoff_extra - # print(f"Increasing cutoff to: {cutoff}") - - except Exception as exp: - # Handle exceptions and try again - print(f"Graph construction failed: {exp,cutoff}") - cutoff += cutoff_extra # Try with a larger cutoff - + count = 0 + while count <= max_attempts: + # try: + # Attempt to create the graph + count += 1 + g, u, v, r = temp_graph( + atoms=atoms, + cutoff=cutoff, + atom_features=atom_features, + dtype=dtype, + ) + # Check if all atoms are included as nodes + if g.num_nodes() == len(atoms.elements): + # print(f"Graph constructed with cutoff: {cutoff}") + break # Exit the loop when successful + # Increment the cutoff if the graph is incomplete + cutoff += cutoff_extra + # print(f"Increasing cutoff to: {cutoff}") + # except Exception as exp: + # # Handle exceptions and try again + # print(f"Graph construction failed: {exp,cutoff}") + # cutoff += cutoff_extra # Try with a larger cutoff + if count >= max_attempts: + raise ValueError("Failed after", max_attempts, atoms) # Optional: Create a line graph if requested if line_graph: lg = g.line_graph(shared=True) @@ -506,18 +516,31 @@ def atom_dgl_multigraph( # u, v, r = build_undirected_edgedata(atoms, edges) # build up atom attribute tensor + comp = atoms.composition.to_dict() + comp_dict = {} + c_ind = 0 + for ii, jj in comp.items(): + if ii not in comp_dict: + comp_dict[ii] = c_ind + c_ind += 1 sps_features = [] + node_types = [] for ii, s in enumerate(atoms.elements): feat = list(get_node_attributes(s, atom_features=atom_features)) # if include_prdf_angles: # feat=feat+list(prdf[ii])+list(adf[ii]) sps_features.append(feat) + node_types.append(comp_dict[s]) sps_features = np.array(sps_features) node_features = torch.tensor(sps_features).type( torch.get_default_dtype() ) g = dgl.graph((u, v)) g.ndata["atom_features"] = node_features + g.ndata["node_type"] = torch.tensor(node_types, dtype=torch.int64) + node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) + g.ndata["node_type"] = node_type + # print('g.ndata["node_type"]',g.ndata["node_type"]) g.edata["r"] = torch.tensor(r).type(torch.get_default_dtype()) # images=torch.tensor(images).type(torch.get_default_dtype()) # print('images',images.shape,r.shape) @@ -530,6 +553,9 @@ def atom_dgl_multigraph( g.ndata["coords"] = torch.tensor(atoms.cart_coords).type( torch.get_default_dtype() ) + g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords).type( + torch.get_default_dtype() + ) if use_lattice_prop: lattice_prop = np.array( [atoms.lattice.lat_lengths(), atoms.lattice.lat_angles()] @@ -913,9 +939,11 @@ def __init__( ) self.lattices = [] for ii, i in df.iterrows(): - self.lattices.append(Atoms.from_dict(i['atoms']).lattice_mat) - - self.lattices = torch.tensor(self.lattices).type(torch.get_default_dtype()) + self.lattices.append(Atoms.from_dict(i["atoms"]).lattice_mat) + + self.lattices = torch.tensor(self.lattices).type( + torch.get_default_dtype() + ) self.transform = transform features = self._get_attribute_lookup(atom_features) @@ -1035,7 +1063,12 @@ def collate_line_graph( if len(labels[0].size()) > 0: return batched_graph, batched_line_graph, torch.stack(labels) else: - return batched_graph, batched_line_graph, torch.tensor(lattices), torch.tensor(labels) + return ( + batched_graph, + batched_line_graph, + torch.tensor(lattices), + torch.tensor(labels), + ) """ diff --git a/alignn/models/alignn.py b/alignn/models/alignn.py index f3ed795..722e88f 100644 --- a/alignn/models/alignn.py +++ b/alignn/models/alignn.py @@ -16,8 +16,10 @@ from torch import nn from torch.nn import functional as F -from alignn.models.utils import RBFExpansion -from alignn.utils import BaseSettings +# from alignn.models.utils import RBFExpansion +# from alignn.utils import BaseSettings + +from pydantic_settings import BaseSettings class ALIGNNConfig(BaseSettings): diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 06e0fab..9674e6a 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -15,7 +15,11 @@ from typing import Literal from torch import nn from torch.nn import functional as F -from alignn.models.utils import RBFExpansion +from alignn.models.utils import ( + RBFExpansion, + compute_cartesian_coordinates, + compute_pair_vector_and_distance, +) from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings @@ -56,6 +60,7 @@ class ALIGNNAtomWiseConfig(BaseSettings): lg_on_fly: bool = False # will make True as default soon batch_stress: bool = True multiply_cutoff: bool = False + use_penalty: bool = False extra_features: int = 0 exponent: int = 3 @@ -372,11 +377,11 @@ def forward( z: angle features (lg.edata) """ if len(self.alignn_layers) > 0: - g, lg = g + g, lg, lat = g lg = lg.local_var() - + # print('lg',lg) # angle features (fixed) - z = self.angle_embedding(lg.edata.pop("h")) + # z = self.angle_embedding(lg.edata.pop("h")) if self.config.extra_features != 0: features = g.ndata["extra_features"] # print('features',features,features.shape) @@ -391,7 +396,21 @@ def forward( x = self.atom_embedding(x) # print('x2',x,x.shape) r = g.edata["r"] - if self.config.calculate_gradient: + if self.config.include_pos_deriv: + # Not tested yet + g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat) + g.ndata["cart_coords"].requires_grad_(True) + r, bondlength = compute_pair_vector_and_distance(g) + lg = g.line_graph(shared=True) + lg.ndata["r"] = r + lg.apply_edges(compute_bond_cosines) + + # bondlength = torch.norm(r, dim=1) + # y = self.edge_embedding(bondlength) + if ( + self.config.calculate_gradient + and not self.config.include_pos_deriv + ): r.requires_grad_(True) bondlength = torch.norm(r, dim=1) # mask = bondlength >= self.config.inner_cutoff @@ -400,6 +419,7 @@ def forward( # re-compute bond angle cosines here to ensure # the three-body interactions are fully included # in the autograd graph. don't rely on dataloader/caching. + lg.ndata["r"] = r # overwrites precomputed r values lg.apply_edges(compute_bond_cosines) # overwrites precomputed h z = self.angle_embedding(lg.edata.pop("h")) @@ -458,117 +478,136 @@ def forward( forces = torch.empty(1) # gradient = torch.empty(1) stress = torch.empty(1) + if self.config.energy_mult_natoms: + en_out = out * g.num_nodes() + else: + en_out = out + if self.config.use_penalty: + penalty_factor = 500.0 # Penalty weight, tune as needed + penalty_factor = 0.01 # Penalty weight, tune as needed + penalty_threshold = 1.0 # 1 angstrom + + penalties = torch.where( + bondlength < penalty_threshold, + penalty_factor * (penalty_threshold - bondlength), + torch.zeros_like(bondlength), + ) + total_penalty = torch.sum(penalties) + en_out += total_penalty if self.config.calculate_gradient: if self.config.include_pos_deriv: - # Not tested yet - g.ndata["coords"].requires_grad_(True) - dx = [g.ndata["coords"], r] + dx = [g.ndata["cart_coords"]] + forces = ( + self.config.grad_multiplier + * grad( + en_out * g.num_nodes(), + dx, + grad_outputs=torch.ones_like(en_out), + create_graph=True, + retain_graph=True, + )[0] + ) else: dx = r - if self.config.energy_mult_natoms: - en_out = out * g.num_nodes() - else: - en_out = out - - # force calculation based on bond displacement vectors - # autograd gives dE / d{r_{i->j}} - pair_forces = ( - self.config.grad_multiplier - * grad( - en_out, - dx, - grad_outputs=torch.ones_like(en_out), - create_graph=True, - retain_graph=True, - )[0] - ) - if self.config.force_mult_natoms: - pair_forces *= g.num_nodes() + # force calculation based on bond displacement vectors + # autograd gives dE / d{r_{i->j}} + pair_forces = ( + self.config.grad_multiplier + * grad( + en_out, + dx, + grad_outputs=torch.ones_like(en_out), + create_graph=True, + retain_graph=True, + )[0] + ) + if self.config.force_mult_natoms: + pair_forces *= g.num_nodes() - # construct force_i = dE / d{r_i} - # reduce over bonds to get forces on each atom + # construct force_i = dE / d{r_i} + # reduce over bonds to get forces on each atom - # force_i contributions from r_{j->i} (in edges) - g.edata["pair_forces"] = pair_forces - g.update_all( - fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ji") - ) - if self.config.add_reverse_forces: - # reduce over reverse edges too! - # force_i contributions from r_{i->j} (out edges) - # aggregate pairwise_force_contributions over reversed edges - rg = dgl.reverse(g, copy_edata=True) - rg.update_all( - fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ij") + # force_i contributions from r_{j->i} (in edges) + g.edata["pair_forces"] = pair_forces + g.update_all( + fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ji") ) + if self.config.add_reverse_forces: + # reduce over reverse edges too! + # force_i contributions from r_{i->j} (out edges) + # aggregate pairwise_force_contributions over reversed edges + rg = dgl.reverse(g, copy_edata=True) + rg.update_all( + fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ij") + ) - # combine dE / d(r_{j->i}) and dE / d(r_{i->j}) - forces = torch.squeeze( - g.ndata["forces_ji"] - rg.ndata["forces_ij"] - ) - else: - forces = torch.squeeze(g.ndata["forces_ji"]) - - if self.config.stresswise_weight != 0: - # Under development, use with caution - # 1 eV/Angstrom3 = 160.21766208 GPa - # 1 GPa = 10 kbar - # Following Virial stress formula, assuming inital velocity = 0 - # Save volume as g.gdta['V']? - # print('pair_forces',pair_forces.shape) - # print('r',r.shape) - # print('g.ndata["V"]',g.ndata["V"].shape) - if not self.config.batch_stress: - # print('Not batch_stress') - stress = ( - -1 - * 160.21766208 - * ( - torch.matmul(r.T, pair_forces) - # / (2 * g.edata["V"]) - / (2 * g.ndata["V"][0]) - ) + # combine dE / d(r_{j->i}) and dE / d(r_{i->j}) + forces = torch.squeeze( + g.ndata["forces_ji"] - rg.ndata["forces_ij"] ) - # print("stress1", stress, stress.shape) - # print("g.batch_size", g.batch_size) else: - # print('Using batch_stress') - stresses = [] - count_edge = 0 - count_node = 0 - for graph_id in range(g.batch_size): - num_edges = g.batch_num_edges()[graph_id] - num_nodes = 0 - st = -1 * ( - 160.21766208 - * torch.matmul( - r[count_edge : count_edge + num_edges].T, - pair_forces[ - count_edge : count_edge + num_edges - ], + forces = torch.squeeze(g.ndata["forces_ji"]) + + if self.config.stresswise_weight != 0: + # Under development, use with caution + # 1 eV/Angstrom3 = 160.21766208 GPa + # 1 GPa = 10 kbar + # Following Virial stress formula, assuming inital velocity = 0 + # Save volume as g.gdta['V']? + # print('pair_forces',pair_forces.shape) + # print('r',r.shape) + # print('g.ndata["V"]',g.ndata["V"].shape) + if not self.config.batch_stress: + # print('Not batch_stress') + stress = ( + -1 + * 160.21766208 + * ( + torch.matmul(r.T, pair_forces) + # / (2 * g.edata["V"]) + / (2 * g.ndata["V"][0]) ) - / g.ndata["V"][count_node + num_nodes] ) + # print("stress1", stress, stress.shape) + # print("g.batch_size", g.batch_size) + else: + # print('Using batch_stress') + stresses = [] + count_edge = 0 + count_node = 0 + for graph_id in range(g.batch_size): + num_edges = g.batch_num_edges()[graph_id] + num_nodes = 0 + st = -1 * ( + 160.21766208 + * torch.matmul( + r[count_edge : count_edge + num_edges].T, + pair_forces[ + count_edge : count_edge + num_edges + ], + ) + / g.ndata["V"][count_node + num_nodes] + ) - count_edge = count_edge + num_edges - num_nodes = g.batch_num_nodes()[graph_id] - count_node = count_node + num_nodes - # print("stresses.append",stresses[-1],stresses[-1].shape) - for n in range(num_nodes): - stresses.append(st) - # stress = (stresses) - stress = self.config.stress_multiplier * torch.cat( - stresses - ) - # print("stress2", stress, stress.shape) - # virial = ( - # 160.21766208 - # * 10 - # * torch.einsum("ij, ik->jk", result["r"], result["dy_dr"]) - # / 2 - # ) # / ( g.ndata["V"][0]) + count_edge = count_edge + num_edges + num_nodes = g.batch_num_nodes()[graph_id] + count_node = count_node + num_nodes + # print("stresses.append",stresses[-1],stresses[-1].shape) + for n in range(num_nodes): + stresses.append(st) + # stress = (stresses) + stress = self.config.stress_multiplier * torch.cat( + stresses + ) + # print("stress2", stress, stress.shape) + # virial = ( + # 160.21766208 + # * 10 + # * torch.einsum("ij, ik->jk", result["r"], result["dy_dr"]) + # / 2 + # ) # / ( g.ndata["V"][0]) if self.link: out = self.link(out) diff --git a/alignn/models/alignn_eff.py b/alignn/models/alignn_eff.py new file mode 100644 index 0000000..9a9c177 --- /dev/null +++ b/alignn/models/alignn_eff.py @@ -0,0 +1,1281 @@ +from torch.autograd import grad +from math import pi +from typing import Any, Callable, Literal, cast +from collections.abc import Sequence +from torch.nn import Linear, Module +from jarvis.core.specie import get_element_full_names +import dgl.function as fn +from torch import Tensor, nn +from jarvis.core.atoms import Atoms +from alignn.graphs import Graph +from enum import Enum +import dgl +from pathlib import Path +import torch +from dgl import readout_nodes +import inspect +import json +import os +from alignn.utils import BaseSettings +from alignn.models.utils import ( + get_ewald_sum, + get_atomic_repulsion, + FourierExpansion, + RadialBesselFunction, + prune_edges_by_features, + _create_directed_line_graph, + compute_theta, + create_line_graph, + compute_pair_vector_and_distance, + polynomial_cutoff, +) + +torch.autograd.detect_anomaly() +DEFAULT_ELEMENTS = list(get_element_full_names().keys()) + + +class ALIGNNeFFConfig(BaseSettings): + """Hyperparameter schema for jarvisdgl.models.alignn.""" + + name: Literal["alignn_eff"] + alignn_layers: int = 4 + calculate_gradient: bool = True + output_features: int = 1 + atomwise_output_features: int = 0 + graphwise_weight: float = 1.0 + gradwise_weight: float = 20.0 + stresswise_weight: float = 0.0 + atomwise_weight: float = 0.0 + batch_stress: bool = True + + +class EFFLineGraphConv(nn.Module): + + def __init__( + self, + node_update_func: Module, + node_out_func: Module, + edge_update_func: Module | None, + node_weight_func: Module | None, + ): + """ + Args: + node_update_func: Update function for message between nodes (bonds) + node_out_func: Output function for nodes (bonds), after message aggregation + edge_update_func: edge update function (for angle features) + node_weight_func: layer node weight function. + """ + super().__init__() + + self.node_update_func = node_update_func + self.node_out_func = node_out_func + self.node_weight_func = node_weight_func + self.edge_update_func = edge_update_func + + @classmethod + def from_dims( + cls, + node_dims: list[int], + edge_dims: list[int] | None = None, + activation: Module | None = None, + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + node_weight_input_dims: int = 0, + ): + norm_kwargs = ( + {"batched_field": "edge"} if normalization == "graph" else None + ) + + node_update_func = GatedMLP_norm( + in_feats=node_dims[0], + dims=node_dims[1:], + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + norm_kwargs=norm_kwargs, + ) + node_out_func = nn.Linear( + in_features=node_dims[-1], out_features=node_dims[-1], bias=False + ) + + node_weight_func = ( + nn.Linear(node_weight_input_dims, node_dims[-1]) + if node_weight_input_dims > 0 + else None + ) + edge_update_func = ( + GatedMLP_norm( + in_feats=edge_dims[0], + dims=edge_dims[1:], + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + norm_kwargs=norm_kwargs, + ) + if edge_dims is not None + else None + ) + + return cls( + node_update_func=node_update_func, + node_out_func=node_out_func, + edge_update_func=edge_update_func, + node_weight_func=node_weight_func, + ) + + def _edge_udf(self, edges: dgl.udf.EdgeBatch) -> dict[str, Tensor]: + """Edge user defined update function. + + Update angle features (edges in bond graph) + + Args: + edges: edge batch + + Returns: + edge_update: edge features update + """ + bonds_i = edges.src["features"] # first bonds features + bonds_j = edges.dst["features"] # second bonds features + angle_ij = edges.data["features"] + atom_ij = edges.data["aux_features"] # center atom features + inputs = torch.hstack([bonds_i, angle_ij, atom_ij, bonds_j]) + messages_ij = self.edge_update_func(inputs, edges._graph) # type: ignore + return {"feat_update": messages_ij} + + def edge_update_(self, graph: dgl.DGLGraph) -> Tensor: + """Perform edge update -> update angle features. + + Args: + graph: bond graph (line graph of atom graph) + + Returns: + edge_update: edge features update + """ + graph.apply_edges(self._edge_udf) + edge_update = graph.edata["feat_update"] + return edge_update + + def node_update_( + self, graph: dgl.DGLGraph, shared_weights: Tensor | None + ) -> Tensor: + """Perform node update -> update bond features. + + Args: + graph: bond graph (line graph of atom graph) + shared_weights: node message shared weights + + Returns: + node_update: bond features update + """ + src, dst = graph.edges() + bonds_i = graph.ndata["features"][src] # first bond feature + bonds_j = graph.ndata["features"][dst] # second bond feature + angle_ij = graph.edata["features"] + atom_ij = graph.edata["aux_features"] # center atom features + inputs = torch.hstack([bonds_i, angle_ij, atom_ij, bonds_j]) + + messages = self.node_update_func(inputs, graph) + + # smooth out messages with layer-wise weights + if self.node_weight_func is not None: + rbf = graph.ndata["bond_expansion"] + weights = self.node_weight_func(rbf) + weights_i, weights_j = weights[src], weights[dst] + messages = messages * weights_i * weights_j + + # smooth out messages with shared weights + if shared_weights is not None: + weights_i, weights_j = shared_weights[src], shared_weights[dst] + messages = messages * weights_i * weights_j + + # message passing + graph.edata["message"] = messages + graph.update_all( + fn.copy_e("message", "message"), fn.sum("message", "feat_update") + ) + + # update nodes + node_update = self.node_out_func( + graph.ndata["feat_update"] + ) # the bond update + + return node_update + + def forward( + self, + graph: dgl.DGLGraph, + node_features: Tensor, + edge_features: Tensor, + aux_edge_features: Tensor, + shared_node_weights: Tensor | None, + ) -> tuple[Tensor, Tensor]: + with graph.local_scope(): + graph.ndata["features"] = node_features + graph.edata["features"] = edge_features + graph.edata["aux_features"] = aux_edge_features + + # node (bond) update + node_update = self.node_update_(graph, shared_node_weights) + new_node_features = node_features + node_update + graph.ndata["features"] = new_node_features + + # edge (angle) update (should angle update be done before node update?) + if self.edge_update_func is not None: + edge_update = self.edge_update_(graph) + new_edge_features = edge_features + edge_update + graph.edata["features"] = new_edge_features + else: + new_edge_features = edge_features + + return new_node_features, new_edge_features + + +class GatedMLP_norm(nn.Module): + """An implementation of a Gated multi-layer perceptron constructed with MLP.""" + + def __init__( + self, + in_feats: int, + dims: Sequence[int], + activation: nn.Module | None = None, + activate_last: bool = True, + use_bias: bool = True, + bias_last: bool = True, + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + norm_kwargs: dict[str, Any] | None = None, + ): + """:param in_feats: Dimension of input features. + :param dims: Architecture of neural networks. + :param activation: non-linear activation module. + :param activate_last: Whether applying activation to last layer or not. + :param use_bias: Whether to use a bias in linear layers. + :param bias_last: Whether applying bias to last layer or not. + :param normalization: normalization name. + :param normalize_hidden: Whether to normalize output of hidden layers. + :param norm_kwargs: Keyword arguments for normalization layer. + """ + super().__init__() + self.in_feats = in_feats + self.dims = [in_feats, *dims] + self._depth = len(dims) + self.use_bias = use_bias + self.activate_last = activate_last + + activation = activation if activation is not None else nn.SiLU() + self.layers = MLP_norm( + self.dims, + activation=activation, + activate_last=True, + use_bias=use_bias, + bias_last=bias_last, + normalization=normalization, + normalize_hidden=normalize_hidden, + norm_kwargs=norm_kwargs, + ) + self.gates = MLP_norm( + self.dims, + activation, + activate_last=False, + use_bias=use_bias, + bias_last=bias_last, + normalization=normalization, + normalize_hidden=normalize_hidden, + norm_kwargs=norm_kwargs, + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs: torch.Tensor, graph=None) -> torch.Tensor: + return self.layers(inputs, graph) * self.sigmoid( + self.gates(inputs, graph) + ) + + +class EFFBondGraphBlock(nn.Module): + """A EFF atom graph block as a sequence of operations involving a message passing layer over the bond graph.""" + + def __init__( + self, + num_atom_feats: int, + num_bond_feats: int, + num_angle_feats: int, + activation: Module, + bond_hidden_dims: Sequence[int], + angle_hidden_dims: Sequence[int] | None, + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + rbf_order: int = 0, + bond_dropout: float = 0.0, + angle_dropout: float = 0.0, + ): + """. + + Args: + num_atom_feats: number of atom features + num_bond_feats: number of bond features + num_angle_feats: number of angle features + activation: activation function + bond_hidden_dims: dimensions of hidden layers of bond graph convolution + angle_hidden_dims: dimensions of hidden layers of angle update function + Default = None + normalization: Normalization type to use in update functions. either "graph" or "layer" + If None, no normalization is applied. + Default = None + normalize_hidden: Whether to normalize hidden features. + Default = False + rbf_order (int): RBF order specifying input dimensions for linear layer + specifying message weights. If 0, no layer-wise weights are used. + Default = 0 + bond_dropout (float): dropout probability for bond graph convolution. + Default = 0.0 + angle_dropout (float): dropout probability for angle update function. + Default = 0.0 + """ + super().__init__() + + node_input_dim = 2 * num_bond_feats + num_angle_feats + num_atom_feats + node_dims = [node_input_dim, *bond_hidden_dims, num_bond_feats] + edge_dims = ( + [node_input_dim, *angle_hidden_dims, num_angle_feats] + if angle_hidden_dims is not None + else None + ) + + self.conv_layer = EFFLineGraphConv.from_dims( + node_dims=node_dims, + edge_dims=edge_dims, + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + node_weight_input_dims=rbf_order, + ) + + self.bond_dropout = ( + nn.Dropout(bond_dropout) if bond_dropout > 0.0 else nn.Identity() + ) + self.angle_dropout = ( + nn.Dropout(angle_dropout) if angle_dropout > 0.0 else nn.Identity() + ) + + def forward( + self, + graph: dgl.DGLGraph, + atom_features: Tensor, + bond_features: Tensor, + angle_features: Tensor, + shared_node_weights: Tensor | None, + ) -> tuple[Tensor, Tensor]: + """Perform convolution in BondGraph to update bond and angle features. + + Args: + graph: bond graph (line graph of atom graph) + atom_features: atom features + bond_features: bond features + angle_features: concatenated center atom and angle features + shared_node_weights: shared node message weights + + Returns: + tuple: update bond features, update angle features + """ + node_features = bond_features[graph.ndata["bond_index"]] + edge_features = angle_features + aux_edge_features = atom_features[graph.edata["center_atom_index"]] + + bond_features_, angle_features = self.conv_layer( + graph, + node_features, + edge_features, + aux_edge_features, + shared_node_weights, + ) + + bond_features_ = self.bond_dropout(bond_features_) + angle_features = self.angle_dropout(angle_features) + + bond_features[graph.ndata["bond_index"]] = bond_features_ + + return bond_features, angle_features + + +class EFFGraphConv(nn.Module): + """A EFF atom graph convolution layer in DGL.""" + + def __init__( + self, + node_update_func: Module, + node_out_func: Module, + edge_update_func: Module | None, + node_weight_func: Module | None, + edge_weight_func: Module | None, + state_update_func: Module | None, + ): + """ + Args: + node_update_func: Update function for message between nodes (atoms) + node_out_func: Output function for nodes (atoms), after message aggregation + edge_update_func: Update function for edges (bonds). If None is given, the + edges are not updated. + node_weight_func: Weight function for radial basis functions. + If None is given, no layer-wise weights will be used. + edge_weight_func: Weight function for radial basis functions + If None is given, no layer-wise weights will be used. + state_update_func: Update function for state feats. + """ + super().__init__() + self.include_state = state_update_func is not None + self.edge_update_func = edge_update_func + self.edge_weight_func = edge_weight_func + self.node_update_func = node_update_func + self.node_out_func = node_out_func + self.node_weight_func = node_weight_func + self.state_update_func = state_update_func + + @classmethod + def from_dims( + cls, + activation: Module, + node_dims: Sequence[int], + edge_dims: Sequence[int] | None = None, + state_dims: Sequence[int] | None = None, + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + rbf_order: int = 0, + ): + """Create a EFFAtomGraphConv layer from dimensions. + + Args: + activation: activation function + node_dims: NN architecture for node update function given as a list of + dimensions of each layer. + edge_dims: NN architecture for edge update function given as a list of + dimensions of each layer. + Default = None + state_dims: NN architecture for state update function given as a list of + dimensions of each layer. + Default = None + normalization: Normalization type to use in update functions. either "graph" or "layer" + If None, no normalization is applied. + Default = None + normalize_hidden: Whether to normalize hidden features. + Default = False + rbf_order (int): RBF order specifying input dimensions for linear layer + specifying message weights. If 0, no layer-wise weights are used. + Default = 0 + + Returns: + EFFAtomGraphConv + """ + norm_kwargs = ( + {"batched_field": "edge"} if normalization == "graph" else None + ) + + node_update_func = GatedMLP_norm( + in_feats=node_dims[0], + dims=node_dims[1:], + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + norm_kwargs=norm_kwargs, + ) + node_out_func = nn.Linear( + in_features=node_dims[-1], out_features=node_dims[-1], bias=False + ) + node_weight_func = ( + nn.Linear( + in_features=rbf_order, out_features=node_dims[-1], bias=False + ) + if rbf_order > 0 + else None + ) + edge_update_func = ( + GatedMLP_norm( + in_feats=edge_dims[0], + dims=edge_dims[1:], + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + norm_kwargs=norm_kwargs, + ) + if edge_dims is not None + else None + ) + edge_weight_func = ( + nn.Linear( + in_features=rbf_order, out_features=edge_dims[-1], bias=False + ) + if rbf_order > 0 and edge_dims is not None + else None + ) + state_update_func = ( + MLP( + state_dims, + activation, + activate_last=True, + ) + if state_dims is not None + else None + ) + + return cls( + node_update_func=node_update_func, + node_out_func=node_out_func, + edge_update_func=edge_update_func, + node_weight_func=node_weight_func, + edge_weight_func=edge_weight_func, + state_update_func=state_update_func, + ) + + def _edge_udf(self, edges: dgl.udf.EdgeBatch) -> dict[str, Tensor]: + """Edge user defined update function. + + Update for bond features (edges) in atom graph. + + Args: + edges: edges in atom graph (ie bonds) + + Returns: + edge_update: edge features update + """ + atom_i = edges.src["features"] # first atom features + atom_j = edges.dst["features"] # second atom features + bond_ij = edges.data["features"] # bond features + if self.include_state: + global_state = edges.data["global_state"] + inputs = torch.hstack([atom_i, bond_ij, atom_j, global_state]) + else: + inputs = torch.hstack([atom_i, bond_ij, atom_j]) + + edge_update = self.edge_update_func(inputs, edges._graph) # type: ignore + if self.edge_weight_func is not None: + rbf = edges.data["bond_expansion"] + rbf = rbf.float() + edge_update = edge_update * self.edge_weight_func(rbf) + + return {"feat_update": edge_update} + + def edge_update_( + self, graph: dgl.DGLGraph, shared_weights: Tensor | None + ) -> Tensor: + """Perform edge update -> bond features. + + Args: + graph: atom graph + shared_weights: atom graph edge weights shared between convolution layers + + Returns: + edge_update: edge features update + """ + graph.apply_edges(self._edge_udf) + edge_update = graph.edata["feat_update"] + if shared_weights is not None: + edge_update = edge_update * shared_weights + return edge_update + + def node_update_( + self, graph: dgl.DGLGraph, shared_weights: Tensor | None + ) -> Tensor: + """Perform node update -> atom features. + + Args: + graph: DGL atom graph + shared_weights: node message shared weights + + Returns: + node_update: updated node features + """ + src, dst = graph.edges() + atom_i = graph.ndata["features"][src] # first atom features + atom_j = graph.ndata["features"][dst] # second atom features + bond_ij = graph.edata["features"] # bond features + + if self.include_state: + global_state = graph.edata["global_state"] + inputs = torch.hstack([atom_i, bond_ij, atom_j, global_state]) + else: + inputs = torch.hstack([atom_i, bond_ij, atom_j]) + + messages = self.node_update_func(inputs, graph) + + # smooth out the messages with layer-wise weights + if self.node_weight_func is not None: + rbf = graph.edata["bond_expansion"] + rbf = rbf.float() + messages = messages * self.node_weight_func(rbf) + + # smooth out the messages with shared weights + if shared_weights is not None: + messages = messages * shared_weights + + # message passing + graph.edata["message"] = messages + graph.update_all( + fn.copy_e("message", "message"), fn.sum("message", "feat_update") + ) + + # update nodes + node_update = self.node_out_func( + graph.ndata["feat_update"] + ) # the bond update + + return node_update + + def state_update_(self, graph: dgl.DGLGraph, state_attr: Tensor) -> Tensor: + """Perform attribute (global state) update. + + Args: + graph: atom graph + state_attr: global state features + + Returns: + state_update: state features update + """ + node_avg = dgl.readout_nodes(graph, feat="features", op="mean") + inputs = torch.hstack([state_attr, node_avg]) + state_attr = self.state_update_func(inputs) # type: ignore + return state_attr + + def forward( + self, + graph: dgl.DGLGraph, + node_features: Tensor, + edge_features: Tensor, + state_attr: Tensor, + shared_node_weights: Tensor | None, + shared_edge_weights: Tensor | None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Perform sequence of edge->node->states updates. + + Args: + graph: atom graph + node_features: node features + edge_features: edge features + state_attr: state attributes + shared_node_weights: shared node message weights + shared_edge_weights: shared edge message weights + + Returns: + tuple: updated node features, updated edge features, updated state attributes + """ + with graph.local_scope(): + graph.ndata["features"] = node_features + graph.edata["features"] = edge_features + + if self.include_state: + graph.edata["global_state"] = dgl.broadcast_edges( + graph, state_attr + ) + + if self.edge_update_func is not None: + edge_update = self.edge_update_(graph, shared_edge_weights) + new_edge_features = edge_features + edge_update + graph.edata["features"] = new_edge_features + else: + new_edge_features = edge_features + + node_update = self.node_update_(graph, shared_node_weights) + new_node_features = node_features + node_update + graph.ndata["features"] = new_node_features + + if self.include_state: + state_attr = self.state_update_(graph, state_attr) # type: ignore + + return new_node_features, new_edge_features, state_attr + + +class EFFAtomGraphBlock(nn.Module): + """ + A EFF atom graph block as a sequence of operations + involving a message passing layer over the atom graph. + """ + + def __init__( + self, + num_atom_feats: int, + num_bond_feats: int, + activation: Module, + atom_hidden_dims: Sequence[int], + bond_hidden_dims: Sequence[int] | None = None, + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + num_state_feats: int | None = None, + rbf_order: int = 0, + dropout: float = 0.0, + ): + """. + + Args: + num_atom_feats: number of atom features + num_bond_feats: number of bond features + activation: activation function + atom_hidden_dims: dimensions of atom convolution hidden layers + bond_hidden_dims: dimensions of bond update hidden layers. + normalization: Normalization type to use in update functions. either "graph" or "layer" + If None, no normalization is applied. + Default = None + normalize_hidden: Whether to normalize hidden features. + Default = False + num_state_feats: number of state features if self.include_state is True + Default = None + rbf_order: RBF order specifying input dimensions for linear layer + specifying message weights. If 0, no layer-wise weights are used. + Default = False + dropout: dropout probability. + Default = 0.0 + """ + super().__init__() + + node_input_dim = 2 * num_atom_feats + num_bond_feats + if num_state_feats is not None: + node_input_dim += num_state_feats + state_dims = [ + num_atom_feats + num_state_feats, + *atom_hidden_dims, + num_state_feats, + ] + else: + state_dims = None + node_dims = [node_input_dim, *atom_hidden_dims, num_atom_feats] + edge_dims = ( + [node_input_dim, *bond_hidden_dims, num_bond_feats] + if bond_hidden_dims is not None + else None + ) + + self.conv_layer = EFFGraphConv.from_dims( + activation=activation, + node_dims=node_dims, + edge_dims=edge_dims, + state_dims=state_dims, + normalization=normalization, + normalize_hidden=normalize_hidden, + rbf_order=rbf_order, + ) + + if normalization == "graph": + self.atom_norm = GraphNorm(num_atom_feats, batched_field="node") + self.bond_norm = GraphNorm(num_bond_feats, batched_field="edge") + elif normalization == "layer": + self.atom_norm = LayerNorm(num_atom_feats) + self.bond_norm = LayerNorm(num_bond_feats) + else: + self.atom_norm = None + self.bond_norm = None + + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward( + self, + graph: dgl.DGLGraph, + atom_features: Tensor, + bond_features: Tensor, + state_attr: Tensor, + shared_node_weights: Tensor | None, + shared_edge_weights: Tensor | None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Perform sequence of bond(optional)->atom->states(optional) updates. + + Args: + graph: atom graph + atom_features: node features + bond_features: edge features + state_attr: state attributes + shared_node_weights: node message weights shared amongst layers + shared_edge_weights: edge message weights shared amongst layers + """ + atom_features, bond_features, state_attr = self.conv_layer( + graph=graph, + node_features=atom_features, + edge_features=bond_features, + state_attr=state_attr, + shared_node_weights=shared_node_weights, + shared_edge_weights=shared_edge_weights, + ) + # move skip connections here? dropout before skip connections? + atom_features = self.dropout(atom_features) + bond_features = self.dropout(bond_features) + if self.atom_norm is not None: + atom_features = self.atom_norm(atom_features, graph) + if self.bond_norm is not None: + bond_features = self.bond_norm(bond_features, graph) + if state_attr is not None: + state_attr = self.dropout(state_attr) + + return atom_features, bond_features, state_attr + + +class MLP_norm(nn.Module): + """Multi-layer perceptron with normalization layer.""" + + def __init__( + self, + dims: list[int], + activation: nn.Module | None = None, + activate_last: bool = False, + use_bias: bool = True, + bias_last: bool = True, + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + norm_kwargs: dict[str, Any] | None = None, + ) -> None: + """ + Args: + dims: Dimensions of each layer of MLP. + activation: activation: Activation function. + activate_last: Whether to apply activation to last layer. + use_bias: Whether to use bias. + bias_last: Whether to apply bias to last layer. + normalization: normalization name. "graph" or "layer" + normalize_hidden: Whether to normalize output of hidden layers. + norm_kwargs: Keyword arguments for normalization layer. + """ + super().__init__() + self._depth = len(dims) - 1 + self.layers = nn.ModuleList() + self.norm_layers = ( + nn.ModuleList() if normalization in ("graph", "layer") else None + ) + self.activation = ( + activation if activation is not None else nn.Identity() + ) + self.activate_last = activate_last + self.normalize_hidden = normalize_hidden + norm_kwargs = norm_kwargs or {} + norm_kwargs = cast(dict, norm_kwargs) + + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + if i < self._depth - 1: + self.layers.append(Linear(in_dim, out_dim, bias=use_bias)) + if normalize_hidden and self.norm_layers is not None: + if normalization == "graph": + self.norm_layers.append( + GraphNorm(out_dim, **norm_kwargs) + ) + elif normalization == "layer": + self.norm_layers.append( + LayerNorm(out_dim, **norm_kwargs) + ) + else: + self.layers.append( + Linear(in_dim, out_dim, bias=use_bias and bias_last) + ) + if self.norm_layers is not None: + if normalization == "graph": + self.norm_layers.append( + GraphNorm(out_dim, **norm_kwargs) + ) + elif normalization == "layer": + self.norm_layers.append( + LayerNorm(out_dim, **norm_kwargs) + ) + + def forward(self, inputs: torch.Tensor, g=None) -> torch.Tensor: + """Applies all layers in turn. + + Args: + inputs: input feature tensor. + g: graph of model, needed for graph normalization + + Returns: + output feature tensor. + """ + x = inputs + for i in range(self._depth - 1): + x = self.layers[i](x) + if self.norm_layers is not None and self.normalize_hidden: + x = self.norm_layers[i](x, g) + x = self.activation(x) + + x = self.layers[-1](x) + if self.norm_layers is not None: + x = self.norm_layers[-1](x, g) + if self.activate_last: + x = self.activation(x) + return x + + +class ActivationFunction(Enum): + """Enumeration of optional activation functions.""" + + swish = nn.SiLU + # sigmoid = nn.Sigmoid + # tanh = nn.Tanh + # softplus = nn.Softplus + # softplus2 = SoftPlus2 + # softexp = SoftExponential + + +class ALIGNNeFF(nn.Module): + """Main EFF model.""" + + __version__ = 1 + + def __init__( + self, + config: ALIGNNeFFConfig = ALIGNNeFFConfig(name="alignn_eff"), + element_types: tuple[str, ...] | None = None, + dim_atom_embedding: int = 64, + dim_bond_embedding: int = 64, + dim_angle_embedding: int = 64, + dim_state_embedding: int | None = None, + dim_state_feats: int | None = None, + non_linear_bond_embedding: bool = False, + non_linear_angle_embedding: bool = False, + cutoff: float = 4.0, + threebody_cutoff: float = 3.0, + cutoff_exponent: int = 5, + max_n: int = 9, + max_f: int = 4, + learn_basis: bool = True, + num_blocks: int = 4, + shared_bond_weights: ( + Literal["bond", "three_body_bond", "both"] | None + ) = "both", + layer_bond_weights: ( + Literal["bond", "three_body_bond", "both"] | None + ) = None, + atom_conv_hidden_dims: Sequence[int] = (64,), + bond_update_hidden_dims: Sequence[int] | None = None, + bond_conv_hidden_dims: Sequence[int] = (64,), + angle_update_hidden_dims: Sequence[int] | None = (), + conv_dropout: float = 0.0, + final_mlp_type: Literal["gated", "mlp"] = "mlp", + final_hidden_dims: Sequence[int] = (64, 64), + final_dropout: float = 0.0, + pooling_operation: Literal["sum", "mean"] = "sum", + readout_field: Literal[ + "atom_feat", "bond_feat", "angle_feat" + ] = "atom_feat", + activation_type: str = "swish", + normalization: Literal["graph", "layer"] | None = None, + normalize_hidden: bool = False, + is_intensive: bool = False, + num_targets: int = 1, + num_site_targets: int = 1, + task_type: Literal["regression", "classification"] = "regression", + ): + super().__init__() + + # self.save_args(locals(), kwargs) + + activation: nn.Module = ActivationFunction[activation_type].value() + + element_types = element_types or DEFAULT_ELEMENTS + + # basis expansions for bond lengths, triple interaction bond lengths and angles + self.bond_expansion = RadialBesselFunction( + max_n=max_n, cutoff=cutoff, learnable=learn_basis + ) + self.threebody_bond_expansion = RadialBesselFunction( + max_n=max_n, cutoff=threebody_cutoff, learnable=learn_basis + ) + self.angle_expansion = FourierExpansion( + max_f=max_f, learnable=learn_basis + ) + + # embedding block for atom, bond, angle, and optional state features + self.include_states = dim_state_feats is not None + self.state_embedding = ( + nn.Embedding(dim_state_feats, dim_state_embedding) + if self.include_states + else None + ) + self.atom_embedding = nn.Embedding( + len(element_types), dim_atom_embedding + ) + + # self.atom_embedding = MLP_norm( + # 1, dim_state_embedding + # ) + + self.bond_embedding = MLP_norm( + [max_n, dim_bond_embedding], + activation=activation, + activate_last=non_linear_bond_embedding, + bias_last=False, + ) + self.angle_embedding = MLP_norm( + [2 * max_f + 1, dim_angle_embedding], + activation=activation, + activate_last=non_linear_angle_embedding, + bias_last=False, + ) + + # shared message bond distance smoothing weights + self.atom_bond_weights = ( + nn.Linear(max_n, dim_atom_embedding, bias=False) + if shared_bond_weights in ["bond", "both"] + else None + ) + self.bond_bond_weights = ( + nn.Linear(max_n, dim_bond_embedding, bias=False) + if shared_bond_weights in ["bond", "both"] + else None + ) + self.threebody_bond_weights = ( + nn.Linear(max_n, dim_bond_embedding, bias=False) + if shared_bond_weights in ["three_body_bond", "both"] + else None + ) + + # operations involving the graph (i.e. atom graph) to update atom and bond features + self.atom_graph_layers = nn.ModuleList( + [ + EFFAtomGraphBlock( + num_atom_feats=dim_atom_embedding, + num_bond_feats=dim_bond_embedding, + atom_hidden_dims=atom_conv_hidden_dims, + bond_hidden_dims=bond_update_hidden_dims, + num_state_feats=dim_state_embedding, + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + dropout=conv_dropout, + rbf_order=0, + ) + for _ in range(num_blocks) + ] + ) + + # operations involving the line graph (i.e. bond graph) to update bond and angle features + self.bond_graph_layers = nn.ModuleList( + [ + EFFBondGraphBlock( + num_atom_feats=dim_atom_embedding, + num_bond_feats=dim_bond_embedding, + num_angle_feats=dim_angle_embedding, + bond_hidden_dims=bond_conv_hidden_dims, + angle_hidden_dims=angle_update_hidden_dims, + activation=activation, + normalization=normalization, + normalize_hidden=normalize_hidden, + bond_dropout=conv_dropout, + angle_dropout=conv_dropout, + rbf_order=0, + ) + for _ in range(num_blocks - 1) + ] + ) + + self.sitewise_readout = ( + nn.Linear(dim_atom_embedding, num_site_targets) + if num_site_targets > 0 + else lambda x: None + ) + print("final_mlp_type", final_mlp_type) + input_dim = ( + dim_atom_embedding + if readout_field == "node_feat" + else dim_bond_embedding + ) + + self.final_layer = MLP_norm( + dims=[input_dim, *final_hidden_dims, num_targets], + activation=activation, + activate_last=False, + ) + + self.element_types = element_types + self.max_n = max_n + self.max_f = max_f + self.cutoff = cutoff + self.cutoff_exponent = cutoff_exponent + self.three_body_cutoff = threebody_cutoff + + self.n_blocks = num_blocks + self.readout_operation = pooling_operation + self.readout_field = readout_field + self.readout_type = final_mlp_type + + self.task_type = task_type + self.is_intensive = is_intensive + + def forward( + self, + g, + state_attr: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Forward pass of the model. + + Args: + g (dgl.DGLGraph): Input g. + state_attr (torch.Tensor, optional): State features. Defaults to None. + l_g (dgl.DGLGraph, optional): Line graph. Defaults to None and is computed internally. + + Returns: + torch.Tensor: Model output. + """ + g, l_g, lat = g + st = lat.new_zeros([g.batch_size, 3, 3]) + st.requires_grad_(True) + lattice = lat @ (torch.eye(3, device=lat.device) + st) + g.edata["lattice"] = torch.repeat_interleave( + lattice, g.batch_num_edges(), dim=0 + ) + g.edata["pbc_offshift"] = ( + g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"] + ).sum(dim=1) + g.ndata["pos"] = ( + g.ndata["frac_coords"].unsqueeze(dim=-1) + * torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0) + ).sum(dim=1) + g.ndata["pos"].requires_grad_(True) + + # compute bond vectors and distances and add to g, needs to be computed here to register gradients + bond_vec, bond_dist = compute_pair_vector_and_distance(g) + g.edata["bond_vec"] = bond_vec.to(g.device) + g.edata["bond_dist"] = bond_dist.to(g.device) + bond_expansion = self.bond_expansion(bond_dist) + smooth_cutoff = polynomial_cutoff( + bond_expansion, self.cutoff, self.cutoff_exponent + ) + g.edata["bond_expansion"] = smooth_cutoff * bond_expansion + + # create bond graph (line graoh) with necessary node and edge data + # print("self.readout_field", self.readout_field) + bond_graph = create_line_graph( + g, self.three_body_cutoff, directed=True + ) + + bond_graph.ndata["bond_index"] = bond_graph.ndata["edge_ids"] + threebody_bond_expansion = self.threebody_bond_expansion( + bond_graph.ndata["bond_dist"] + ) + smooth_cutoff = polynomial_cutoff( + threebody_bond_expansion, + self.three_body_cutoff, + self.cutoff_exponent, + ) + bond_graph.ndata["bond_expansion"] = ( + smooth_cutoff * threebody_bond_expansion + ) + bond_indices = bond_graph.ndata["bond_index"][bond_graph.edges()[0]] + bond_graph.edata["center_atom_index"] = g.edges()[1][bond_indices] + bond_graph.apply_edges(compute_theta) + bond_graph.edata["angle_expansion"] = self.angle_expansion( + bond_graph.edata["theta"] + ) + + # atom_features = self.atom_embedding(g.ndata["atom_features"]) + atom_features = self.atom_embedding(g.ndata["node_type"]) + + bond_features = self.bond_embedding(g.edata["bond_expansion"]) + angle_features = self.angle_embedding( + bond_graph.edata["angle_expansion"] + ) + if self.state_embedding is not None and state_attr is not None: + state_attr = self.state_embedding(state_attr) + else: + state_attr = None + + # shared message weights + atom_bond_weights = ( + self.atom_bond_weights(g.edata["bond_expansion"]) + if self.atom_bond_weights is not None + else None + ) + # print("atom_bond_weights", torch.sum(atom_bond_weights)) + bond_bond_weights = ( + self.bond_bond_weights(g.edata["bond_expansion"]) + if self.bond_bond_weights is not None + else None + ) + # print("bond_bond_weights", torch.sum(bond_bond_weights)) + threebody_bond_weights = ( + self.threebody_bond_weights(bond_graph.ndata["bond_expansion"]) + if self.threebody_bond_weights is not None + else None + ) + + # message passing layers + for i in range(self.n_blocks - 1): + atom_features, bond_features, state_attr = self.atom_graph_layers[ + i + ]( + g, + atom_features, + bond_features, + state_attr, + atom_bond_weights, + bond_bond_weights, + ) + bond_features, angle_features = self.bond_graph_layers[i]( + bond_graph, + atom_features, + bond_features, + angle_features, + threebody_bond_weights, + ) + + atom_features, bond_features, state_attr = self.atom_graph_layers[-1]( + g, + atom_features, + bond_features, + state_attr, + atom_bond_weights, + bond_bond_weights, + ) + + g.ndata["atom_feat"] = self.final_layer(atom_features) + structure_properties = readout_nodes( + g, "atom_feat", op=self.readout_operation + ) + # self.add_ewald=True + # ewald_en = 0 + # if self.add_ewald: + # ewald_en = get_atomic_repulsion(g) + # total_energies = (torch.squeeze(structure_properties)) +ewald_en/g.num_nodes() + total_energies = torch.squeeze(structure_properties) + + penalty_factor = 500.0 # Penalty weight, tune as needed + penalty_factor = 1000.0 # Penalty weight, tune as needed + penalty_threshold = 1.0 # 1 angstrom + + # Calculate penalties for distances less than the threshold + penalties = torch.where( + bond_dist < penalty_threshold, + penalty_factor * (penalty_threshold - bond_dist), + torch.zeros_like(bond_dist), + ) + total_penalty = torch.sum(penalties) + + # min_distance=1.0 + # mask = bond_dist < min_distance + # penalty = torch.zeros_like(bond_dist) + # epsilon=1.0 + # alpha=12 + # Smooth penalty calculation for close distances + # penalty[mask] = epsilon * ((min_distance / bond_dist[mask]) ** alpha) + + # Sum up the penalties + # total_penalty = torch.sum(penalty) + total_energies += total_penalty + forces = torch.zeros(1) + stresses = torch.zeros(1) + hessian = torch.zeros(1) + grad_vars = [ + g.ndata["pos"], + st, + ] # if self.calc_stresses else [g.ndata["pos"]] + # print('total_energies',total_energies) + grads = grad( + total_energies, + grad_vars, + grad_outputs=torch.ones_like(total_energies), + create_graph=True, + retain_graph=True, + ) + forces = -grads[0] + volume = torch.abs(torch.det(lattice)) + sts = -grads[1] + scale = 1.0 / volume * -160.21766208 + sts = ( + [i * j for i, j in zip(sts, scale)] + if sts.dim() == 3 + else [sts * scale] + ) + stresses = torch.cat(sts) + result = {} + result["out"] = total_energies + result["grad"] = forces + result["stresses"] = stresses + return result diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py index f4acd8c..9bb3a91 100644 --- a/alignn/models/alignn_ff2.py +++ b/alignn/models/alignn_ff2.py @@ -1,8 +1,4 @@ -"""Atomistic LIne Graph Neural Network. - -A prototype crystal line graph network dgl implementation. -""" - +from math import pi, sqrt from typing import Tuple, Union from torch.autograd import grad import dgl @@ -13,344 +9,146 @@ from torch import nn from torch.nn import functional as F from alignn.models.utils import ( + RadialBesselFunction, RBFExpansion, + RBFExpansionSmooth, BesselExpansion, SphericalHarmonicsExpansion, FourierExpansion, compute_pair_vector_and_distance, check_line_graph, cutoff_function_based_edges, + compute_cartesian_coordinates, + MLPLayer, ) from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings -from matgl.layers._basis import RadialBesselFunction,FourierExpansion -from matgl.layers import MLP_norm - -torch.autograd.set_detect_anomaly(True) +from dgl import GCNNorm class ALIGNNFF2Config(BaseSettings): """Hyperparameter schema for jarvisdgl.models.alignn.""" name: Literal["alignn_ff2"] - alignn_layers: int = 4 - gcn_layers: int = 4 - atom_input_features: int = 92 - edge_input_features: int = 80 + alignn_layers: int = 2 + gcn_layers: int = 2 + atom_input_features: int = 1 + edge_input_features: int = 64 triplet_input_features: int = 40 embedding_features: int = 64 - hidden_features: int = 256 + hidden_features: int = 128 output_features: int = 1 grad_multiplier: int = -1 calculate_gradient: bool = True atomwise_output_features: int = 0 graphwise_weight: float = 1.0 gradwise_weight: float = 1.0 - stresswise_weight: float = 0.00001 + stresswise_weight: float = 0.0 atomwise_weight: float = 0.0 classification: bool = False - force_mult_natoms: bool = True + batch_stress: bool = False use_cutoff_function: bool = True - inner_cutoff: float = 4 # Ansgtrom - stress_multiplier: float = 1 - add_reverse_forces: bool = False # will make True as default soon - batch_stress: bool = True - multiply_cutoff: bool = False + use_penalty: bool = True + multiply_cutoff: bool = True + inner_cutoff: float = 4.0 # Angstrom + stress_multiplier: float = 1.0 + sigma: float = 0.2 + exponent: int = 4 extra_features: int = 0 - exponent: int = 3 - bond_exp_basis: str = "gaussian" # "bessel" # or gaussian - angle_exp_basis: str = "gaussian" # "bessel" # or gaussian - max_n: int = 9 - max_f: int = 4 - learn_basis: bool = True - -class EdgeGatedGraphConv(nn.Module): - """Edge gated graph convolution from arxiv:1711.07553. - see also arxiv:2003.0098. - - This is similar to CGCNN, but edge features only go into - the soft attention / edge gating function, and the primary - node update function is W cat(u, v) + b +class GraphConv(nn.Module): + """ + Custom Graph Convolution layer with smooth transformations on bond lengths and angles. """ def __init__( - self, input_features: int, output_features: int, residual: bool = True + self, in_feats, out_feats, activation=nn.SiLU(), hidden_features=64 ): - """Initialize parameters for ALIGNN update.""" - super().__init__() - self.residual = residual - # CGCNN-Conv operates on augmented edge features - # z_ij = cat(v_i, v_j, u_ij) - # m_ij = σ(z_ij W_f + b_f) ⊙ g_s(z_ij W_s + b_s) - # coalesce parameters for W_f and W_s - # but -- split them up along feature dimension - self.src_gate = nn.Linear(input_features, output_features) - self.dst_gate = nn.Linear(input_features, output_features) - self.edge_gate = nn.Linear(input_features, output_features) - self.bn_edges = nn.LayerNorm(output_features) - - self.src_update = nn.Linear(input_features, output_features) - self.dst_update = nn.Linear(input_features, output_features) - self.bn_nodes = nn.LayerNorm(output_features) - - def forward( - self, - g: dgl.DGLGraph, - node_feats: torch.Tensor, - edge_feats: torch.Tensor, - ) -> torch.Tensor: - """Edge-gated graph convolution. - - h_i^l+1 = ReLU(U h_i + sum_{j->i} eta_{ij} ⊙ V h_j) + super(GraphConv, self).__init__() + self.fc = nn.Linear( + in_feats, out_feats + ) # Linear transformation for features + self.activation = activation + self.edge_transform = nn.Linear( + hidden_features, out_feats + ) # For bond-length based transformation + + def forward(self, g, node_feats, bond_feats): """ - g = g.local_var() - - # instead of concatenating (u || v || e) and applying one weight matrix - # split the weight matrix into three, apply, then sum - # see https://docs.dgl.ai/guide/message-efficient.html - # but split them on feature dimensions to update u, v, e separately - # m = BatchNorm(Linear(cat(u, v, e))) - - # compute edge updates, equivalent to: - # Softplus(Linear(u || v || e)) - g.ndata["e_src"] = self.src_gate(node_feats) - g.ndata["e_dst"] = self.dst_gate(node_feats) - g.apply_edges(fn.u_add_v("e_src", "e_dst", "e_nodes")) - m = g.edata.pop("e_nodes") + self.edge_gate(edge_feats) - - g.edata["sigma"] = torch.sigmoid(m) - g.ndata["Bh"] = self.dst_update(node_feats) + Forward pass with bond length handling for smooth transitions. + """ + # Transform bond (edge) features + # print('bond_feats',bond_feats.shape) + bond_feats = self.edge_transform(bond_feats) + + # Message passing: message = transformed edge feature + node feature + g.ndata["h"] = node_feats + g.edata["e"] = bond_feats g.update_all( - fn.u_mul_e("Bh", "sigma", "m"), fn.sum("m", "sum_sigma_h") + message_func=fn.u_add_e( + "h", "e", "m" + ), # Add node and edge features + reduce_func=fn.sum("m", "h"), # Sum messages for each node ) - g.update_all(fn.copy_e("sigma", "m"), fn.sum("m", "sum_sigma")) - g.ndata["h"] = g.ndata["sum_sigma_h"] / (g.ndata["sum_sigma"] + 1e-6) - x = self.src_update(node_feats) + g.ndata.pop("h") - - # softmax version seems to perform slightly worse - # that the sigmoid-gated version - # compute node updates - # Linear(u) + edge_gates ⊙ Linear(v) - # g.edata["gate"] = edge_softmax(g, y) - # g.ndata["h_dst"] = self.dst_update(node_feats) - # g.update_all(fn.u_mul_e("h_dst", "gate", "m"), fn.sum("m", "h")) - # x = self.src_update(node_feats) + g.ndata.pop("h") - - # node and edge updates - x = F.silu(self.bn_nodes(x)) - y = F.silu(self.bn_edges(m)) - - if self.residual: - x = node_feats + x - y = edge_feats + y - return x, y + # Final node feature transformation + node_feats = self.fc(g.ndata["h"]) + return self.activation(node_feats), bond_feats -class ALIGNNConv(nn.Module): - """Line graph update.""" - - def __init__( - self, - in_features: int, - out_features: int, - ): - """Set up ALIGNN parameters.""" - super().__init__() - self.node_update = EdgeGatedGraphConv(in_features, out_features) - self.edge_update = EdgeGatedGraphConv(out_features, out_features) - - def forward( - self, - g: dgl.DGLGraph, - lg: dgl.DGLGraph, - x: torch.Tensor, - y: torch.Tensor, - z: torch.Tensor, - ): - """Node and Edge updates for ALIGNN layer. - - x: node input features - y: edge input features - z: edge pair input features - """ - g = g.local_var() - lg = lg.local_var() - # Edge-gated graph convolution update on crystal graph - x, m = self.node_update(g, x, y) - - # Edge-gated graph convolution update on crystal graph - y, z = self.edge_update(lg, m, z) +class AtomGraphBlock(nn.Module): + """ + Atom Graph Block that processes atom-centric features and uses GraphConv for updates. + """ - return x, y, z + def __init__(self, in_feats, out_feats, n_layers=2, hidden_features=64): + super(AtomGraphBlock, self).__init__() + self.layers = nn.ModuleList( + [ + GraphConv( + in_feats if i == 0 else out_feats, + out_feats, + hidden_features=hidden_features, + ) + for i in range(n_layers) + ] + ) + def forward(self, g, node_feats, bond_feats): + for layer in self.layers: + node_feats, bond_feats = layer(g, node_feats, bond_feats) + return node_feats, bond_feats -class AtomWise(nn.Module): - """A class representing an interatomic potential.""" +class BondGraphBlock(nn.Module): + """ + Bond Graph Block that applies additional processing on bond-based features. + """ - __version__ = 3 + def __init__(self, in_feats, out_feats, n_layers=2, hidden_features=64): + super(BondGraphBlock, self).__init__() + # self.fc = nn.Linear(in_feats, out_feats) # Linear transformation for bond features + # self.activation = activation + self.layers = nn.ModuleList( + [ + GraphConv( + in_feats if i == 0 else out_feats, + out_feats, + hidden_features=hidden_features, + ) + for i in range(n_layers) + ] + ) - def __init__( - self, - model: nn.Module, - data_mean: torch.Tensor | float = 0.0, - data_std: torch.Tensor | float = 1.0, - #element_refs: np.ndarray | None = None, - calc_forces: bool = True, - calc_stresses: bool = True, - calc_hessian: bool = False, - calc_magmom: bool = False, - calc_repuls: bool = False, - zbl_trainable: bool = False, - debug_mode: bool = False, - ): - """Initialize Potential from a model and elemental references. - - Args: - model: Model for predicting energies. - data_mean: Mean of target. - data_std: Std dev of target. - element_refs: Element reference values for each element. - calc_forces: Enable force calculations. - calc_stresses: Enable stress calculations. - calc_hessian: Enable hessian calculations. - calc_magmom: Enable site-wise property calculation. - calc_repuls: Whether the ZBL repulsion is included - zbl_trainable: Whether zbl repulsion is trainable - debug_mode: Return gradient of total energy with respect to atomic positions and lattices for checking + def forward(self, g, bond_feats, angle_feats): """ - super().__init__() - self.save_args(locals()) - self.model = model - self.calc_forces = calc_forces - self.calc_stresses = calc_stresses - self.calc_hessian = calc_hessian - self.calc_magmom = calc_magmom - #self.element_refs: AtomRef | None - self.debug_mode = debug_mode - self.calc_repuls = calc_repuls - - if calc_repuls: - self.repuls = NuclearRepulsion(self.model.cutoff, trainable=zbl_trainable) - - if element_refs is not None: - self.element_refs = AtomRef(property_offset=torch.tensor(element_refs, dtype=matgl.float_th)) - else: - self.element_refs = None - # for backward compatibility - if data_mean is None: - data_mean = 0.0 - self.register_buffer("data_mean", torch.tensor(data_mean, dtype=matgl.float_th)) - self.register_buffer("data_std", torch.tensor(data_std, dtype=matgl.float_th)) - - def forward( - self, - g: dgl.DGLGraph, - lat: torch.Tensor, - lg: dgl.DGLGraph | None = None, - ) -> tuple[torch.Tensor, ...]: - """Args: - g: DGL graph - lat: lattice - state_attr: State attrs - l_g: Line graph. - - Returns: - (energies, forces, stresses, hessian) or (energies, forces, stresses, hessian, site-wise properties) + Process bond features with smooth transformations. """ - # st (strain) for stress calculations - result = {} - #st = lat.new_zeros([g.batch_size, 3, 3]) - #if self.calc_stresses: - # st.requires_grad_(True) - lattice = lat @ (torch.eye(3, device=lat.device) + st) - g.edata["lattice"] = torch.repeat_interleave(lattice, g.batch_num_edges(), dim=0) - g.edata["pbc_offshift"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) - g.ndata["pos"] = ( - g.ndata["frac_coords"].unsqueeze(dim=-1) * torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0) - ).sum(dim=1) - if self.calc_forces: - g.ndata["pos"].requires_grad_(True) - - total_energies = self.model(g,lg) - - total_energies = self.data_std * total_energies + self.data_mean - - if self.calc_repuls: - total_energies += self.repuls(self.model.element_types, g) - - if self.element_refs is not None: - property_offset = torch.squeeze(self.element_refs(g)) - total_energies += property_offset - - forces = torch.zeros(1) - stresses = torch.zeros(1) - hessian = torch.zeros(1) - - grad_vars = [g.ndata["pos"], st] if self.calc_stresses else [g.ndata["pos"]] - - if self.calc_forces: - grads = grad( - total_energies, - grad_vars, - grad_outputs=torch.ones_like(total_energies), - create_graph=True, - retain_graph=True, - ) - forces = -grads[0] - - if self.calc_hessian: - r = -grads[0].view(-1) - s = r.size(0) - hessian = total_energies.new_zeros((s, s)) - for iatom in range(s): - tmp = grad([r[iatom]], g.ndata["pos"], retain_graph=iatom < s)[0] - if tmp is not None: - hessian[iatom] = tmp.view(-1) - - if self.calc_stresses: - volume = ( - torch.abs(torch.det(lattice.float())).half() - if matgl.float_th == torch.float16 - else torch.abs(torch.det(lattice)) - ) - sts = -grads[1] - scale = 1.0 / volume * -160.21766208 - sts = [i * j for i, j in zip(sts, scale)] if sts.dim() == 3 else [sts * scale] - stresses = torch.cat(sts) - - if self.debug_mode: - return total_energies, grads[0], grads[1] - - if self.calc_magmom: - return total_energies, forces, stresses, hessian, g.ndata["magmom"] - result['out']=total_energies - result['grad']=forces - result['stresses']=stresses - result['atomwise_pred']=atomwise_pred - - return result - - -class MLPLayer(nn.Module): - """Multilayer perceptron layer helper.""" - - def __init__(self, in_features: int, out_features: int): - """Linear, Batchnorm, SiLU layer.""" - super().__init__() - self.layer = nn.Sequential( - nn.Linear(in_features, out_features), - nn.LayerNorm(out_features), - nn.SiLU(), - ) - - def forward(self, x): - """Linear, Batchnorm, silu layer.""" - # print('xtype',x.dtype) - return self.layer(x) + # Transform bond features and apply smooth activation + for layer in self.layers: + bond_feats, angle_feats = layer(g, bond_feats, angle_feats) + return bond_feats, angle_feats class ALIGNNFF2(nn.Module): @@ -366,7 +164,6 @@ def __init__( ): """Initialize class with number of input features, conv layers.""" super().__init__() - # print(config) self.classification = config.classification self.config = config if self.config.gradwise_weight == 0: @@ -376,73 +173,57 @@ def __init__( self.atom_embedding = MLPLayer( config.atom_input_features, config.hidden_features ) - if self.config.bond_exp_basis == "bessel": - self.bond_expansion = RadialBesselFunction(max_n=config.max_n, cutoff=config.inner_cutoff, learnable=False) - self.edge_embedding = MLP_norm([config.max_n, config.hidden_features],bias_last=False) - #self.edge_embedding = MLP_norm([config.edge_input_features, config.hidden_features],bias_last=False) - #self.bond_expansion = RadialBesselFunction(max_n=config.edge_input_features, cutoff=config.inner_cutoff, learnable=True) - else: - self.edge_embedding = nn.Sequential( - RBFExpansion( - vmin=0, - vmax=8.0, - bins=config.edge_input_features, - ), - MLPLayer( - config.edge_input_features, config.embedding_features - ), - MLPLayer(config.embedding_features, config.hidden_features), - ) - if self.config.angle_exp_basis == "spherical": - self.angle_embedding = nn.Sequential( - SphericalHarmonicsExpansion(), - MLPLayer( - config.triplet_input_features, config.embedding_features - ), - MLPLayer(config.embedding_features, config.hidden_features), - ) # not tested - elif self.config.angle_exp_basis == "bessel": - self.angle_expansion = FourierExpansion(max_f=config.max_f, learnable=False) - self.angle_embedding = MLP_norm([2*config.max_f+1, config.hidden_features],bias_last=False) - elif self.config.angle_exp_basis == "fourier": - self.angle_embedding = nn.Sequential( - FourierExpansion(), - MLPLayer( - config.triplet_input_features, config.embedding_features - ), - MLPLayer(config.embedding_features, config.hidden_features), - ) # not tested - else: - self.angle_embedding = nn.Sequential( - RBFExpansion( - vmin=-1.0, - vmax=1.0, - bins=config.triplet_input_features, - ), - MLPLayer( - config.triplet_input_features, config.embedding_features - ), - MLPLayer(config.embedding_features, config.hidden_features), - ) + self.edge_embedding = nn.Sequential( + RadialBesselFunction( + max_n=config.edge_input_features, cutoff=config.inner_cutoff + ), + # RBFExpansionSmooth(num_centers=config.edge_input_features, cutoff=config.inner_cutoff, sigma=config.sigma), + MLPLayer(config.edge_input_features, config.embedding_features), + MLPLayer(config.embedding_features, config.hidden_features), + ) + self.angle_embedding = nn.Sequential( + RadialBesselFunction( + max_n=config.edge_input_features, cutoff=config.inner_cutoff + ), + # RBFExpansionSmooth(num_centers=config.triplet_input_features, cutoff=1.0, sigma=config.sigma), + MLPLayer(config.edge_input_features, config.embedding_features), + MLPLayer(config.embedding_features, config.hidden_features), + ) + + self.atom_graph_layers = nn.ModuleList( + [ + AtomGraphBlock( + config.hidden_features, + config.hidden_features, + n_layers=config.gcn_layers, + hidden_features=config.hidden_features, + ) + ] + ) - self.alignn_layers = nn.ModuleList( + self.bond_graph_layers = nn.ModuleList( [ - ALIGNNConv( + BondGraphBlock( config.hidden_features, config.hidden_features, + n_layers=config.gcn_layers, + hidden_features=config.hidden_features, ) - for idx in range(config.alignn_layers) ] ) - self.gcn_layers = nn.ModuleList( + + self.angle_graph_layers = nn.ModuleList( [ - EdgeGatedGraphConv( - config.hidden_features, config.hidden_features + BondGraphBlock( + config.hidden_features, + config.hidden_features, + hidden_features=config.hidden_features, ) - for idx in range(config.gcn_layers) ] ) + self.gnorm = GCNNorm() + self.readout = AvgPooling() if config.extra_features != 0: @@ -479,55 +260,35 @@ def __init__( else: self.fc = nn.Linear(config.hidden_features, config.output_features) - def forward( - self, g: Union[Tuple[dgl.DGLGraph, dgl.DGLGraph], dgl.DGLGraph] - ): - """ALIGNN : start with `atom_features`. - - x: atom features (g.ndata) - y: bond features (g.edata and lg.ndata) - z: angle features (lg.edata) - """ + def forward(self, g): result = {} - if len(self.alignn_layers) > 0: - g, lg = g + if self.config.alignn_layers > 0: + g, lg, lat = g lg = lg.local_var() + # print('lattice',lattice,lattice.shape) + else: + g, lat = g + if self.config.extra_features != 0: features = g.ndata["extra_features"] features = self.extra_feature_embedding(features) x = g.ndata.pop("atom_features") x = self.atom_embedding(x) - check_lg=True - g.edata["pbc_offshift"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) - g.ndata["cart_coords"] = (g.ndata["frac_coords"].unsqueeze(dim=-1) * g.ndata["lattice"][0]).sum(dim=1) + + g = self.gnorm(g) + # Compute and embed bond lengths + g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat) if self.config.calculate_gradient: - #g.edata["images"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) - #torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0)).sum(dim=1) g.ndata["cart_coords"].requires_grad_(True) + r, bondlength = compute_pair_vector_and_distance(g) bondlength = torch.norm(r, dim=1) - g.edata["d"] = bondlength - g.edata["r"] = r - # bond_expansion = self.bond_expansion(bondlength) - # z = self.angle_embedding(lg.edata.pop("h")) - lg = check_line_graph(g, lg, self.config.inner_cutoff) - lg.apply_edges(compute_bond_cosines) + y = self.edge_embedding(bondlength) # smooth_cutoff = polynomial_cutoff( - # bond_expansion, self.config.inner_cutoff, self.config.exponent + # bond_expansion, self.config.inner_cutoff, self.config.exponent # ) # bond_expansion *= smooth_cutoff - # g.edata["bond_expansion"] = ( - # bond_expansion # smooth_cutoff * bond_expansion - # ) - - if self.config.bond_exp_basis=='bessel': - z = self.angle_embedding(self.angle_expansion(lg.edata.pop("h"))) - y = self.edge_embedding(self.bond_expansion(bondlength)) - else: - y = self.edge_embedding(bondlength) - z = self.angle_embedding(lg.edata.pop("h")) - if self.config.use_cutoff_function: if self.config.multiply_cutoff: c_off = cutoff_function_based_edges( @@ -545,18 +306,14 @@ def forward( ) y = self.edge_embedding(bondlength) else: - #print('bondlength',bondlength,bondlength.shape) y = self.edge_embedding(bondlength) - #""" - # ALIGNN updates: update node, edge, triplet features - for alignn_layer in self.alignn_layers: - x, y, z = alignn_layer(g, lg, x, y, z) - - # gated GCN updates: update node, edge features - for gcn_layer in self.gcn_layers: - x, y = gcn_layer(g, x, y) - # norm-activation-pool-classify - out = torch.empty(1) + out = torch.empty(1) # graph level output eg energy + lg = g.line_graph(shared=True) + lg.ndata["r"] = r + lg.apply_edges(compute_bond_cosines) + for atom_graph_layer in self.atom_graph_layers: + x, y = atom_graph_layer(g, x, y) + if self.config.output_features is not None: h = self.readout(g, x) out = self.fc(h) @@ -580,120 +337,40 @@ def forward( # gradient = torch.empty(1) stress = torch.empty(1) - if self.config.calculate_gradient: - en_out = out - #g.edata["images"] = (g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"]).sum(dim=1) - #torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0)).sum(dim=1) - #g.ndata["cart_coords"].requires_grad_(True) - grad_vars = [g.ndata["cart_coords"]] - grads = grad(g.num_nodes()*en_out,grad_vars,grad_outputs=torch.ones_like(en_out),create_graph=True,retain_graph=True) - forces_out = -grads[0] - - # force calculation based on bond displacement vectors - # autograd gives dE / d{r_{i->j}} - pair_forces = ( - self.config.grad_multiplier - * grad( - en_out, - r, - grad_outputs=torch.ones_like(en_out), - create_graph=True, - retain_graph=True, - )[0] + if self.config.use_penalty: + penalty_factor = 500.0 # Penalty weight, tune as needed + penalty_factor = 0.01 # Penalty weight, tune as needed + penalty_threshold = 1.0 # 1 angstrom + + penalties = torch.where( + bondlength < penalty_threshold, + penalty_factor * (penalty_threshold - bondlength), + torch.zeros_like(bondlength), ) - if self.config.force_mult_natoms: - pair_forces *= g.num_nodes() + total_penalty = torch.sum(penalties) + out += total_penalty - # construct force_i = dE / d{r_i} - # reduce over bonds to get forces on each atom + if self.config.calculate_gradient: - # force_i contributions from r_{j->i} (in edges) - g.edata["pair_forces"] = pair_forces - g.update_all( - fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ji") + # en_out = torch.sum(out)*g.num_nodes() + en_out = out # *g.num_nodes() + # en_out = (out) *g.num_nodes() + grad_vars = [g.ndata["cart_coords"]] + grads = grad( + en_out, + grad_vars, + grad_outputs=torch.ones_like(en_out), + create_graph=True, + retain_graph=True, ) - if self.config.add_reverse_forces: - # reduce over reverse edges too! - # force_i contributions from r_{i->j} (out edges) - # aggregate pairwise_force_contributions over reversed edges - rg = dgl.reverse(g, copy_edata=True) - rg.update_all( - fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ij") - ) - - # combine dE / d(r_{j->i}) and dE / d(r_{i->j}) - forces = torch.squeeze( - g.ndata["forces_ji"] - rg.ndata["forces_ij"] - ) - else: - forces = torch.squeeze(g.ndata["forces_ji"]) - # print('forces',forces) - - if self.config.stresswise_weight != 0: - # Under development, use with caution - # 1 eV/Angstrom3 = 160.21766208 GPa - # 1 GPa = 10 kbar - # Following Virial stress formula, assuming inital velocity = 0 - # Save volume as g.gdta['V']? - # print('pair_forces',pair_forces.shape) - # print('r',r.shape) - # print('g.ndata["V"]',g.ndata["V"].shape) - if not self.config.batch_stress: - # print('Not batch_stress') - stress = ( - -1 - * 160.21766208 - * ( - torch.matmul(r.T, pair_forces) - # / (2 * g.edata["V"]) - / (2 * g.ndata["V"][0]) - ) - ) - # print("stress1", stress, stress.shape) - # print("g.batch_size", g.batch_size) - else: - # print('Using batch_stress') - stresses = [] - count_edge = 0 - count_node = 0 - for graph_id in range(g.batch_size): - num_edges = g.batch_num_edges()[graph_id] - num_nodes = 0 - st = -1 * ( - 160.21766208 - * torch.matmul( - r[count_edge : count_edge + num_edges].T, - pair_forces[ - count_edge : count_edge + num_edges - ], - ) - / g.ndata["V"][count_node + num_nodes] - ) - - count_edge = count_edge + num_edges - num_nodes = g.batch_num_nodes()[graph_id] - count_node = count_node + num_nodes - # print("stresses.append",stresses[-1],stresses[-1].shape) - for n in range(num_nodes): - stresses.append(st) - # stress = (stresses) - stress = self.config.stress_multiplier * torch.cat( - stresses - ) - # print("stress2", stress, stress.shape) - # virial = ( - # 160.21766208 - # * 10 - # * torch.einsum("ij, ik->jk", result["r"], result["dy_dr"]) - # / 2 - # ) # / ( g.ndata["V"][0]) + forces_out = -1 * grads[0] * g.num_nodes() + # forces_out = -1*grads[0] + stresses = torch.eye(3) if self.classification: - # out = torch.max(out,dim=1) out = self.softmax(out) result["out"] = out result["grad"] = forces_out result["stresses"] = stress result["atomwise_pred"] = atomwise_pred - # print(result) return result diff --git a/alignn/models/utils.py b/alignn/models/utils.py index af7e3b5..4d8c58c 100644 --- a/alignn/models/utils.py +++ b/alignn/models/utils.py @@ -3,15 +3,155 @@ from typing import Optional import numpy as np import torch - -# from torch import nn from math import pi import torch.nn as nn - -# from scipy.special import spherical_jn -# from scipy.special import sph_harm, lpmv import math import dgl +import torch +from typing import Any, Callable, Literal, cast + + +def get_atomic_repulsion(g, cutoff=5.0): + """ + Calculate atomic repulsion energy using pairwise Coulomb interactions within a cutoff distance. + + Parameters: + g (DGLGraph): ALIGNN graph with atom charges (Z) and precomputed bond lengths in g.edata['d']. + cutoff (float): Cutoff distance for pairwise interactions. + + Returns: + float: Atomic repulsion energy for the given graph. + """ + + # Atomic charges + Z = g.ndata["Z"].squeeze() # Ensure Z is a 1D tensor + bond_lengths = g.edata[ + "d" + ] # Precomputed bond lengths in Cartesian coordinates + + # Atomic indices for each edge + src, dst = g.edges() + + # Mask for distances below the cutoff + valid_edges = bond_lengths < cutoff + + # Get charges for each pair + Zi = Z[src[valid_edges]] + Zj = Z[dst[valid_edges]] + rij = bond_lengths[valid_edges] + + # Compute repulsion energy + repulsion_energy = torch.sum(Zi * Zj / rij) + + return repulsion_energy + + +class RadialBesselFunction(nn.Module): + + def __init__( + self, + max_n: int, + cutoff: float, + learnable: bool = False, + dtype=torch.float32, + ): + """ + Args: + max_n: int, max number of roots (including max_n) + cutoff: float, cutoff radius + learnable: bool, whether to learn the location of roots. + """ + super().__init__() + self.max_n = max_n + self.inv_cutoff = 1 / cutoff + self.norm_const = (2 * self.inv_cutoff) ** 0.5 + if learnable: + self.frequencies = torch.nn.Parameter( + data=torch.Tensor( + pi * torch.arange(1, self.max_n + 1, dtype=dtype) + ), + requires_grad=True, + ) + else: + self.register_buffer( + "frequencies", + pi * torch.arange(1, self.max_n + 1, dtype=dtype), + ) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + r = r[:, None] # (nEdges,1) + d_scaled = r * self.inv_cutoff + return self.norm_const * torch.sin(self.frequencies * d_scaled) / r + + +def get_ewald_sum(g, lattice_mat, alpha=0.2, r_cut=10.0, k_cut=5): + """ + Calculate the Ewald sum energy for the DGL graph using precomputed rij vectors. + + Parameters: + g (DGLGraph): ALIGNN graph with atom features, fractional coordinates, and precomputed rij vectors. + alpha (float): Ewald splitting parameter, controls the balance between real and reciprocal space sums. + r_cut (float): Real-space cutoff distance for pairwise interactions. + k_cut (int): Reciprocal-space cutoff for Fourier components. + + Returns: + float: Ewald sum energy for the given graph. + """ + + # Atomic numbers (charges) and fractional coordinates + Z = g.ndata["Z"] # Atomic charges (assuming Z is atomic number) + cart_pos = g.ndata[ + "frac_coords" + ] # Fractional coordinates in Cartesian space + r_ij_vectors = g.edata[ + "r" + ] # Precomputed rij vectors in Cartesian coordinates + + # Initialize Ewald sum energy + ewald_energy = 0.0 + + # Real-space sum using precomputed rij vectors + src, dst = g.edges() # Get the source and destination nodes for each edge + for edge_idx in range(len(src)): + i = src[edge_idx] + j = dst[edge_idx] + + # Pairwise distance (norm of r_ij) + r = torch.norm(r_ij_vectors[edge_idx]) + + if r < r_cut: + ewald_energy += Z[i] * Z[j] * torch.erfc(alpha * r) / r + + # Reciprocal-space sum + # lattice_mat = g.ndata['lattice_mat'][0] # Assuming lattice matrix is uniform across nodes + recip_vectors = ( + 2 * pi * torch.inverse(lattice_mat).T + ) # Reciprocal lattice vectors + for h in range(-k_cut, k_cut + 1): + for k in range(-k_cut, k_cut + 1): + for l in range(-k_cut, k_cut + 1): + if h == 0 and k == 0 and l == 0: + continue + k_vec = ( + h * recip_vectors[:, 0] + + k * recip_vectors[:, 1] + + l * recip_vectors[:, 2] + ) + k_sq = torch.dot(k_vec, k_vec) + structure_factor = torch.sum( + Z * torch.exp(1j * torch.matmul(cart_pos, k_vec)) + ) + ewald_energy += ( + torch.exp(-k_sq / (4 * alpha**2)) / k_sq + ) * (torch.norm(structure_factor) ** 2) + + # Self-interaction correction + # print('Z',Z) + # ewald_energy -= alpha / torch.sqrt(pi) * torch.sum(Z ** 2) + ewald_energy -= ( + alpha / torch.sqrt(torch.tensor(torch.pi)) * torch.sum(Z**2) + ) + return ewald_energy.real # Return the real part of the energy class BesselExpansion(nn.Module): @@ -267,8 +407,8 @@ def _k(self, l_x: int, m: int) -> float: def compute_pair_vector_and_distance(g: dgl.DGLGraph): """Calculate bond vectors and distances using dgl graphs.""" - #print('g.edges()',g.ndata["cart_coords"][g.edges()[1]].shape,g.edata["pbc_offshift"].shape) - dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["pbc_offshift"] + # print('g.edges()',g.ndata["cart_coords"][g.edges()[1]].shape,g.edata["pbc_offshift"].shape) + dst_pos = g.ndata["cart_coords"][g.edges()[1]] + g.edata["images"] src_pos = g.ndata["cart_coords"][g.edges()[0]] bond_vec = dst_pos - src_pos bond_dist = torch.norm(bond_vec, dim=1) @@ -348,3 +488,379 @@ def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3): # envelope = (r_cut_sq - r_sq) # ** 2 * (r_cut_sq + 2 * r_sq - 3 * r_on_sq)/ (r_cut_sq - r_on_sq) ** 3 return torch.where(r <= inner_cutoff, envelope, torch.zeros_like(r)) + + +def compute_cartesian_coordinates(g, lattice, dtype=torch.float32): + """ + Compute Cartesian coordinates from fractional coordinates and lattice matrices. + + Args: + g: DGL graph with 'frac_coords' as node data. + lattice: Tensor of shape (B, 3, 3), where B is the batch size. + dtype: Torch dtype to ensure consistent tensor types. + + Returns: + Tensor of Cartesian coordinates with shape (N, 3). + """ + # Get fractional coordinates (N, 3) and ensure correct dtype + frac_coords = g.ndata["frac_coords"].to(dtype) + + # Ensure lattice is 3D with shape (B, 3, 3) and correct dtype + if lattice.dim() == 2: # If shape is (3, 3), expand to (1, 3, 3) + lattice = lattice.unsqueeze(0).to(dtype) + else: + lattice = lattice.to(dtype) + + # Generate batch indices to map nodes to their corresponding graph + batch_indices = torch.repeat_interleave( + torch.arange(len(lattice), device=frac_coords.device), + g.batch_num_nodes(), + ) + + # Expand lattice matrices based on batch indices to match node count + expanded_lattice = lattice[batch_indices] # Shape: (N, 3, 3) + + # Perform batched matrix multiplication to get Cartesian coordinates + cart_coords = torch.bmm( + frac_coords.unsqueeze(1), # Shape: (N, 1, 3) + expanded_lattice, # Shape: (N, 3, 3) + ).squeeze( + 1 + ) # Shape: (N, 3) + + return cart_coords + + +class RBFExpansionSmooth(nn.Module): + """ + RBF Expansion layer for bond lengths with smooth output variation. + """ + + def __init__(self, num_centers=10, cutoff=5.0, sigma=0.5): + super(RBFExpansionSmooth, self).__init__() + + # Initialize centers and sigma for Gaussian RBFs + self.cutoff = cutoff + self.sigma = sigma + self.centers = torch.linspace(0, cutoff, num_centers).view( + 1, -1 + ) # Shape (1, num_centers) + + def forward(self, bondlengths): + """ + Compute the RBF features for a batch of bond lengths. + + Parameters: + - bondlengths: Tensor of shape (batch_size,) containing bond lengths. + + Returns: + - RBF expanded features: Tensor of shape (batch_size, num_centers) with smoothly varying RBFs. + """ + # Reshape bondlengths to (batch_size, 1) for broadcasting + bondlengths = bondlengths.view(-1, 1) # Shape (batch_size, 1) + + # Calculate RBF values + rbf_features = torch.exp( + -((bondlengths - self.centers.to(bondlengths.device)) ** 2) + / (2 * self.sigma**2) + ) + + # Apply cutoff + mask = bondlengths <= self.cutoff + rbf_features = ( + rbf_features * mask.float() + ) # Mask to zero out beyond cutoff + + return rbf_features + + +class MLPLayer(nn.Module): + """Multilayer perceptron layer helper.""" + + def __init__(self, in_features: int, out_features: int): + """Linear, Batchnorm, SiLU layer.""" + super().__init__() + self.layer = nn.Sequential( + nn.Linear(in_features, out_features), + nn.LayerNorm(out_features), + nn.SiLU(), + ) + + def forward(self, x): + """Linear, Batchnorm, silu layer.""" + # print('xtype',x.dtype) + return self.layer(x) + + +def _create_directed_line_graph( + graph: dgl.DGLGraph, threebody_cutoff: float +) -> dgl.DGLGraph: + with torch.no_grad(): + pg = prune_edges_by_features( + graph, + feat_name="bond_dist", + condition=lambda x: torch.gt(x, threebody_cutoff), + ) + """ + lg=graph.line_graph(shared=True) + lg.ndata["src_bond_sign"] = torch.ones( + (lg.number_of_nodes(), 1), + dtype=lg.ndata["bond_vec"].dtype, + device=lg.device, + ) + return lg + """ + src_indices, dst_indices = pg.edges() + images = pg.edata["images"] + all_indices = torch.arange( + pg.number_of_nodes(), device=graph.device + ).unsqueeze(dim=0) + num_bonds_per_atom = torch.count_nonzero( + src_indices.unsqueeze(dim=1) == all_indices, dim=0 + ) + num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave( + num_bonds_per_atom + ) + lg_src = torch.empty( + num_edges_per_bond.sum(), dtype=torch.int64, device=graph.device + ) + lg_dst = torch.empty( + num_edges_per_bond.sum(), dtype=torch.int64, device=graph.device + ) + incoming_edges = src_indices.unsqueeze(1) == dst_indices + is_self_edge = src_indices == dst_indices + not_self_edge = ~is_self_edge + + n = 0 + # create line graph edges for bonds that are self edges in atom graph + if is_self_edge.any(): + edge_inds_s = is_self_edge.nonzero() + lg_dst_s = edge_inds_s.repeat_interleave( + num_edges_per_bond[is_self_edge] + 1 + ) + lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze() + lg_src_s = lg_src_s[lg_src_s != lg_dst_s] + lg_dst_s = edge_inds_s.repeat_interleave( + num_edges_per_bond[is_self_edge] + ) + n = len(lg_dst_s) + lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s + + # create line graph edges for bonds that are not self edges in atom graph + shared_src = src_indices.unsqueeze(1) == src_indices + back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all( + -images.unsqueeze(1) == images, axis=2 + ) + incoming = incoming_edges & (shared_src | ~back_tracking) + + edge_inds_ns = not_self_edge.nonzero().squeeze() + lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze() + lg_dst_ns = edge_inds_ns.repeat_interleave( + num_edges_per_bond[not_self_edge] + ) + lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns + lg = dgl.graph((lg_src, lg_dst)) + + for key in pg.edata: + lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()] + + # we need to store the sign of bond vector when a bond is a src node in the line + # graph in order to appropriately calculate angles when self edges are involved + lg.ndata["src_bond_sign"] = torch.ones( + (lg.number_of_nodes(), 1), + dtype=lg.ndata["bond_vec"].dtype, + device=lg.device, + ) + # if we flip self edges then we need to correct computed angles by pi - angle + # lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s] + # find the intersection for the rare cases where not all edges end up as nodes in the line graph + all_ns, counts = torch.cat( + [ + torch.arange(lg.number_of_nodes(), device=graph.device), + edge_inds_ns, + ] + ).unique(return_counts=True) + lg_inds_ns = all_ns[torch.where(counts > 1)] + lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][ + lg_inds_ns + ] + + return lg + + +def prune_edges_by_features( + graph: dgl.DGLGraph, + feat_name: str, + condition: Callable[[torch.Tensor], torch.Tensor], + keep_ndata: bool = False, + keep_edata: bool = True, + *args, + **kwargs, +) -> dgl.DGLGraph: + if feat_name not in graph.edata: + raise ValueError( + f"Edge field {feat_name} not an edge feature in given graph." + ) + + valid_edges = torch.logical_not( + condition(graph.edata[feat_name], *args, **kwargs) + ) + valid_edges1 = torch.ones( + graph.num_edges(), dtype=torch.bool, device=graph.device + ) + # print('valid_edges',valid_edges,valid_edges.shape) + # print('valid_edges1',valid_edges1,valid_edges1.shape) + + src, dst = graph.edges() + src, dst = src[valid_edges], dst[valid_edges] + e_ids = valid_edges.nonzero().squeeze() + new_g = dgl.graph((src, dst), device=graph.device) + new_g.edata["edge_ids"] = e_ids # keep track of original edge ids + + if keep_ndata: + for key, value in graph.ndata.items(): + new_g.ndata[key] = value + if keep_edata: + for key, value in graph.edata.items(): + new_g.edata[key] = value[valid_edges] + + return new_g + + +def compute_theta( + edges: dgl.udf.EdgeBatch, + cosine: bool = False, + directed: bool = True, + eps=1e-7, +) -> dict[str, torch.Tensor]: + """User defined dgl function to calculate bond angles from edges in a graph. + + Args: + edges: DGL graph edges + cosine: Whether to return the cosine of the angle or the angle itself + directed: Whether to the line graph was created with create directed line graph. + In which case bonds (only those that are not self bonds) need to + have their bond vectors flipped. + eps: eps value used to clamp cosine values to avoid acos of values > 1.0 + + Returns: + dict[str, torch.Tensor]: Dictionary containing bond angles and distances + """ + vec1 = ( + edges.src["bond_vec"] * edges.src["src_bond_sign"] + if directed + else edges.src["bond_vec"] + ) + vec2 = edges.dst["bond_vec"] + key = "cos_theta" if cosine else "theta" + val = torch.sum(vec1 * vec2, dim=1) / ( + torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1) + ) + val = val.clamp_( + min=-1 + eps, max=1 - eps + ) # stability for floating point numbers > 1.0 + if not cosine: + val = torch.acos(val) + return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]} + + +def create_line_graph( + g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False +) -> dgl.DGLGraph: + """ + Calculate the three body indices from pair atom indices. + + Args: + g: DGL graph + threebody_cutoff (float): cutoff for three-body interactions + directed (bool): Whether to create a directed line graph, or an M3gnet 3body line graph + Default = False (M3Gnet) + + Returns: + l_g: DGL graph containing three body information from graph + """ + graph_with_three_body = prune_edges_by_features( + g, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff + ) + if directed: + # lg = g.line_graph(shared=True) + # return lg + lg = _create_directed_line_graph( + graph_with_three_body, threebody_cutoff + ) + else: + lg = _compute_3body(graph_with_three_body) + + return lg + + +def compute_pair_vector_and_distance(g: dgl.DGLGraph): + """Calculate bond vectors and distances using dgl graphs. + + Args: + g: DGL graph + + Returns: + bond_vec (torch.tensor): bond distance between two atoms + bond_dist (torch.tensor): vector from src node to dst node + """ + dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["images"] + src_pos = g.ndata["pos"][g.edges()[0]] + bond_vec = dst_pos - src_pos + bond_dist = torch.norm(bond_vec, dim=1) + + return bond_vec, bond_dist + + +def polynomial_cutoff( + r: torch.Tensor, cutoff: float, exponent: int = 3 +) -> torch.Tensor: + """Envelope polynomial function that ensures a smooth cutoff. + + Ensures first and second derivative vanish at cuttoff. As described in: + https://arxiv.org/abs/2003.03123 + + Args: + r (torch.Tensor): radius distance tensor + cutoff (float): cutoff distance. + exponent (int): minimum exponent of the polynomial. Default is 3. + The polynomial includes terms of order exponent, exponent + 1, exponent + 2. + + Returns: polynomial cutoff function + """ + coef1 = -(exponent + 1) * (exponent + 2) / 2 + coef2 = exponent * (exponent + 2) + coef3 = -exponent * (exponent + 1) / 2 + ratio = r / cutoff + poly_envelope = ( + 1 + + coef1 * ratio**exponent + + coef2 * ratio ** (exponent + 1) + + coef3 * ratio ** (exponent + 2) + ) + + return torch.where(r <= cutoff, poly_envelope, 0.0) + + +if __name__ == "__main__": + from jarvis.core.atoms import Atoms + from alignn.graphs import radius_graph_jarvis + + FIXTURES = { + "lattice_mat": [ + [2.715, 2.715, 0], + [0, 2.715, 2.715], + [2.715, 0, 2.715], + ], + "coords": [[0, 0, 0], [0.25, 0.25, 0.25]], + "elements": ["Si", "Si"], + } + Si = Atoms( + lattice_mat=FIXTURES["lattice_mat"], + coords=FIXTURES["coords"], + elements=FIXTURES["elements"], + ) + g, lg = radius_graph_jarvis( + atoms=s1, cutoff=5, atom_features="atomic_number" + ) + ewald = get_ewald_sum(g, torch.tensor(Si.lattice_mat)) diff --git a/alignn/train.py b/alignn/train.py index 24c925e..3dc71ff 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -13,6 +13,7 @@ from alignn.config import TrainingConfig from alignn.models.alignn_atomwise import ALIGNNAtomWise from alignn.models.alignn_ff2 import ALIGNNFF2 +from alignn.models.alignn_eff import ALIGNNeFF from alignn.models.alignn import ALIGNN from jarvis.db.jsonutils import dumpjson import json @@ -34,6 +35,8 @@ warnings.filterwarnings("ignore", category=RuntimeWarning) +torch.autograd.detect_anomaly() + def train_dgl( config: Union[TrainingConfig, Dict[str, Any]], @@ -148,6 +151,7 @@ def train_dgl( "alignn_atomwise": ALIGNNAtomWise, "alignn": ALIGNN, "alignn_ff2": ALIGNNFF2, + "alignn_eff": ALIGNNeFF, } if config.random_seed is not None: random.seed(config.random_seed) @@ -169,11 +173,6 @@ def train_dgl( net = _model.get(config.model.name)(config.model) else: net = model - from matgl.models import CHGNet, M3GNet - from matgl.utils.training import ModelLightningModule, PotentialLightningModule - #model = M3GNet(element_types=['Si'], is_intensive=False) - model = CHGNet(element_types=['Si'], is_intensive=False,threebody_cutoff=4) - net = PotentialLightningModule(model=model, stress_weight=0.0001, include_line_graph=True) print("net parameters", sum(p.numel() for p in net.parameters())) # print("device", device) net.to(device) @@ -205,11 +204,11 @@ def train_dgl( optimizer, ) - if ( - config.model.name == "alignn_atomwise" - or config.model.name == "alignn_ff2" - ): - + # if ( + # config.model.name == "alignn_atomwise" + # or config.model.name == "alignn_ff2" + # ): + if "alignn_" in config.model.name: best_loss = np.inf criterion = nn.L1Loss() if classification: @@ -233,8 +232,17 @@ def train_dgl( # info["id"] = jid optimizer.zero_grad() if (config.model.alignn_layers) > 0: - result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) - #result = net([dats[0].to(device), dats[1].to(device),lat=dats[2].to(device)]) + result = net( + [ + dats[0].to(device), + dats[1].to(device), + dats[2].to(device), + ] + ) + # result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) + # result = net([dats[0].to(device), dats[1].to(device),lat=dats[2].to(device)]) + # batched_graph, batched_line_graph, torch.stack(lattices),torch.tensor(labels) + else: result = net(dats[0].to(device)) # info = {} @@ -351,8 +359,15 @@ def train_dgl( optimizer.zero_grad() # result = net([dats[0].to(device), dats[1].to(device)]) if (config.model.alignn_layers) > 0: - #result = net([dats[0].to(device), dats[2].to(device), dats[1].to(device)]) - result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) + # result = net([dats[0].to(device), dats[2].to(device), dats[1].to(device)]) + # result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) + result = net( + [ + dats[0].to(device), + dats[1].to(device), + dats[2].to(device), + ] + ) else: result = net(dats[0].to(device)) # info = {} @@ -496,7 +511,14 @@ def train_dgl( info["id"] = jid optimizer.zero_grad() if (config.model.alignn_layers) > 0: - result = net([dats[0].to(device), dats[1].to(device)]) + # result = net([dats[0].to(device), dats[1].to(device)]) + result = net( + [ + dats[0].to(device), + dats[1].to(device), + dats[2].to(device), + ] + ) else: result = net(dats[0].to(device)) loss1 = 0 # Such as energy From e445579b18b741d9a0f54312f680cff301f51cb0 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 3 Nov 2024 21:56:42 -0500 Subject: [PATCH 09/37] Clean up. --- alignn/__init__.py | 2 +- alignn/config.py | 4 - alignn/dataset.py | 2 + .../config_example_atomwise.json | 4 +- alignn/ff/ff.py | 7 - alignn/graphs.py | 51 +- alignn/lmdb_dataset.py | 34 +- alignn/models/alignn.py | 8 +- alignn/models/alignn_atomwise.py | 69 +- alignn/models/alignn_eff.py | 1281 ----------------- alignn/models/alignn_ff2.py | 376 ----- alignn/models/utils.py | 725 +--------- alignn/train.py | 9 - alignn/train_alignn.py | 15 - setup.py | 2 +- 15 files changed, 91 insertions(+), 2498 deletions(-) delete mode 100644 alignn/models/alignn_eff.py delete mode 100644 alignn/models/alignn_ff2.py diff --git a/alignn/__init__.py b/alignn/__init__.py index 7e471e3..bf293b6 100644 --- a/alignn/__init__.py +++ b/alignn/__init__.py @@ -1,3 +1,3 @@ """Version number.""" -__version__ = "2024.8.30" +__version__ = "2024.10.30" diff --git a/alignn/config.py b/alignn/config.py index 9dbe7ec..4ed2318 100644 --- a/alignn/config.py +++ b/alignn/config.py @@ -6,8 +6,6 @@ from typing import Literal from alignn.utils import BaseSettings from alignn.models.alignn import ALIGNNConfig -from alignn.models.alignn_ff2 import ALIGNNFF2Config -from alignn.models.alignn_eff import ALIGNNeFFConfig from alignn.models.alignn_atomwise import ALIGNNAtomWiseConfig # import torch @@ -211,8 +209,6 @@ class TrainingConfig(BaseSettings): # model configuration model: Union[ ALIGNNConfig, - ALIGNNFF2Config, - ALIGNNeFFConfig, ALIGNNAtomWiseConfig, # CGCNNConfig, # ICGCNNConfig, diff --git a/alignn/dataset.py b/alignn/dataset.py index 8ad2004..0d2e88b 100644 --- a/alignn/dataset.py +++ b/alignn/dataset.py @@ -14,6 +14,8 @@ tqdm.pandas() +# NOTE: Use lmd_dataset, +# need to fix adding lattice in dataloader def load_graphs( dataset=[], name: str = "dft_3d", diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index baf5198..84a5ee6 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -3,7 +3,7 @@ "dataset": "user_data", "target": "target", "atom_features": "cgcnn", - "neighbor_strategy": "radius_graph_jarvis", + "neighbor_strategy": "radius_graph", "id_tag": "jid", "dtype": "float32", "random_seed": 123, @@ -40,7 +40,7 @@ "distributed":false, "use_lmdb": true, "model": { - "name": "alignn_ff2", + "name": "alignn_atomwise", "atom_input_features": 92, "calculate_gradient":true, "atomwise_output_features":0, diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 71e224e..f0bf991 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -30,9 +30,6 @@ from jarvis.db.jsonutils import loadjson from alignn.graphs import Graph from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig -from alignn.models.alignn_ff2 import ALIGNNFF2, ALIGNNFF2Config -from alignn.models.alignn_eff import ALIGNNeFF, ALIGNNeFFConfig -from alignn.config import TrainingConfig from jarvis.analysis.defects.vacancy import Vacancy import numpy as np from alignn.pretrained import get_prediction @@ -270,12 +267,8 @@ def __init__( ) if self.model is None: - if config["model"]["name"] == "alignn_ff2": - model = ALIGNNFF2(ALIGNNFF2Config(**config["model"])) if config["model"]["name"] == "alignn_atomwise": model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"])) - if config["model"]["name"] == "alignn_eff": - model = ALIGNNeFF(ALIGNNeFFConfig(**config["model"])) model.state_dict() model.load_state_dict( torch.load( diff --git a/alignn/graphs.py b/alignn/graphs.py index 686b7a5..054aaff 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -15,6 +15,7 @@ import torch import dgl from tqdm import tqdm +from jarvis.core.atoms import Atoms # import matgl @@ -63,13 +64,12 @@ def temp_graph( g.ndata["Z"] = torch.tensor(atom_feats, dtype=torch.int64) g.edata["r"] = torch.tensor(np.array(r), dtype=dtype) g.edata["d"] = torch.tensor(d, dtype=dtype) - g.edata["pbc_offset"] = torch.tensor(images, dtype=dtype) - g.edata["pbc_offshift"] = torch.tensor(images, dtype=dtype) + # g.edata["pbc_offset"] = torch.tensor(images, dtype=dtype) + # g.edata["pbc_offshift"] = torch.tensor(images, dtype=dtype) g.edata["images"] = torch.tensor(images, dtype=dtype) - # g.edata["lattice"] = torch.tensor(torch.repeat_interleave(torch.tensor(atoms.lattice_mat.flatten()), atoms.num_atoms), dtype=dtype) - node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) - g.ndata["node_type"] = node_type - lattice_mat = atoms.lattice_mat + # node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) + # g.ndata["node_type"] = node_type + # lattice_mat = atoms.lattice_mat # g.ndata["lattice"] = torch.tensor( # [lattice_mat for ii in range(g.num_nodes())] # , dtype=dtype) @@ -78,7 +78,6 @@ def temp_graph( # , dtype=dtype) g.ndata["pos"] = torch.tensor(atoms.cart_coords, dtype=dtype) g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords, dtype=dtype) - # g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype) return g, u, v, r @@ -516,43 +515,45 @@ def atom_dgl_multigraph( # u, v, r = build_undirected_edgedata(atoms, edges) # build up atom attribute tensor - comp = atoms.composition.to_dict() - comp_dict = {} - c_ind = 0 - for ii, jj in comp.items(): - if ii not in comp_dict: - comp_dict[ii] = c_ind - c_ind += 1 + # comp = atoms.composition.to_dict() + # comp_dict = {} + # c_ind = 0 + # for ii, jj in comp.items(): + # if ii not in comp_dict: + # comp_dict[ii] = c_ind + # c_ind += 1 sps_features = [] - node_types = [] + # node_types = [] for ii, s in enumerate(atoms.elements): feat = list(get_node_attributes(s, atom_features=atom_features)) # if include_prdf_angles: # feat=feat+list(prdf[ii])+list(adf[ii]) sps_features.append(feat) - node_types.append(comp_dict[s]) + # node_types.append(comp_dict[s]) sps_features = np.array(sps_features) node_features = torch.tensor(sps_features).type( torch.get_default_dtype() ) g = dgl.graph((u, v)) g.ndata["atom_features"] = node_features - g.ndata["node_type"] = torch.tensor(node_types, dtype=torch.int64) - node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))]) - g.ndata["node_type"] = node_type + # g.ndata["node_type"] = torch.tensor(node_types, dtype=torch.int64) + # node_type = torch.tensor([0 for i in range(len(atoms.atm_num))]) + # g.ndata["node_type"] = node_type # print('g.ndata["node_type"]',g.ndata["node_type"]) - g.edata["r"] = torch.tensor(r).type(torch.get_default_dtype()) + g.edata["r"] = torch.tensor(np.array(r)).type( + torch.get_default_dtype() + ) # images=torch.tensor(images).type(torch.get_default_dtype()) # print('images',images.shape,r.shape) # print('type',torch.get_default_dtype()) - g.edata["images"] = torch.tensor(images).type( + g.edata["images"] = torch.tensor(np.array(images)).type( torch.get_default_dtype() ) vol = atoms.volume g.ndata["V"] = torch.tensor([vol for ii in range(atoms.num_atoms)]) - g.ndata["coords"] = torch.tensor(atoms.cart_coords).type( - torch.get_default_dtype() - ) + # g.ndata["coords"] = torch.tensor(atoms.cart_coords).type( + # torch.get_default_dtype() + # ) g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords).type( torch.get_default_dtype() ) @@ -1048,7 +1049,7 @@ def setup_standardizer(self, ids): @staticmethod def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]): """Dataloader helper to batch graphs cross `samples`.""" - graphs, lattice, labels = map(list, zip(*samples)) + graphs, lattices, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(lattices), torch.tensor(labels) diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index 9dbd9cc..682687b 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -58,11 +58,11 @@ def __getitem__(self, idx): with self.env.begin() as txn: serialized_data = txn.get(f"{idx}".encode()) if self.line_graph: - graph, line_graph, lattice,label = pk.loads(serialized_data) - return graph, line_graph, lattice,label + graph, line_graph, lattice, label = pk.loads(serialized_data) + return graph, line_graph, lattice, label else: - graph, lattice,label = pk.loads(serialized_data) - return graph, lattice,label + graph, lattice, label = pk.loads(serialized_data) + return graph, lattice, label def close(self): """Close connection.""" @@ -76,7 +76,7 @@ def __del__(self): def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]): """Dataloader helper to batch graphs cross `samples`.""" # print('samples',samples) - graphs, lattices,labels = map(list, zip(*samples)) + graphs, lattices, labels = map(list, zip(*samples)) # graphs, lgs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(lattices), torch.tensor(labels) @@ -90,9 +90,19 @@ def collate_line_graph( batched_graph = dgl.batch(graphs) batched_line_graph = dgl.batch(line_graphs) if len(labels[0].size()) > 0: - return batched_graph, batched_line_graph, torch.tensor(lattices),torch.stack(labels) + return ( + batched_graph, + batched_line_graph, + torch.tensor(lattices), + torch.stack(labels), + ) else: - return batched_graph, batched_line_graph, torch.stack(lattices),torch.tensor(labels) + return ( + batched_graph, + batched_line_graph, + torch.stack(lattices), + torch.tensor(labels), + ) def get_torch_dataset( @@ -143,7 +153,7 @@ def get_torch_dataset( for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)): ids.append(d[id_tag]) # g, lg = Graph.atom_dgl_multigraph( - atoms=Atoms.from_dict(d["atoms"]) + atoms = Atoms.from_dict(d["atoms"]) g = Graph.atom_dgl_multigraph( atoms, cutoff=float(cutoff), @@ -157,7 +167,9 @@ def get_torch_dataset( ) if line_graph: g, lg = g - lattice=torch.tensor(atoms.lattice_mat).type(torch.get_default_dtype()) + lattice = torch.tensor(atoms.lattice_mat).type( + torch.get_default_dtype() + ) label = torch.tensor(d[target]).type(torch.get_default_dtype()) # print('label',label,label.view(-1).long()) if classification: @@ -184,9 +196,9 @@ def get_torch_dataset( # labels.append(label) if line_graph: - serialized_data = pk.dumps((g, lg, lattice,label)) + serialized_data = pk.dumps((g, lg, lattice, label)) else: - serialized_data = pk.dumps((g, lattice,label)) + serialized_data = pk.dumps((g, lattice, label)) txn.put(f"{idx}".encode(), serialized_data) env.close() diff --git a/alignn/models/alignn.py b/alignn/models/alignn.py index 722e88f..fe00d3d 100644 --- a/alignn/models/alignn.py +++ b/alignn/models/alignn.py @@ -4,21 +4,15 @@ """ from typing import Tuple, Union - import dgl import dgl.function as fn import numpy as np import torch from dgl.nn import AvgPooling - -# from dgl.nn.functional import edge_softmax from typing import Literal from torch import nn from torch.nn import functional as F - -# from alignn.models.utils import RBFExpansion -# from alignn.utils import BaseSettings - +from alignn.models.utils import RBFExpansion from pydantic_settings import BaseSettings diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 9674e6a..dd7d15d 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -19,6 +19,7 @@ RBFExpansion, compute_cartesian_coordinates, compute_pair_vector_and_distance, + MLPLayer, ) from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings @@ -28,13 +29,15 @@ class ALIGNNAtomWiseConfig(BaseSettings): """Hyperparameter schema for jarvisdgl.models.alignn.""" name: Literal["alignn_atomwise"] - alignn_layers: int = 4 - gcn_layers: int = 4 - atom_input_features: int = 92 + alignn_layers: int = 2 + gcn_layers: int = 2 + atom_input_features: int = 1 + # atom_input_features: int = 92 edge_input_features: int = 80 triplet_input_features: int = 40 embedding_features: int = 64 - hidden_features: int = 256 + hidden_features: int = 64 + # hidden_features: int = 256 # fc_layers: int = 1 # fc_features: int = 64 output_features: int = 1 @@ -42,7 +45,7 @@ class ALIGNNAtomWiseConfig(BaseSettings): calculate_gradient: bool = True atomwise_output_features: int = 0 graphwise_weight: float = 1.0 - gradwise_weight: float = 0.0 + gradwise_weight: float = 1.0 stresswise_weight: float = 0.0 atomwise_weight: float = 0.0 # if link == log, apply `exp` to final outputs @@ -51,18 +54,20 @@ class ALIGNNAtomWiseConfig(BaseSettings): zero_inflated: bool = False classification: bool = False force_mult_natoms: bool = False - energy_mult_natoms: bool = False + energy_mult_natoms: bool = True include_pos_deriv: bool = False use_cutoff_function: bool = False - inner_cutoff: float = 6 # Ansgtrom + inner_cutoff: float = 3 # Ansgtrom stress_multiplier: float = 1 - add_reverse_forces: bool = False # will make True as default soon - lg_on_fly: bool = False # will make True as default soon + add_reverse_forces: bool = True # will make True as default soon + lg_on_fly: bool = True # will make True as default soon batch_stress: bool = True multiply_cutoff: bool = False - use_penalty: bool = False + use_penalty: bool = True extra_features: int = 0 - exponent: int = 3 + exponent: int = 5 + penalty_factor: float = 0.1 + penalty_threshold: float = 1 class Config: """Configure model settings behavior.""" @@ -239,23 +244,6 @@ def forward( return x, y, z -class MLPLayer(nn.Module): - """Multilayer perceptron layer helper.""" - - def __init__(self, in_features: int, out_features: int): - """Linear, Batchnorm, SiLU layer.""" - super().__init__() - self.layer = nn.Sequential( - nn.Linear(in_features, out_features), - nn.LayerNorm(out_features), - nn.SiLU(), - ) - - def forward(self, x): - """Linear, Batchnorm, silu layer.""" - return self.layer(x) - - class ALIGNNAtomWise(nn.Module): """Atomistic Line graph network. @@ -478,14 +466,24 @@ def forward( forces = torch.empty(1) # gradient = torch.empty(1) stress = torch.empty(1) + natoms = torch.tensor([gg.num_nodes() for gg in dgl.unbatch(g)]).to( + g.device + ) if self.config.energy_mult_natoms: - en_out = out * g.num_nodes() + # print('g.num_nodes()',g.num_nodes()) + # print('unbatch',dgl.unbatch(g)) + # print('natoms',natoms) + # print('out',out,out.shape) + # print() + # print() + en_out = out * natoms # g.num_nodes() else: en_out = out if self.config.use_penalty: - penalty_factor = 500.0 # Penalty weight, tune as needed - penalty_factor = 0.01 # Penalty weight, tune as needed - penalty_threshold = 1.0 # 1 angstrom + penalty_factor = ( + self.config.penalty_factor + ) # Penalty weight, tune as needed + penalty_threshold = self.config.penalty_threshold # 1 angstrom penalties = torch.where( bondlength < penalty_threshold, @@ -537,7 +535,7 @@ def forward( if self.config.add_reverse_forces: # reduce over reverse edges too! # force_i contributions from r_{i->j} (out edges) - # aggregate pairwise_force_contributions over reversed edges + # aggregate pairwise_force_contribs over reversed edges rg = dgl.reverse(g, copy_edata=True) rg.update_all( fn.copy_e("pair_forces", "m"), fn.sum("m", "forces_ij") @@ -554,7 +552,7 @@ def forward( # Under development, use with caution # 1 eV/Angstrom3 = 160.21766208 GPa # 1 GPa = 10 kbar - # Following Virial stress formula, assuming inital velocity = 0 + # Virial stress formula, assuming inital velocity = 0 # Save volume as g.gdta['V']? # print('pair_forces',pair_forces.shape) # print('r',r.shape) @@ -605,7 +603,8 @@ def forward( # virial = ( # 160.21766208 # * 10 - # * torch.einsum("ij, ik->jk", result["r"], result["dy_dr"]) + # * torch.einsum("ij, ik->jk", + # result["r"], result["dy_dr"]) # / 2 # ) # / ( g.ndata["V"][0]) if self.link: diff --git a/alignn/models/alignn_eff.py b/alignn/models/alignn_eff.py deleted file mode 100644 index 9a9c177..0000000 --- a/alignn/models/alignn_eff.py +++ /dev/null @@ -1,1281 +0,0 @@ -from torch.autograd import grad -from math import pi -from typing import Any, Callable, Literal, cast -from collections.abc import Sequence -from torch.nn import Linear, Module -from jarvis.core.specie import get_element_full_names -import dgl.function as fn -from torch import Tensor, nn -from jarvis.core.atoms import Atoms -from alignn.graphs import Graph -from enum import Enum -import dgl -from pathlib import Path -import torch -from dgl import readout_nodes -import inspect -import json -import os -from alignn.utils import BaseSettings -from alignn.models.utils import ( - get_ewald_sum, - get_atomic_repulsion, - FourierExpansion, - RadialBesselFunction, - prune_edges_by_features, - _create_directed_line_graph, - compute_theta, - create_line_graph, - compute_pair_vector_and_distance, - polynomial_cutoff, -) - -torch.autograd.detect_anomaly() -DEFAULT_ELEMENTS = list(get_element_full_names().keys()) - - -class ALIGNNeFFConfig(BaseSettings): - """Hyperparameter schema for jarvisdgl.models.alignn.""" - - name: Literal["alignn_eff"] - alignn_layers: int = 4 - calculate_gradient: bool = True - output_features: int = 1 - atomwise_output_features: int = 0 - graphwise_weight: float = 1.0 - gradwise_weight: float = 20.0 - stresswise_weight: float = 0.0 - atomwise_weight: float = 0.0 - batch_stress: bool = True - - -class EFFLineGraphConv(nn.Module): - - def __init__( - self, - node_update_func: Module, - node_out_func: Module, - edge_update_func: Module | None, - node_weight_func: Module | None, - ): - """ - Args: - node_update_func: Update function for message between nodes (bonds) - node_out_func: Output function for nodes (bonds), after message aggregation - edge_update_func: edge update function (for angle features) - node_weight_func: layer node weight function. - """ - super().__init__() - - self.node_update_func = node_update_func - self.node_out_func = node_out_func - self.node_weight_func = node_weight_func - self.edge_update_func = edge_update_func - - @classmethod - def from_dims( - cls, - node_dims: list[int], - edge_dims: list[int] | None = None, - activation: Module | None = None, - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - node_weight_input_dims: int = 0, - ): - norm_kwargs = ( - {"batched_field": "edge"} if normalization == "graph" else None - ) - - node_update_func = GatedMLP_norm( - in_feats=node_dims[0], - dims=node_dims[1:], - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - norm_kwargs=norm_kwargs, - ) - node_out_func = nn.Linear( - in_features=node_dims[-1], out_features=node_dims[-1], bias=False - ) - - node_weight_func = ( - nn.Linear(node_weight_input_dims, node_dims[-1]) - if node_weight_input_dims > 0 - else None - ) - edge_update_func = ( - GatedMLP_norm( - in_feats=edge_dims[0], - dims=edge_dims[1:], - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - norm_kwargs=norm_kwargs, - ) - if edge_dims is not None - else None - ) - - return cls( - node_update_func=node_update_func, - node_out_func=node_out_func, - edge_update_func=edge_update_func, - node_weight_func=node_weight_func, - ) - - def _edge_udf(self, edges: dgl.udf.EdgeBatch) -> dict[str, Tensor]: - """Edge user defined update function. - - Update angle features (edges in bond graph) - - Args: - edges: edge batch - - Returns: - edge_update: edge features update - """ - bonds_i = edges.src["features"] # first bonds features - bonds_j = edges.dst["features"] # second bonds features - angle_ij = edges.data["features"] - atom_ij = edges.data["aux_features"] # center atom features - inputs = torch.hstack([bonds_i, angle_ij, atom_ij, bonds_j]) - messages_ij = self.edge_update_func(inputs, edges._graph) # type: ignore - return {"feat_update": messages_ij} - - def edge_update_(self, graph: dgl.DGLGraph) -> Tensor: - """Perform edge update -> update angle features. - - Args: - graph: bond graph (line graph of atom graph) - - Returns: - edge_update: edge features update - """ - graph.apply_edges(self._edge_udf) - edge_update = graph.edata["feat_update"] - return edge_update - - def node_update_( - self, graph: dgl.DGLGraph, shared_weights: Tensor | None - ) -> Tensor: - """Perform node update -> update bond features. - - Args: - graph: bond graph (line graph of atom graph) - shared_weights: node message shared weights - - Returns: - node_update: bond features update - """ - src, dst = graph.edges() - bonds_i = graph.ndata["features"][src] # first bond feature - bonds_j = graph.ndata["features"][dst] # second bond feature - angle_ij = graph.edata["features"] - atom_ij = graph.edata["aux_features"] # center atom features - inputs = torch.hstack([bonds_i, angle_ij, atom_ij, bonds_j]) - - messages = self.node_update_func(inputs, graph) - - # smooth out messages with layer-wise weights - if self.node_weight_func is not None: - rbf = graph.ndata["bond_expansion"] - weights = self.node_weight_func(rbf) - weights_i, weights_j = weights[src], weights[dst] - messages = messages * weights_i * weights_j - - # smooth out messages with shared weights - if shared_weights is not None: - weights_i, weights_j = shared_weights[src], shared_weights[dst] - messages = messages * weights_i * weights_j - - # message passing - graph.edata["message"] = messages - graph.update_all( - fn.copy_e("message", "message"), fn.sum("message", "feat_update") - ) - - # update nodes - node_update = self.node_out_func( - graph.ndata["feat_update"] - ) # the bond update - - return node_update - - def forward( - self, - graph: dgl.DGLGraph, - node_features: Tensor, - edge_features: Tensor, - aux_edge_features: Tensor, - shared_node_weights: Tensor | None, - ) -> tuple[Tensor, Tensor]: - with graph.local_scope(): - graph.ndata["features"] = node_features - graph.edata["features"] = edge_features - graph.edata["aux_features"] = aux_edge_features - - # node (bond) update - node_update = self.node_update_(graph, shared_node_weights) - new_node_features = node_features + node_update - graph.ndata["features"] = new_node_features - - # edge (angle) update (should angle update be done before node update?) - if self.edge_update_func is not None: - edge_update = self.edge_update_(graph) - new_edge_features = edge_features + edge_update - graph.edata["features"] = new_edge_features - else: - new_edge_features = edge_features - - return new_node_features, new_edge_features - - -class GatedMLP_norm(nn.Module): - """An implementation of a Gated multi-layer perceptron constructed with MLP.""" - - def __init__( - self, - in_feats: int, - dims: Sequence[int], - activation: nn.Module | None = None, - activate_last: bool = True, - use_bias: bool = True, - bias_last: bool = True, - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - norm_kwargs: dict[str, Any] | None = None, - ): - """:param in_feats: Dimension of input features. - :param dims: Architecture of neural networks. - :param activation: non-linear activation module. - :param activate_last: Whether applying activation to last layer or not. - :param use_bias: Whether to use a bias in linear layers. - :param bias_last: Whether applying bias to last layer or not. - :param normalization: normalization name. - :param normalize_hidden: Whether to normalize output of hidden layers. - :param norm_kwargs: Keyword arguments for normalization layer. - """ - super().__init__() - self.in_feats = in_feats - self.dims = [in_feats, *dims] - self._depth = len(dims) - self.use_bias = use_bias - self.activate_last = activate_last - - activation = activation if activation is not None else nn.SiLU() - self.layers = MLP_norm( - self.dims, - activation=activation, - activate_last=True, - use_bias=use_bias, - bias_last=bias_last, - normalization=normalization, - normalize_hidden=normalize_hidden, - norm_kwargs=norm_kwargs, - ) - self.gates = MLP_norm( - self.dims, - activation, - activate_last=False, - use_bias=use_bias, - bias_last=bias_last, - normalization=normalization, - normalize_hidden=normalize_hidden, - norm_kwargs=norm_kwargs, - ) - self.sigmoid = nn.Sigmoid() - - def forward(self, inputs: torch.Tensor, graph=None) -> torch.Tensor: - return self.layers(inputs, graph) * self.sigmoid( - self.gates(inputs, graph) - ) - - -class EFFBondGraphBlock(nn.Module): - """A EFF atom graph block as a sequence of operations involving a message passing layer over the bond graph.""" - - def __init__( - self, - num_atom_feats: int, - num_bond_feats: int, - num_angle_feats: int, - activation: Module, - bond_hidden_dims: Sequence[int], - angle_hidden_dims: Sequence[int] | None, - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - rbf_order: int = 0, - bond_dropout: float = 0.0, - angle_dropout: float = 0.0, - ): - """. - - Args: - num_atom_feats: number of atom features - num_bond_feats: number of bond features - num_angle_feats: number of angle features - activation: activation function - bond_hidden_dims: dimensions of hidden layers of bond graph convolution - angle_hidden_dims: dimensions of hidden layers of angle update function - Default = None - normalization: Normalization type to use in update functions. either "graph" or "layer" - If None, no normalization is applied. - Default = None - normalize_hidden: Whether to normalize hidden features. - Default = False - rbf_order (int): RBF order specifying input dimensions for linear layer - specifying message weights. If 0, no layer-wise weights are used. - Default = 0 - bond_dropout (float): dropout probability for bond graph convolution. - Default = 0.0 - angle_dropout (float): dropout probability for angle update function. - Default = 0.0 - """ - super().__init__() - - node_input_dim = 2 * num_bond_feats + num_angle_feats + num_atom_feats - node_dims = [node_input_dim, *bond_hidden_dims, num_bond_feats] - edge_dims = ( - [node_input_dim, *angle_hidden_dims, num_angle_feats] - if angle_hidden_dims is not None - else None - ) - - self.conv_layer = EFFLineGraphConv.from_dims( - node_dims=node_dims, - edge_dims=edge_dims, - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - node_weight_input_dims=rbf_order, - ) - - self.bond_dropout = ( - nn.Dropout(bond_dropout) if bond_dropout > 0.0 else nn.Identity() - ) - self.angle_dropout = ( - nn.Dropout(angle_dropout) if angle_dropout > 0.0 else nn.Identity() - ) - - def forward( - self, - graph: dgl.DGLGraph, - atom_features: Tensor, - bond_features: Tensor, - angle_features: Tensor, - shared_node_weights: Tensor | None, - ) -> tuple[Tensor, Tensor]: - """Perform convolution in BondGraph to update bond and angle features. - - Args: - graph: bond graph (line graph of atom graph) - atom_features: atom features - bond_features: bond features - angle_features: concatenated center atom and angle features - shared_node_weights: shared node message weights - - Returns: - tuple: update bond features, update angle features - """ - node_features = bond_features[graph.ndata["bond_index"]] - edge_features = angle_features - aux_edge_features = atom_features[graph.edata["center_atom_index"]] - - bond_features_, angle_features = self.conv_layer( - graph, - node_features, - edge_features, - aux_edge_features, - shared_node_weights, - ) - - bond_features_ = self.bond_dropout(bond_features_) - angle_features = self.angle_dropout(angle_features) - - bond_features[graph.ndata["bond_index"]] = bond_features_ - - return bond_features, angle_features - - -class EFFGraphConv(nn.Module): - """A EFF atom graph convolution layer in DGL.""" - - def __init__( - self, - node_update_func: Module, - node_out_func: Module, - edge_update_func: Module | None, - node_weight_func: Module | None, - edge_weight_func: Module | None, - state_update_func: Module | None, - ): - """ - Args: - node_update_func: Update function for message between nodes (atoms) - node_out_func: Output function for nodes (atoms), after message aggregation - edge_update_func: Update function for edges (bonds). If None is given, the - edges are not updated. - node_weight_func: Weight function for radial basis functions. - If None is given, no layer-wise weights will be used. - edge_weight_func: Weight function for radial basis functions - If None is given, no layer-wise weights will be used. - state_update_func: Update function for state feats. - """ - super().__init__() - self.include_state = state_update_func is not None - self.edge_update_func = edge_update_func - self.edge_weight_func = edge_weight_func - self.node_update_func = node_update_func - self.node_out_func = node_out_func - self.node_weight_func = node_weight_func - self.state_update_func = state_update_func - - @classmethod - def from_dims( - cls, - activation: Module, - node_dims: Sequence[int], - edge_dims: Sequence[int] | None = None, - state_dims: Sequence[int] | None = None, - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - rbf_order: int = 0, - ): - """Create a EFFAtomGraphConv layer from dimensions. - - Args: - activation: activation function - node_dims: NN architecture for node update function given as a list of - dimensions of each layer. - edge_dims: NN architecture for edge update function given as a list of - dimensions of each layer. - Default = None - state_dims: NN architecture for state update function given as a list of - dimensions of each layer. - Default = None - normalization: Normalization type to use in update functions. either "graph" or "layer" - If None, no normalization is applied. - Default = None - normalize_hidden: Whether to normalize hidden features. - Default = False - rbf_order (int): RBF order specifying input dimensions for linear layer - specifying message weights. If 0, no layer-wise weights are used. - Default = 0 - - Returns: - EFFAtomGraphConv - """ - norm_kwargs = ( - {"batched_field": "edge"} if normalization == "graph" else None - ) - - node_update_func = GatedMLP_norm( - in_feats=node_dims[0], - dims=node_dims[1:], - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - norm_kwargs=norm_kwargs, - ) - node_out_func = nn.Linear( - in_features=node_dims[-1], out_features=node_dims[-1], bias=False - ) - node_weight_func = ( - nn.Linear( - in_features=rbf_order, out_features=node_dims[-1], bias=False - ) - if rbf_order > 0 - else None - ) - edge_update_func = ( - GatedMLP_norm( - in_feats=edge_dims[0], - dims=edge_dims[1:], - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - norm_kwargs=norm_kwargs, - ) - if edge_dims is not None - else None - ) - edge_weight_func = ( - nn.Linear( - in_features=rbf_order, out_features=edge_dims[-1], bias=False - ) - if rbf_order > 0 and edge_dims is not None - else None - ) - state_update_func = ( - MLP( - state_dims, - activation, - activate_last=True, - ) - if state_dims is not None - else None - ) - - return cls( - node_update_func=node_update_func, - node_out_func=node_out_func, - edge_update_func=edge_update_func, - node_weight_func=node_weight_func, - edge_weight_func=edge_weight_func, - state_update_func=state_update_func, - ) - - def _edge_udf(self, edges: dgl.udf.EdgeBatch) -> dict[str, Tensor]: - """Edge user defined update function. - - Update for bond features (edges) in atom graph. - - Args: - edges: edges in atom graph (ie bonds) - - Returns: - edge_update: edge features update - """ - atom_i = edges.src["features"] # first atom features - atom_j = edges.dst["features"] # second atom features - bond_ij = edges.data["features"] # bond features - if self.include_state: - global_state = edges.data["global_state"] - inputs = torch.hstack([atom_i, bond_ij, atom_j, global_state]) - else: - inputs = torch.hstack([atom_i, bond_ij, atom_j]) - - edge_update = self.edge_update_func(inputs, edges._graph) # type: ignore - if self.edge_weight_func is not None: - rbf = edges.data["bond_expansion"] - rbf = rbf.float() - edge_update = edge_update * self.edge_weight_func(rbf) - - return {"feat_update": edge_update} - - def edge_update_( - self, graph: dgl.DGLGraph, shared_weights: Tensor | None - ) -> Tensor: - """Perform edge update -> bond features. - - Args: - graph: atom graph - shared_weights: atom graph edge weights shared between convolution layers - - Returns: - edge_update: edge features update - """ - graph.apply_edges(self._edge_udf) - edge_update = graph.edata["feat_update"] - if shared_weights is not None: - edge_update = edge_update * shared_weights - return edge_update - - def node_update_( - self, graph: dgl.DGLGraph, shared_weights: Tensor | None - ) -> Tensor: - """Perform node update -> atom features. - - Args: - graph: DGL atom graph - shared_weights: node message shared weights - - Returns: - node_update: updated node features - """ - src, dst = graph.edges() - atom_i = graph.ndata["features"][src] # first atom features - atom_j = graph.ndata["features"][dst] # second atom features - bond_ij = graph.edata["features"] # bond features - - if self.include_state: - global_state = graph.edata["global_state"] - inputs = torch.hstack([atom_i, bond_ij, atom_j, global_state]) - else: - inputs = torch.hstack([atom_i, bond_ij, atom_j]) - - messages = self.node_update_func(inputs, graph) - - # smooth out the messages with layer-wise weights - if self.node_weight_func is not None: - rbf = graph.edata["bond_expansion"] - rbf = rbf.float() - messages = messages * self.node_weight_func(rbf) - - # smooth out the messages with shared weights - if shared_weights is not None: - messages = messages * shared_weights - - # message passing - graph.edata["message"] = messages - graph.update_all( - fn.copy_e("message", "message"), fn.sum("message", "feat_update") - ) - - # update nodes - node_update = self.node_out_func( - graph.ndata["feat_update"] - ) # the bond update - - return node_update - - def state_update_(self, graph: dgl.DGLGraph, state_attr: Tensor) -> Tensor: - """Perform attribute (global state) update. - - Args: - graph: atom graph - state_attr: global state features - - Returns: - state_update: state features update - """ - node_avg = dgl.readout_nodes(graph, feat="features", op="mean") - inputs = torch.hstack([state_attr, node_avg]) - state_attr = self.state_update_func(inputs) # type: ignore - return state_attr - - def forward( - self, - graph: dgl.DGLGraph, - node_features: Tensor, - edge_features: Tensor, - state_attr: Tensor, - shared_node_weights: Tensor | None, - shared_edge_weights: Tensor | None, - ) -> tuple[Tensor, Tensor, Tensor]: - """Perform sequence of edge->node->states updates. - - Args: - graph: atom graph - node_features: node features - edge_features: edge features - state_attr: state attributes - shared_node_weights: shared node message weights - shared_edge_weights: shared edge message weights - - Returns: - tuple: updated node features, updated edge features, updated state attributes - """ - with graph.local_scope(): - graph.ndata["features"] = node_features - graph.edata["features"] = edge_features - - if self.include_state: - graph.edata["global_state"] = dgl.broadcast_edges( - graph, state_attr - ) - - if self.edge_update_func is not None: - edge_update = self.edge_update_(graph, shared_edge_weights) - new_edge_features = edge_features + edge_update - graph.edata["features"] = new_edge_features - else: - new_edge_features = edge_features - - node_update = self.node_update_(graph, shared_node_weights) - new_node_features = node_features + node_update - graph.ndata["features"] = new_node_features - - if self.include_state: - state_attr = self.state_update_(graph, state_attr) # type: ignore - - return new_node_features, new_edge_features, state_attr - - -class EFFAtomGraphBlock(nn.Module): - """ - A EFF atom graph block as a sequence of operations - involving a message passing layer over the atom graph. - """ - - def __init__( - self, - num_atom_feats: int, - num_bond_feats: int, - activation: Module, - atom_hidden_dims: Sequence[int], - bond_hidden_dims: Sequence[int] | None = None, - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - num_state_feats: int | None = None, - rbf_order: int = 0, - dropout: float = 0.0, - ): - """. - - Args: - num_atom_feats: number of atom features - num_bond_feats: number of bond features - activation: activation function - atom_hidden_dims: dimensions of atom convolution hidden layers - bond_hidden_dims: dimensions of bond update hidden layers. - normalization: Normalization type to use in update functions. either "graph" or "layer" - If None, no normalization is applied. - Default = None - normalize_hidden: Whether to normalize hidden features. - Default = False - num_state_feats: number of state features if self.include_state is True - Default = None - rbf_order: RBF order specifying input dimensions for linear layer - specifying message weights. If 0, no layer-wise weights are used. - Default = False - dropout: dropout probability. - Default = 0.0 - """ - super().__init__() - - node_input_dim = 2 * num_atom_feats + num_bond_feats - if num_state_feats is not None: - node_input_dim += num_state_feats - state_dims = [ - num_atom_feats + num_state_feats, - *atom_hidden_dims, - num_state_feats, - ] - else: - state_dims = None - node_dims = [node_input_dim, *atom_hidden_dims, num_atom_feats] - edge_dims = ( - [node_input_dim, *bond_hidden_dims, num_bond_feats] - if bond_hidden_dims is not None - else None - ) - - self.conv_layer = EFFGraphConv.from_dims( - activation=activation, - node_dims=node_dims, - edge_dims=edge_dims, - state_dims=state_dims, - normalization=normalization, - normalize_hidden=normalize_hidden, - rbf_order=rbf_order, - ) - - if normalization == "graph": - self.atom_norm = GraphNorm(num_atom_feats, batched_field="node") - self.bond_norm = GraphNorm(num_bond_feats, batched_field="edge") - elif normalization == "layer": - self.atom_norm = LayerNorm(num_atom_feats) - self.bond_norm = LayerNorm(num_bond_feats) - else: - self.atom_norm = None - self.bond_norm = None - - self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() - - def forward( - self, - graph: dgl.DGLGraph, - atom_features: Tensor, - bond_features: Tensor, - state_attr: Tensor, - shared_node_weights: Tensor | None, - shared_edge_weights: Tensor | None, - ) -> tuple[Tensor, Tensor, Tensor]: - """Perform sequence of bond(optional)->atom->states(optional) updates. - - Args: - graph: atom graph - atom_features: node features - bond_features: edge features - state_attr: state attributes - shared_node_weights: node message weights shared amongst layers - shared_edge_weights: edge message weights shared amongst layers - """ - atom_features, bond_features, state_attr = self.conv_layer( - graph=graph, - node_features=atom_features, - edge_features=bond_features, - state_attr=state_attr, - shared_node_weights=shared_node_weights, - shared_edge_weights=shared_edge_weights, - ) - # move skip connections here? dropout before skip connections? - atom_features = self.dropout(atom_features) - bond_features = self.dropout(bond_features) - if self.atom_norm is not None: - atom_features = self.atom_norm(atom_features, graph) - if self.bond_norm is not None: - bond_features = self.bond_norm(bond_features, graph) - if state_attr is not None: - state_attr = self.dropout(state_attr) - - return atom_features, bond_features, state_attr - - -class MLP_norm(nn.Module): - """Multi-layer perceptron with normalization layer.""" - - def __init__( - self, - dims: list[int], - activation: nn.Module | None = None, - activate_last: bool = False, - use_bias: bool = True, - bias_last: bool = True, - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - norm_kwargs: dict[str, Any] | None = None, - ) -> None: - """ - Args: - dims: Dimensions of each layer of MLP. - activation: activation: Activation function. - activate_last: Whether to apply activation to last layer. - use_bias: Whether to use bias. - bias_last: Whether to apply bias to last layer. - normalization: normalization name. "graph" or "layer" - normalize_hidden: Whether to normalize output of hidden layers. - norm_kwargs: Keyword arguments for normalization layer. - """ - super().__init__() - self._depth = len(dims) - 1 - self.layers = nn.ModuleList() - self.norm_layers = ( - nn.ModuleList() if normalization in ("graph", "layer") else None - ) - self.activation = ( - activation if activation is not None else nn.Identity() - ) - self.activate_last = activate_last - self.normalize_hidden = normalize_hidden - norm_kwargs = norm_kwargs or {} - norm_kwargs = cast(dict, norm_kwargs) - - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - if i < self._depth - 1: - self.layers.append(Linear(in_dim, out_dim, bias=use_bias)) - if normalize_hidden and self.norm_layers is not None: - if normalization == "graph": - self.norm_layers.append( - GraphNorm(out_dim, **norm_kwargs) - ) - elif normalization == "layer": - self.norm_layers.append( - LayerNorm(out_dim, **norm_kwargs) - ) - else: - self.layers.append( - Linear(in_dim, out_dim, bias=use_bias and bias_last) - ) - if self.norm_layers is not None: - if normalization == "graph": - self.norm_layers.append( - GraphNorm(out_dim, **norm_kwargs) - ) - elif normalization == "layer": - self.norm_layers.append( - LayerNorm(out_dim, **norm_kwargs) - ) - - def forward(self, inputs: torch.Tensor, g=None) -> torch.Tensor: - """Applies all layers in turn. - - Args: - inputs: input feature tensor. - g: graph of model, needed for graph normalization - - Returns: - output feature tensor. - """ - x = inputs - for i in range(self._depth - 1): - x = self.layers[i](x) - if self.norm_layers is not None and self.normalize_hidden: - x = self.norm_layers[i](x, g) - x = self.activation(x) - - x = self.layers[-1](x) - if self.norm_layers is not None: - x = self.norm_layers[-1](x, g) - if self.activate_last: - x = self.activation(x) - return x - - -class ActivationFunction(Enum): - """Enumeration of optional activation functions.""" - - swish = nn.SiLU - # sigmoid = nn.Sigmoid - # tanh = nn.Tanh - # softplus = nn.Softplus - # softplus2 = SoftPlus2 - # softexp = SoftExponential - - -class ALIGNNeFF(nn.Module): - """Main EFF model.""" - - __version__ = 1 - - def __init__( - self, - config: ALIGNNeFFConfig = ALIGNNeFFConfig(name="alignn_eff"), - element_types: tuple[str, ...] | None = None, - dim_atom_embedding: int = 64, - dim_bond_embedding: int = 64, - dim_angle_embedding: int = 64, - dim_state_embedding: int | None = None, - dim_state_feats: int | None = None, - non_linear_bond_embedding: bool = False, - non_linear_angle_embedding: bool = False, - cutoff: float = 4.0, - threebody_cutoff: float = 3.0, - cutoff_exponent: int = 5, - max_n: int = 9, - max_f: int = 4, - learn_basis: bool = True, - num_blocks: int = 4, - shared_bond_weights: ( - Literal["bond", "three_body_bond", "both"] | None - ) = "both", - layer_bond_weights: ( - Literal["bond", "three_body_bond", "both"] | None - ) = None, - atom_conv_hidden_dims: Sequence[int] = (64,), - bond_update_hidden_dims: Sequence[int] | None = None, - bond_conv_hidden_dims: Sequence[int] = (64,), - angle_update_hidden_dims: Sequence[int] | None = (), - conv_dropout: float = 0.0, - final_mlp_type: Literal["gated", "mlp"] = "mlp", - final_hidden_dims: Sequence[int] = (64, 64), - final_dropout: float = 0.0, - pooling_operation: Literal["sum", "mean"] = "sum", - readout_field: Literal[ - "atom_feat", "bond_feat", "angle_feat" - ] = "atom_feat", - activation_type: str = "swish", - normalization: Literal["graph", "layer"] | None = None, - normalize_hidden: bool = False, - is_intensive: bool = False, - num_targets: int = 1, - num_site_targets: int = 1, - task_type: Literal["regression", "classification"] = "regression", - ): - super().__init__() - - # self.save_args(locals(), kwargs) - - activation: nn.Module = ActivationFunction[activation_type].value() - - element_types = element_types or DEFAULT_ELEMENTS - - # basis expansions for bond lengths, triple interaction bond lengths and angles - self.bond_expansion = RadialBesselFunction( - max_n=max_n, cutoff=cutoff, learnable=learn_basis - ) - self.threebody_bond_expansion = RadialBesselFunction( - max_n=max_n, cutoff=threebody_cutoff, learnable=learn_basis - ) - self.angle_expansion = FourierExpansion( - max_f=max_f, learnable=learn_basis - ) - - # embedding block for atom, bond, angle, and optional state features - self.include_states = dim_state_feats is not None - self.state_embedding = ( - nn.Embedding(dim_state_feats, dim_state_embedding) - if self.include_states - else None - ) - self.atom_embedding = nn.Embedding( - len(element_types), dim_atom_embedding - ) - - # self.atom_embedding = MLP_norm( - # 1, dim_state_embedding - # ) - - self.bond_embedding = MLP_norm( - [max_n, dim_bond_embedding], - activation=activation, - activate_last=non_linear_bond_embedding, - bias_last=False, - ) - self.angle_embedding = MLP_norm( - [2 * max_f + 1, dim_angle_embedding], - activation=activation, - activate_last=non_linear_angle_embedding, - bias_last=False, - ) - - # shared message bond distance smoothing weights - self.atom_bond_weights = ( - nn.Linear(max_n, dim_atom_embedding, bias=False) - if shared_bond_weights in ["bond", "both"] - else None - ) - self.bond_bond_weights = ( - nn.Linear(max_n, dim_bond_embedding, bias=False) - if shared_bond_weights in ["bond", "both"] - else None - ) - self.threebody_bond_weights = ( - nn.Linear(max_n, dim_bond_embedding, bias=False) - if shared_bond_weights in ["three_body_bond", "both"] - else None - ) - - # operations involving the graph (i.e. atom graph) to update atom and bond features - self.atom_graph_layers = nn.ModuleList( - [ - EFFAtomGraphBlock( - num_atom_feats=dim_atom_embedding, - num_bond_feats=dim_bond_embedding, - atom_hidden_dims=atom_conv_hidden_dims, - bond_hidden_dims=bond_update_hidden_dims, - num_state_feats=dim_state_embedding, - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - dropout=conv_dropout, - rbf_order=0, - ) - for _ in range(num_blocks) - ] - ) - - # operations involving the line graph (i.e. bond graph) to update bond and angle features - self.bond_graph_layers = nn.ModuleList( - [ - EFFBondGraphBlock( - num_atom_feats=dim_atom_embedding, - num_bond_feats=dim_bond_embedding, - num_angle_feats=dim_angle_embedding, - bond_hidden_dims=bond_conv_hidden_dims, - angle_hidden_dims=angle_update_hidden_dims, - activation=activation, - normalization=normalization, - normalize_hidden=normalize_hidden, - bond_dropout=conv_dropout, - angle_dropout=conv_dropout, - rbf_order=0, - ) - for _ in range(num_blocks - 1) - ] - ) - - self.sitewise_readout = ( - nn.Linear(dim_atom_embedding, num_site_targets) - if num_site_targets > 0 - else lambda x: None - ) - print("final_mlp_type", final_mlp_type) - input_dim = ( - dim_atom_embedding - if readout_field == "node_feat" - else dim_bond_embedding - ) - - self.final_layer = MLP_norm( - dims=[input_dim, *final_hidden_dims, num_targets], - activation=activation, - activate_last=False, - ) - - self.element_types = element_types - self.max_n = max_n - self.max_f = max_f - self.cutoff = cutoff - self.cutoff_exponent = cutoff_exponent - self.three_body_cutoff = threebody_cutoff - - self.n_blocks = num_blocks - self.readout_operation = pooling_operation - self.readout_field = readout_field - self.readout_type = final_mlp_type - - self.task_type = task_type - self.is_intensive = is_intensive - - def forward( - self, - g, - state_attr: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Forward pass of the model. - - Args: - g (dgl.DGLGraph): Input g. - state_attr (torch.Tensor, optional): State features. Defaults to None. - l_g (dgl.DGLGraph, optional): Line graph. Defaults to None and is computed internally. - - Returns: - torch.Tensor: Model output. - """ - g, l_g, lat = g - st = lat.new_zeros([g.batch_size, 3, 3]) - st.requires_grad_(True) - lattice = lat @ (torch.eye(3, device=lat.device) + st) - g.edata["lattice"] = torch.repeat_interleave( - lattice, g.batch_num_edges(), dim=0 - ) - g.edata["pbc_offshift"] = ( - g.edata["images"].unsqueeze(dim=-1) * g.edata["lattice"] - ).sum(dim=1) - g.ndata["pos"] = ( - g.ndata["frac_coords"].unsqueeze(dim=-1) - * torch.repeat_interleave(lattice, g.batch_num_nodes(), dim=0) - ).sum(dim=1) - g.ndata["pos"].requires_grad_(True) - - # compute bond vectors and distances and add to g, needs to be computed here to register gradients - bond_vec, bond_dist = compute_pair_vector_and_distance(g) - g.edata["bond_vec"] = bond_vec.to(g.device) - g.edata["bond_dist"] = bond_dist.to(g.device) - bond_expansion = self.bond_expansion(bond_dist) - smooth_cutoff = polynomial_cutoff( - bond_expansion, self.cutoff, self.cutoff_exponent - ) - g.edata["bond_expansion"] = smooth_cutoff * bond_expansion - - # create bond graph (line graoh) with necessary node and edge data - # print("self.readout_field", self.readout_field) - bond_graph = create_line_graph( - g, self.three_body_cutoff, directed=True - ) - - bond_graph.ndata["bond_index"] = bond_graph.ndata["edge_ids"] - threebody_bond_expansion = self.threebody_bond_expansion( - bond_graph.ndata["bond_dist"] - ) - smooth_cutoff = polynomial_cutoff( - threebody_bond_expansion, - self.three_body_cutoff, - self.cutoff_exponent, - ) - bond_graph.ndata["bond_expansion"] = ( - smooth_cutoff * threebody_bond_expansion - ) - bond_indices = bond_graph.ndata["bond_index"][bond_graph.edges()[0]] - bond_graph.edata["center_atom_index"] = g.edges()[1][bond_indices] - bond_graph.apply_edges(compute_theta) - bond_graph.edata["angle_expansion"] = self.angle_expansion( - bond_graph.edata["theta"] - ) - - # atom_features = self.atom_embedding(g.ndata["atom_features"]) - atom_features = self.atom_embedding(g.ndata["node_type"]) - - bond_features = self.bond_embedding(g.edata["bond_expansion"]) - angle_features = self.angle_embedding( - bond_graph.edata["angle_expansion"] - ) - if self.state_embedding is not None and state_attr is not None: - state_attr = self.state_embedding(state_attr) - else: - state_attr = None - - # shared message weights - atom_bond_weights = ( - self.atom_bond_weights(g.edata["bond_expansion"]) - if self.atom_bond_weights is not None - else None - ) - # print("atom_bond_weights", torch.sum(atom_bond_weights)) - bond_bond_weights = ( - self.bond_bond_weights(g.edata["bond_expansion"]) - if self.bond_bond_weights is not None - else None - ) - # print("bond_bond_weights", torch.sum(bond_bond_weights)) - threebody_bond_weights = ( - self.threebody_bond_weights(bond_graph.ndata["bond_expansion"]) - if self.threebody_bond_weights is not None - else None - ) - - # message passing layers - for i in range(self.n_blocks - 1): - atom_features, bond_features, state_attr = self.atom_graph_layers[ - i - ]( - g, - atom_features, - bond_features, - state_attr, - atom_bond_weights, - bond_bond_weights, - ) - bond_features, angle_features = self.bond_graph_layers[i]( - bond_graph, - atom_features, - bond_features, - angle_features, - threebody_bond_weights, - ) - - atom_features, bond_features, state_attr = self.atom_graph_layers[-1]( - g, - atom_features, - bond_features, - state_attr, - atom_bond_weights, - bond_bond_weights, - ) - - g.ndata["atom_feat"] = self.final_layer(atom_features) - structure_properties = readout_nodes( - g, "atom_feat", op=self.readout_operation - ) - # self.add_ewald=True - # ewald_en = 0 - # if self.add_ewald: - # ewald_en = get_atomic_repulsion(g) - # total_energies = (torch.squeeze(structure_properties)) +ewald_en/g.num_nodes() - total_energies = torch.squeeze(structure_properties) - - penalty_factor = 500.0 # Penalty weight, tune as needed - penalty_factor = 1000.0 # Penalty weight, tune as needed - penalty_threshold = 1.0 # 1 angstrom - - # Calculate penalties for distances less than the threshold - penalties = torch.where( - bond_dist < penalty_threshold, - penalty_factor * (penalty_threshold - bond_dist), - torch.zeros_like(bond_dist), - ) - total_penalty = torch.sum(penalties) - - # min_distance=1.0 - # mask = bond_dist < min_distance - # penalty = torch.zeros_like(bond_dist) - # epsilon=1.0 - # alpha=12 - # Smooth penalty calculation for close distances - # penalty[mask] = epsilon * ((min_distance / bond_dist[mask]) ** alpha) - - # Sum up the penalties - # total_penalty = torch.sum(penalty) - total_energies += total_penalty - forces = torch.zeros(1) - stresses = torch.zeros(1) - hessian = torch.zeros(1) - grad_vars = [ - g.ndata["pos"], - st, - ] # if self.calc_stresses else [g.ndata["pos"]] - # print('total_energies',total_energies) - grads = grad( - total_energies, - grad_vars, - grad_outputs=torch.ones_like(total_energies), - create_graph=True, - retain_graph=True, - ) - forces = -grads[0] - volume = torch.abs(torch.det(lattice)) - sts = -grads[1] - scale = 1.0 / volume * -160.21766208 - sts = ( - [i * j for i, j in zip(sts, scale)] - if sts.dim() == 3 - else [sts * scale] - ) - stresses = torch.cat(sts) - result = {} - result["out"] = total_energies - result["grad"] = forces - result["stresses"] = stresses - return result diff --git a/alignn/models/alignn_ff2.py b/alignn/models/alignn_ff2.py deleted file mode 100644 index 9bb3a91..0000000 --- a/alignn/models/alignn_ff2.py +++ /dev/null @@ -1,376 +0,0 @@ -from math import pi, sqrt -from typing import Tuple, Union -from torch.autograd import grad -import dgl -import dgl.function as fn -from dgl.nn import AvgPooling -import torch -from typing import Literal -from torch import nn -from torch.nn import functional as F -from alignn.models.utils import ( - RadialBesselFunction, - RBFExpansion, - RBFExpansionSmooth, - BesselExpansion, - SphericalHarmonicsExpansion, - FourierExpansion, - compute_pair_vector_and_distance, - check_line_graph, - cutoff_function_based_edges, - compute_cartesian_coordinates, - MLPLayer, -) -from alignn.graphs import compute_bond_cosines -from alignn.utils import BaseSettings -from dgl import GCNNorm - - -class ALIGNNFF2Config(BaseSettings): - """Hyperparameter schema for jarvisdgl.models.alignn.""" - - name: Literal["alignn_ff2"] - alignn_layers: int = 2 - gcn_layers: int = 2 - atom_input_features: int = 1 - edge_input_features: int = 64 - triplet_input_features: int = 40 - embedding_features: int = 64 - hidden_features: int = 128 - output_features: int = 1 - grad_multiplier: int = -1 - calculate_gradient: bool = True - atomwise_output_features: int = 0 - graphwise_weight: float = 1.0 - gradwise_weight: float = 1.0 - stresswise_weight: float = 0.0 - atomwise_weight: float = 0.0 - classification: bool = False - batch_stress: bool = False - use_cutoff_function: bool = True - use_penalty: bool = True - multiply_cutoff: bool = True - inner_cutoff: float = 4.0 # Angstrom - stress_multiplier: float = 1.0 - sigma: float = 0.2 - exponent: int = 4 - extra_features: int = 0 - - -class GraphConv(nn.Module): - """ - Custom Graph Convolution layer with smooth transformations on bond lengths and angles. - """ - - def __init__( - self, in_feats, out_feats, activation=nn.SiLU(), hidden_features=64 - ): - super(GraphConv, self).__init__() - self.fc = nn.Linear( - in_feats, out_feats - ) # Linear transformation for features - self.activation = activation - self.edge_transform = nn.Linear( - hidden_features, out_feats - ) # For bond-length based transformation - - def forward(self, g, node_feats, bond_feats): - """ - Forward pass with bond length handling for smooth transitions. - """ - # Transform bond (edge) features - # print('bond_feats',bond_feats.shape) - bond_feats = self.edge_transform(bond_feats) - - # Message passing: message = transformed edge feature + node feature - g.ndata["h"] = node_feats - g.edata["e"] = bond_feats - g.update_all( - message_func=fn.u_add_e( - "h", "e", "m" - ), # Add node and edge features - reduce_func=fn.sum("m", "h"), # Sum messages for each node - ) - - # Final node feature transformation - node_feats = self.fc(g.ndata["h"]) - return self.activation(node_feats), bond_feats - - -class AtomGraphBlock(nn.Module): - """ - Atom Graph Block that processes atom-centric features and uses GraphConv for updates. - """ - - def __init__(self, in_feats, out_feats, n_layers=2, hidden_features=64): - super(AtomGraphBlock, self).__init__() - self.layers = nn.ModuleList( - [ - GraphConv( - in_feats if i == 0 else out_feats, - out_feats, - hidden_features=hidden_features, - ) - for i in range(n_layers) - ] - ) - - def forward(self, g, node_feats, bond_feats): - for layer in self.layers: - node_feats, bond_feats = layer(g, node_feats, bond_feats) - return node_feats, bond_feats - - -class BondGraphBlock(nn.Module): - """ - Bond Graph Block that applies additional processing on bond-based features. - """ - - def __init__(self, in_feats, out_feats, n_layers=2, hidden_features=64): - super(BondGraphBlock, self).__init__() - # self.fc = nn.Linear(in_feats, out_feats) # Linear transformation for bond features - # self.activation = activation - self.layers = nn.ModuleList( - [ - GraphConv( - in_feats if i == 0 else out_feats, - out_feats, - hidden_features=hidden_features, - ) - for i in range(n_layers) - ] - ) - - def forward(self, g, bond_feats, angle_feats): - """ - Process bond features with smooth transformations. - """ - # Transform bond features and apply smooth activation - for layer in self.layers: - bond_feats, angle_feats = layer(g, bond_feats, angle_feats) - return bond_feats, angle_feats - - -class ALIGNNFF2(nn.Module): - """Atomistic Line graph network. - - Chain alternating gated graph convolution updates on crystal graph - and atomistic line graph. - """ - - def __init__( - self, - config: ALIGNNFF2Config = ALIGNNFF2Config(name="alignn_ff2"), - ): - """Initialize class with number of input features, conv layers.""" - super().__init__() - self.classification = config.classification - self.config = config - if self.config.gradwise_weight == 0: - self.config.calculate_gradient = False - # if self.config.atomwise_weight == 0: - # self.config.atomwise_output_features = None - self.atom_embedding = MLPLayer( - config.atom_input_features, config.hidden_features - ) - self.edge_embedding = nn.Sequential( - RadialBesselFunction( - max_n=config.edge_input_features, cutoff=config.inner_cutoff - ), - # RBFExpansionSmooth(num_centers=config.edge_input_features, cutoff=config.inner_cutoff, sigma=config.sigma), - MLPLayer(config.edge_input_features, config.embedding_features), - MLPLayer(config.embedding_features, config.hidden_features), - ) - self.angle_embedding = nn.Sequential( - RadialBesselFunction( - max_n=config.edge_input_features, cutoff=config.inner_cutoff - ), - # RBFExpansionSmooth(num_centers=config.triplet_input_features, cutoff=1.0, sigma=config.sigma), - MLPLayer(config.edge_input_features, config.embedding_features), - MLPLayer(config.embedding_features, config.hidden_features), - ) - - self.atom_graph_layers = nn.ModuleList( - [ - AtomGraphBlock( - config.hidden_features, - config.hidden_features, - n_layers=config.gcn_layers, - hidden_features=config.hidden_features, - ) - ] - ) - - self.bond_graph_layers = nn.ModuleList( - [ - BondGraphBlock( - config.hidden_features, - config.hidden_features, - n_layers=config.gcn_layers, - hidden_features=config.hidden_features, - ) - ] - ) - - self.angle_graph_layers = nn.ModuleList( - [ - BondGraphBlock( - config.hidden_features, - config.hidden_features, - hidden_features=config.hidden_features, - ) - ] - ) - - self.gnorm = GCNNorm() - - self.readout = AvgPooling() - - if config.extra_features != 0: - self.readout_feat = AvgPooling() - # Credit for extra_features work: - # Gong et al., https://doi.org/10.48550/arXiv.2208.05039 - self.extra_feature_embedding = MLPLayer( - config.extra_features, config.extra_features - ) - # print('config.output_features',config.output_features) - self.fc3 = nn.Linear( - config.hidden_features + config.extra_features, - config.output_features, - ) - self.fc1 = MLPLayer( - config.extra_features + config.hidden_features, - config.extra_features + config.hidden_features, - ) - self.fc2 = MLPLayer( - config.extra_features + config.hidden_features, - config.extra_features + config.hidden_features, - ) - - if config.atomwise_output_features > 0: - # if config.atomwise_output_features is not None: - self.fc_atomwise = nn.Linear( - config.hidden_features, config.atomwise_output_features - ) - - if self.classification: - self.fc = nn.Linear(config.hidden_features, 1) - self.softmax = nn.Sigmoid() - # self.softmax = nn.LogSoftmax(dim=1) - else: - self.fc = nn.Linear(config.hidden_features, config.output_features) - - def forward(self, g): - result = {} - if self.config.alignn_layers > 0: - g, lg, lat = g - lg = lg.local_var() - # print('lattice',lattice,lattice.shape) - else: - g, lat = g - - if self.config.extra_features != 0: - features = g.ndata["extra_features"] - features = self.extra_feature_embedding(features) - x = g.ndata.pop("atom_features") - x = self.atom_embedding(x) - - g = self.gnorm(g) - # Compute and embed bond lengths - g.ndata["cart_coords"] = compute_cartesian_coordinates(g, lat) - if self.config.calculate_gradient: - g.ndata["cart_coords"].requires_grad_(True) - - r, bondlength = compute_pair_vector_and_distance(g) - bondlength = torch.norm(r, dim=1) - y = self.edge_embedding(bondlength) - - # smooth_cutoff = polynomial_cutoff( - # bond_expansion, self.config.inner_cutoff, self.config.exponent - # ) - # bond_expansion *= smooth_cutoff - if self.config.use_cutoff_function: - if self.config.multiply_cutoff: - c_off = cutoff_function_based_edges( - bondlength, - inner_cutoff=self.config.inner_cutoff, - exponent=self.config.exponent, - ).unsqueeze(dim=1) - - y = self.edge_embedding(bondlength) * c_off - else: - bondlength = cutoff_function_based_edges( - bondlength, - inner_cutoff=self.config.inner_cutoff, - exponent=self.config.exponent, - ) - y = self.edge_embedding(bondlength) - else: - y = self.edge_embedding(bondlength) - out = torch.empty(1) # graph level output eg energy - lg = g.line_graph(shared=True) - lg.ndata["r"] = r - lg.apply_edges(compute_bond_cosines) - for atom_graph_layer in self.atom_graph_layers: - x, y = atom_graph_layer(g, x, y) - - if self.config.output_features is not None: - h = self.readout(g, x) - out = self.fc(h) - if self.config.extra_features != 0: - h_feat = self.readout_feat(g, features) - h = torch.cat((h, h_feat), 1) - h = self.fc1(h) - h = self.fc2(h) - out = self.fc3(h) - else: - out = torch.squeeze(out) - atomwise_pred = torch.empty(1) - if ( - self.config.atomwise_output_features > 0 - # self.config.atomwise_output_features is not None - and self.config.atomwise_weight != 0 - ): - atomwise_pred = self.fc_atomwise(x) - # atomwise_pred = torch.squeeze(self.readout(g, atomwise_pred)) - forces = torch.empty(1) - # gradient = torch.empty(1) - stress = torch.empty(1) - - if self.config.use_penalty: - penalty_factor = 500.0 # Penalty weight, tune as needed - penalty_factor = 0.01 # Penalty weight, tune as needed - penalty_threshold = 1.0 # 1 angstrom - - penalties = torch.where( - bondlength < penalty_threshold, - penalty_factor * (penalty_threshold - bondlength), - torch.zeros_like(bondlength), - ) - total_penalty = torch.sum(penalties) - out += total_penalty - - if self.config.calculate_gradient: - - # en_out = torch.sum(out)*g.num_nodes() - en_out = out # *g.num_nodes() - # en_out = (out) *g.num_nodes() - grad_vars = [g.ndata["cart_coords"]] - grads = grad( - en_out, - grad_vars, - grad_outputs=torch.ones_like(en_out), - create_graph=True, - retain_graph=True, - ) - forces_out = -1 * grads[0] * g.num_nodes() - # forces_out = -1*grads[0] - stresses = torch.eye(3) - - if self.classification: - out = self.softmax(out) - result["out"] = out - result["grad"] = forces_out - result["stresses"] = stress - result["atomwise_pred"] = atomwise_pred - return result diff --git a/alignn/models/utils.py b/alignn/models/utils.py index 4d8c58c..9acf735 100644 --- a/alignn/models/utils.py +++ b/alignn/models/utils.py @@ -3,218 +3,8 @@ from typing import Optional import numpy as np import torch -from math import pi import torch.nn as nn -import math import dgl -import torch -from typing import Any, Callable, Literal, cast - - -def get_atomic_repulsion(g, cutoff=5.0): - """ - Calculate atomic repulsion energy using pairwise Coulomb interactions within a cutoff distance. - - Parameters: - g (DGLGraph): ALIGNN graph with atom charges (Z) and precomputed bond lengths in g.edata['d']. - cutoff (float): Cutoff distance for pairwise interactions. - - Returns: - float: Atomic repulsion energy for the given graph. - """ - - # Atomic charges - Z = g.ndata["Z"].squeeze() # Ensure Z is a 1D tensor - bond_lengths = g.edata[ - "d" - ] # Precomputed bond lengths in Cartesian coordinates - - # Atomic indices for each edge - src, dst = g.edges() - - # Mask for distances below the cutoff - valid_edges = bond_lengths < cutoff - - # Get charges for each pair - Zi = Z[src[valid_edges]] - Zj = Z[dst[valid_edges]] - rij = bond_lengths[valid_edges] - - # Compute repulsion energy - repulsion_energy = torch.sum(Zi * Zj / rij) - - return repulsion_energy - - -class RadialBesselFunction(nn.Module): - - def __init__( - self, - max_n: int, - cutoff: float, - learnable: bool = False, - dtype=torch.float32, - ): - """ - Args: - max_n: int, max number of roots (including max_n) - cutoff: float, cutoff radius - learnable: bool, whether to learn the location of roots. - """ - super().__init__() - self.max_n = max_n - self.inv_cutoff = 1 / cutoff - self.norm_const = (2 * self.inv_cutoff) ** 0.5 - if learnable: - self.frequencies = torch.nn.Parameter( - data=torch.Tensor( - pi * torch.arange(1, self.max_n + 1, dtype=dtype) - ), - requires_grad=True, - ) - else: - self.register_buffer( - "frequencies", - pi * torch.arange(1, self.max_n + 1, dtype=dtype), - ) - - def forward(self, r: torch.Tensor) -> torch.Tensor: - r = r[:, None] # (nEdges,1) - d_scaled = r * self.inv_cutoff - return self.norm_const * torch.sin(self.frequencies * d_scaled) / r - - -def get_ewald_sum(g, lattice_mat, alpha=0.2, r_cut=10.0, k_cut=5): - """ - Calculate the Ewald sum energy for the DGL graph using precomputed rij vectors. - - Parameters: - g (DGLGraph): ALIGNN graph with atom features, fractional coordinates, and precomputed rij vectors. - alpha (float): Ewald splitting parameter, controls the balance between real and reciprocal space sums. - r_cut (float): Real-space cutoff distance for pairwise interactions. - k_cut (int): Reciprocal-space cutoff for Fourier components. - - Returns: - float: Ewald sum energy for the given graph. - """ - - # Atomic numbers (charges) and fractional coordinates - Z = g.ndata["Z"] # Atomic charges (assuming Z is atomic number) - cart_pos = g.ndata[ - "frac_coords" - ] # Fractional coordinates in Cartesian space - r_ij_vectors = g.edata[ - "r" - ] # Precomputed rij vectors in Cartesian coordinates - - # Initialize Ewald sum energy - ewald_energy = 0.0 - - # Real-space sum using precomputed rij vectors - src, dst = g.edges() # Get the source and destination nodes for each edge - for edge_idx in range(len(src)): - i = src[edge_idx] - j = dst[edge_idx] - - # Pairwise distance (norm of r_ij) - r = torch.norm(r_ij_vectors[edge_idx]) - - if r < r_cut: - ewald_energy += Z[i] * Z[j] * torch.erfc(alpha * r) / r - - # Reciprocal-space sum - # lattice_mat = g.ndata['lattice_mat'][0] # Assuming lattice matrix is uniform across nodes - recip_vectors = ( - 2 * pi * torch.inverse(lattice_mat).T - ) # Reciprocal lattice vectors - for h in range(-k_cut, k_cut + 1): - for k in range(-k_cut, k_cut + 1): - for l in range(-k_cut, k_cut + 1): - if h == 0 and k == 0 and l == 0: - continue - k_vec = ( - h * recip_vectors[:, 0] - + k * recip_vectors[:, 1] - + l * recip_vectors[:, 2] - ) - k_sq = torch.dot(k_vec, k_vec) - structure_factor = torch.sum( - Z * torch.exp(1j * torch.matmul(cart_pos, k_vec)) - ) - ewald_energy += ( - torch.exp(-k_sq / (4 * alpha**2)) / k_sq - ) * (torch.norm(structure_factor) ** 2) - - # Self-interaction correction - # print('Z',Z) - # ewald_energy -= alpha / torch.sqrt(pi) * torch.sum(Z ** 2) - ewald_energy -= ( - alpha / torch.sqrt(torch.tensor(torch.pi)) * torch.sum(Z**2) - ) - return ewald_energy.real # Return the real part of the energy - - -class BesselExpansion(nn.Module): - """Expand interatomic distances with spherical Bessel functions.""" - - def __init__( - self, - vmin: float = 0, - vmax: float = 8, - bins: int = 40, - cutoff: Optional[float] = None, - ): - """Register torch parameters for Bessel function expansion.""" - super().__init__() - self.vmin = vmin - self.vmax = vmax - self.bins = bins - self.cutoff = cutoff if cutoff is not None else vmax - - # Generate frequency parameters for Bessel functions - # Convert to float32 explicitly - frequencies = torch.tensor( - [(n * np.pi) / self.cutoff for n in range(1, bins + 1)], - dtype=torch.float32, - ) - self.register_buffer("frequencies", frequencies) - - # Precompute normalization factors - norm_factors = torch.tensor( - [np.sqrt(2 / self.cutoff) for _ in range(bins)], - dtype=torch.float32, - ) - self.register_buffer("norm_factors", norm_factors) - - def forward(self, distance: torch.Tensor) -> torch.Tensor: - """Apply Bessel function expansion to interatomic distance tensor.""" - # Ensure input is float32 - distance = distance.to(torch.float32) - - # Compute the zero-order spherical Bessel functions - x = distance.unsqueeze(-1) * self.frequencies - - # Handle the case where x is close to zero - mask = x.abs() < 1e-10 - j0 = torch.where(mask, torch.ones_like(x), torch.sin(x) / x) - - # Apply normalization - bessel_features = j0 * self.norm_factors - - # Apply smooth cutoff function if cutoff is specified - if self.cutoff < self.vmax: - envelope = self._smooth_cutoff(distance) - bessel_features = bessel_features * envelope.unsqueeze(-1) - - return bessel_features - - def _smooth_cutoff(self, distance: torch.Tensor) -> torch.Tensor: - """Apply smooth cutoff function to ensure continuity at boundary.""" - x = torch.pi * distance / self.cutoff - cutoffs = 0.5 * (torch.cos(x) + 1.0) - return torch.where( - distance <= self.cutoff, cutoffs, torch.zeros_like(distance) - ) class RBFExpansion(nn.Module): @@ -253,158 +43,6 @@ def forward(self, distance: torch.Tensor) -> torch.Tensor: ) -class FourierExpansion(nn.Module): - """Fourier Expansion of a (periodic) scalar feature.""" - - def __init__( - self, - max_f: int = 5, - interval: float = pi, - scale_factor: float = 1.0, - learnable: bool = False, - ): - """Args: - max_f (int): the maximum frequency of the Fourier expansion. - Default = 5 - interval (float): interval of the Fourier exp, such that functions - are orthonormal over [-interval, interval]. Default = pi - scale_factor (float): pre-factor to scale all values. - learnable (bool): whether to set the frequencies as learnable - Default = False. - """ - super().__init__() - self.max_f = max_f - self.interval = interval - self.scale_factor = scale_factor - # Initialize frequencies at canonical - if learnable: - self.frequencies = torch.nn.Parameter( - data=torch.arange(0, max_f + 1, dtype=torch.float32), - requires_grad=True, - ) - else: - self.register_buffer( - "frequencies", torch.arange(0, max_f + 1, dtype=torch.float32) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Expand x into cos and sin functions.""" - result = x.new_zeros(x.shape[0], 1 + 2 * self.max_f) - tmp = torch.outer(x, self.frequencies) - result[:, ::2] = torch.cos(tmp * pi / self.interval) - result[:, 1::2] = torch.sin(tmp[:, 1:] * pi / self.interval) - return result / self.interval * self.scale_factor - - -class SphericalHarmonicsExpansion(nn.Module): - """Expand angles with spherical harmonics.""" - - def __init__( - self, - vmin: float = 0, - vmax: float = math.pi, - bins: int = 20, - l_max: int = 3, - ): - """Register torch parameters for spherical harmonics expansion.""" - super().__init__() - self.vmin = vmin - self.vmax = vmax - self.bins = bins - self.l_max = l_max - self.num_harmonics = (l_max + 1) ** 2 - self.register_buffer( - "centers", torch.linspace(self.vmin, self.vmax, self.bins) - ) - - def forward(self, theta: torch.Tensor) -> torch.Tensor: - """Apply spherical harmonics expansion to angular tensors.""" - harmonics = [] - phi = torch.zeros_like(theta) - for l_x in range(self.l_max + 1): - for m in range(-l_x, l_x + 1): - y_lm = self._spherical_harmonic(l_x, m, theta, phi) - harmonics.append(y_lm) - return torch.stack(harmonics, dim=-1) - - def _legendre_polynomial( - self, l_x: int, m: int, x: torch.Tensor - ) -> torch.Tensor: - """ - Compute the associated Legendre polynomials P_l^m(x). - :param l: Degree of the polynomial. - :param m: Order of the polynomial. - :param x: Input tensor. - :return: Associated Legendre polynomial evaluated at x. - """ - pmm = torch.ones_like(x) - if m > 0: - somx2 = torch.sqrt((1 - x) * (1 + x)) - fact = 1.0 - for i in range(1, m + 1): - pmm = -pmm * fact * somx2 - fact += 2.0 - - if l_x == m: - return pmm - pmmp1 = x * (2 * m + 1) * pmm - if l_x == m + 1: - return pmmp1 - - pll = torch.zeros_like(x) - for ll in range(m + 2, l_x + 1): - pll = ((2 * ll - 1) * x * pmmp1 - (ll + m - 1) * pmm) / (ll - m) - pmm = pmmp1 - pmmp1 = pll - - return pll - - def _spherical_harmonic( - self, l_x: int, m: int, theta: torch.Tensor, phi: torch.Tensor - ) -> torch.Tensor: - """ - Compute the real part of the spherical harmonics Y_l^m(theta, phi). - :param l: Degree of the harmonic. - :param m: Order of the harmonic. - :param theta: Polar angle (in radians). - :param phi: Azimuthal angle (in radians). - :return: Real part of the spherical harmonic Y_l^m. - """ - sqrt2 = torch.sqrt(torch.tensor(2.0)) - if m > 0: - return ( - sqrt2 - * self._k(l_x, m) - * torch.cos(m * phi) - * self._legendre_polynomial(l_x, m, torch.cos(theta)) - ) - elif m < 0: - return ( - sqrt2 - * self._k(l_x, -m) - * torch.sin(-m * phi) - * self._legendre_polynomial(l_x, -m, torch.cos(theta)) - ) - else: - return self._k(l_x, 0) * self._legendre_polynomial( - l_x, 0, torch.cos(theta) - ) - - def _k(self, l_x: int, m: int) -> float: - """ - Normalization constant for the spherical harmonics. - :param l: Degree of the harmonic. - :param m: Order of the harmonic. - :return: Normalization constant. - """ - return math.sqrt( - (2 * l_x + 1) - / (4 * math.pi) - * math.factorial(l_x - m) - / math.factorial(l_x + m) - ) - - def compute_pair_vector_and_distance(g: dgl.DGLGraph): """Calculate bond vectors and distances using dgl graphs.""" # print('g.edges()',g.ndata["cart_coords"][g.edges()[1]].shape,g.edata["pbc_offshift"].shape) @@ -416,50 +54,6 @@ def compute_pair_vector_and_distance(g: dgl.DGLGraph): return bond_vec, bond_dist -def check_line_graph( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float -): - """Ensure that 3body line graph is compatible with a given graph. - - Args: - graph: atomistic graph - line_graph: line graph of atomistic graph - threebody_cutoff: cutoff for three-body interactions - """ - valid_three_body = graph.edata["d"] <= threebody_cutoff - if line_graph.num_nodes() == graph.edata["r"][valid_three_body].shape[0]: - line_graph.ndata["r"] = graph.edata["r"][valid_three_body] - line_graph.ndata["d"] = graph.edata["d"][valid_three_body] - line_graph.ndata["images"] = graph.edata["images"][valid_three_body] - else: - three_body_id = torch.concatenate(line_graph.edges()) - max_three_body_id = ( - torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0 - ) - line_graph.ndata["r"] = graph.edata["r"][:max_three_body_id] - line_graph.ndata["d"] = graph.edata["d"][:max_three_body_id] - line_graph.ndata["images"] = graph.edata["images"][:max_three_body_id] - - return line_graph - - -def cutoff_function_based_edges_old(r, inner_cutoff=4): - """Apply smooth cutoff to pairwise interactions - - r: bond lengths - inner_cutoff: cutoff radius - - inside cutoff radius, apply smooth cutoff envelope - outside cutoff radius: hard zeros - """ - ratio = r / inner_cutoff - return torch.where( - ratio <= 1, - 1 - 6 * ratio**5 + 15 * ratio**4 - 10 * ratio**3, - torch.zeros_like(r), - ) - - def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3): """Apply smooth cutoff to pairwise interactions @@ -492,7 +86,7 @@ def cutoff_function_based_edges(r, inner_cutoff=4, exponent=3): def compute_cartesian_coordinates(g, lattice, dtype=torch.float32): """ - Compute Cartesian coordinates from fractional coordinates and lattice matrices. + Compute Cartesian coords from fractional coords and lattice matrices. Args: g: DGL graph with 'frac_coords' as node data. @@ -531,49 +125,6 @@ def compute_cartesian_coordinates(g, lattice, dtype=torch.float32): return cart_coords -class RBFExpansionSmooth(nn.Module): - """ - RBF Expansion layer for bond lengths with smooth output variation. - """ - - def __init__(self, num_centers=10, cutoff=5.0, sigma=0.5): - super(RBFExpansionSmooth, self).__init__() - - # Initialize centers and sigma for Gaussian RBFs - self.cutoff = cutoff - self.sigma = sigma - self.centers = torch.linspace(0, cutoff, num_centers).view( - 1, -1 - ) # Shape (1, num_centers) - - def forward(self, bondlengths): - """ - Compute the RBF features for a batch of bond lengths. - - Parameters: - - bondlengths: Tensor of shape (batch_size,) containing bond lengths. - - Returns: - - RBF expanded features: Tensor of shape (batch_size, num_centers) with smoothly varying RBFs. - """ - # Reshape bondlengths to (batch_size, 1) for broadcasting - bondlengths = bondlengths.view(-1, 1) # Shape (batch_size, 1) - - # Calculate RBF values - rbf_features = torch.exp( - -((bondlengths - self.centers.to(bondlengths.device)) ** 2) - / (2 * self.sigma**2) - ) - - # Apply cutoff - mask = bondlengths <= self.cutoff - rbf_features = ( - rbf_features * mask.float() - ) # Mask to zero out beyond cutoff - - return rbf_features - - class MLPLayer(nn.Module): """Multilayer perceptron layer helper.""" @@ -590,277 +141,3 @@ def forward(self, x): """Linear, Batchnorm, silu layer.""" # print('xtype',x.dtype) return self.layer(x) - - -def _create_directed_line_graph( - graph: dgl.DGLGraph, threebody_cutoff: float -) -> dgl.DGLGraph: - with torch.no_grad(): - pg = prune_edges_by_features( - graph, - feat_name="bond_dist", - condition=lambda x: torch.gt(x, threebody_cutoff), - ) - """ - lg=graph.line_graph(shared=True) - lg.ndata["src_bond_sign"] = torch.ones( - (lg.number_of_nodes(), 1), - dtype=lg.ndata["bond_vec"].dtype, - device=lg.device, - ) - return lg - """ - src_indices, dst_indices = pg.edges() - images = pg.edata["images"] - all_indices = torch.arange( - pg.number_of_nodes(), device=graph.device - ).unsqueeze(dim=0) - num_bonds_per_atom = torch.count_nonzero( - src_indices.unsqueeze(dim=1) == all_indices, dim=0 - ) - num_edges_per_bond = (num_bonds_per_atom - 1).repeat_interleave( - num_bonds_per_atom - ) - lg_src = torch.empty( - num_edges_per_bond.sum(), dtype=torch.int64, device=graph.device - ) - lg_dst = torch.empty( - num_edges_per_bond.sum(), dtype=torch.int64, device=graph.device - ) - incoming_edges = src_indices.unsqueeze(1) == dst_indices - is_self_edge = src_indices == dst_indices - not_self_edge = ~is_self_edge - - n = 0 - # create line graph edges for bonds that are self edges in atom graph - if is_self_edge.any(): - edge_inds_s = is_self_edge.nonzero() - lg_dst_s = edge_inds_s.repeat_interleave( - num_edges_per_bond[is_self_edge] + 1 - ) - lg_src_s = incoming_edges[is_self_edge].nonzero()[:, 1].squeeze() - lg_src_s = lg_src_s[lg_src_s != lg_dst_s] - lg_dst_s = edge_inds_s.repeat_interleave( - num_edges_per_bond[is_self_edge] - ) - n = len(lg_dst_s) - lg_src[:n], lg_dst[:n] = lg_src_s, lg_dst_s - - # create line graph edges for bonds that are not self edges in atom graph - shared_src = src_indices.unsqueeze(1) == src_indices - back_tracking = (dst_indices.unsqueeze(1) == src_indices) & torch.all( - -images.unsqueeze(1) == images, axis=2 - ) - incoming = incoming_edges & (shared_src | ~back_tracking) - - edge_inds_ns = not_self_edge.nonzero().squeeze() - lg_src_ns = incoming[not_self_edge].nonzero()[:, 1].squeeze() - lg_dst_ns = edge_inds_ns.repeat_interleave( - num_edges_per_bond[not_self_edge] - ) - lg_src[n:], lg_dst[n:] = lg_src_ns, lg_dst_ns - lg = dgl.graph((lg_src, lg_dst)) - - for key in pg.edata: - lg.ndata[key] = pg.edata[key][: lg.number_of_nodes()] - - # we need to store the sign of bond vector when a bond is a src node in the line - # graph in order to appropriately calculate angles when self edges are involved - lg.ndata["src_bond_sign"] = torch.ones( - (lg.number_of_nodes(), 1), - dtype=lg.ndata["bond_vec"].dtype, - device=lg.device, - ) - # if we flip self edges then we need to correct computed angles by pi - angle - # lg.ndata["src_bond_sign"][edge_inds_s] = -lg.ndata["src_bond_sign"][edge_ind_s] - # find the intersection for the rare cases where not all edges end up as nodes in the line graph - all_ns, counts = torch.cat( - [ - torch.arange(lg.number_of_nodes(), device=graph.device), - edge_inds_ns, - ] - ).unique(return_counts=True) - lg_inds_ns = all_ns[torch.where(counts > 1)] - lg.ndata["src_bond_sign"][lg_inds_ns] = -lg.ndata["src_bond_sign"][ - lg_inds_ns - ] - - return lg - - -def prune_edges_by_features( - graph: dgl.DGLGraph, - feat_name: str, - condition: Callable[[torch.Tensor], torch.Tensor], - keep_ndata: bool = False, - keep_edata: bool = True, - *args, - **kwargs, -) -> dgl.DGLGraph: - if feat_name not in graph.edata: - raise ValueError( - f"Edge field {feat_name} not an edge feature in given graph." - ) - - valid_edges = torch.logical_not( - condition(graph.edata[feat_name], *args, **kwargs) - ) - valid_edges1 = torch.ones( - graph.num_edges(), dtype=torch.bool, device=graph.device - ) - # print('valid_edges',valid_edges,valid_edges.shape) - # print('valid_edges1',valid_edges1,valid_edges1.shape) - - src, dst = graph.edges() - src, dst = src[valid_edges], dst[valid_edges] - e_ids = valid_edges.nonzero().squeeze() - new_g = dgl.graph((src, dst), device=graph.device) - new_g.edata["edge_ids"] = e_ids # keep track of original edge ids - - if keep_ndata: - for key, value in graph.ndata.items(): - new_g.ndata[key] = value - if keep_edata: - for key, value in graph.edata.items(): - new_g.edata[key] = value[valid_edges] - - return new_g - - -def compute_theta( - edges: dgl.udf.EdgeBatch, - cosine: bool = False, - directed: bool = True, - eps=1e-7, -) -> dict[str, torch.Tensor]: - """User defined dgl function to calculate bond angles from edges in a graph. - - Args: - edges: DGL graph edges - cosine: Whether to return the cosine of the angle or the angle itself - directed: Whether to the line graph was created with create directed line graph. - In which case bonds (only those that are not self bonds) need to - have their bond vectors flipped. - eps: eps value used to clamp cosine values to avoid acos of values > 1.0 - - Returns: - dict[str, torch.Tensor]: Dictionary containing bond angles and distances - """ - vec1 = ( - edges.src["bond_vec"] * edges.src["src_bond_sign"] - if directed - else edges.src["bond_vec"] - ) - vec2 = edges.dst["bond_vec"] - key = "cos_theta" if cosine else "theta" - val = torch.sum(vec1 * vec2, dim=1) / ( - torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1) - ) - val = val.clamp_( - min=-1 + eps, max=1 - eps - ) # stability for floating point numbers > 1.0 - if not cosine: - val = torch.acos(val) - return {key: val, "triple_bond_lengths": edges.dst["bond_dist"]} - - -def create_line_graph( - g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False -) -> dgl.DGLGraph: - """ - Calculate the three body indices from pair atom indices. - - Args: - g: DGL graph - threebody_cutoff (float): cutoff for three-body interactions - directed (bool): Whether to create a directed line graph, or an M3gnet 3body line graph - Default = False (M3Gnet) - - Returns: - l_g: DGL graph containing three body information from graph - """ - graph_with_three_body = prune_edges_by_features( - g, feat_name="bond_dist", condition=lambda x: x > threebody_cutoff - ) - if directed: - # lg = g.line_graph(shared=True) - # return lg - lg = _create_directed_line_graph( - graph_with_three_body, threebody_cutoff - ) - else: - lg = _compute_3body(graph_with_three_body) - - return lg - - -def compute_pair_vector_and_distance(g: dgl.DGLGraph): - """Calculate bond vectors and distances using dgl graphs. - - Args: - g: DGL graph - - Returns: - bond_vec (torch.tensor): bond distance between two atoms - bond_dist (torch.tensor): vector from src node to dst node - """ - dst_pos = g.ndata["pos"][g.edges()[1]] + g.edata["images"] - src_pos = g.ndata["pos"][g.edges()[0]] - bond_vec = dst_pos - src_pos - bond_dist = torch.norm(bond_vec, dim=1) - - return bond_vec, bond_dist - - -def polynomial_cutoff( - r: torch.Tensor, cutoff: float, exponent: int = 3 -) -> torch.Tensor: - """Envelope polynomial function that ensures a smooth cutoff. - - Ensures first and second derivative vanish at cuttoff. As described in: - https://arxiv.org/abs/2003.03123 - - Args: - r (torch.Tensor): radius distance tensor - cutoff (float): cutoff distance. - exponent (int): minimum exponent of the polynomial. Default is 3. - The polynomial includes terms of order exponent, exponent + 1, exponent + 2. - - Returns: polynomial cutoff function - """ - coef1 = -(exponent + 1) * (exponent + 2) / 2 - coef2 = exponent * (exponent + 2) - coef3 = -exponent * (exponent + 1) / 2 - ratio = r / cutoff - poly_envelope = ( - 1 - + coef1 * ratio**exponent - + coef2 * ratio ** (exponent + 1) - + coef3 * ratio ** (exponent + 2) - ) - - return torch.where(r <= cutoff, poly_envelope, 0.0) - - -if __name__ == "__main__": - from jarvis.core.atoms import Atoms - from alignn.graphs import radius_graph_jarvis - - FIXTURES = { - "lattice_mat": [ - [2.715, 2.715, 0], - [0, 2.715, 2.715], - [2.715, 0, 2.715], - ], - "coords": [[0, 0, 0], [0.25, 0.25, 0.25]], - "elements": ["Si", "Si"], - } - Si = Atoms( - lattice_mat=FIXTURES["lattice_mat"], - coords=FIXTURES["coords"], - elements=FIXTURES["elements"], - ) - g, lg = radius_graph_jarvis( - atoms=s1, cutoff=5, atom_features="atomic_number" - ) - ewald = get_ewald_sum(g, torch.tensor(Si.lattice_mat)) diff --git a/alignn/train.py b/alignn/train.py index 3dc71ff..42c2d5c 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -12,8 +12,6 @@ from alignn.data import get_train_val_loaders from alignn.config import TrainingConfig from alignn.models.alignn_atomwise import ALIGNNAtomWise -from alignn.models.alignn_ff2 import ALIGNNFF2 -from alignn.models.alignn_eff import ALIGNNeFF from alignn.models.alignn import ALIGNN from jarvis.db.jsonutils import dumpjson import json @@ -150,8 +148,6 @@ def train_dgl( _model = { "alignn_atomwise": ALIGNNAtomWise, "alignn": ALIGNN, - "alignn_ff2": ALIGNNFF2, - "alignn_eff": ALIGNNeFF, } if config.random_seed is not None: random.seed(config.random_seed) @@ -239,9 +235,6 @@ def train_dgl( dats[2].to(device), ] ) - # result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) - # result = net([dats[0].to(device), dats[1].to(device),lat=dats[2].to(device)]) - # batched_graph, batched_line_graph, torch.stack(lattices),torch.tensor(labels) else: result = net(dats[0].to(device)) @@ -359,8 +352,6 @@ def train_dgl( optimizer.zero_grad() # result = net([dats[0].to(device), dats[1].to(device)]) if (config.model.alignn_layers) > 0: - # result = net([dats[0].to(device), dats[2].to(device), dats[1].to(device)]) - # result = net(dats[0].to(device), dats[2].to(device),dats[1].to(device)) result = net( [ dats[0].to(device), diff --git a/alignn/train_alignn.py b/alignn/train_alignn.py index 6dcf8ab..c36db4e 100644 --- a/alignn/train_alignn.py +++ b/alignn/train_alignn.py @@ -13,7 +13,6 @@ from jarvis.db.jsonutils import loadjson import argparse from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig -from alignn.models.alignn_ff2 import ALIGNNFF2, ALIGNNFF2Config import torch import time from jarvis.core.atoms import Atoms @@ -331,20 +330,6 @@ def train_for_folder( torch.load(restart_model_path, map_location=device) ) model = model.to(device) - if config.model.name == "alignn_ff2": - rest_config = loadjson( - restart_model_path.replace("current_model.pt", "config.json") - # restart_model_path.replace("best_model.pt", "config.json") - ) - - tmp = ALIGNNFF2Config(**rest_config["model"]) - print("Rest config", tmp) - model = ALIGNNFF2(tmp) # config.model) - print("model", model) - model.load_state_dict( - torch.load(restart_model_path, map_location=device) - ) - model = model.to(device) # print ('n_outputs',n_outputs[0]) # if multioutput and classification_threshold is not None: diff --git a/setup.py b/setup.py index ebf95e9..206bad8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setuptools.setup( name="alignn", - version="2024.8.30", + version="2024.10.30", author="Kamal Choudhary, Brian DeCost", author_email="kamal.choudhary@nist.gov", description="alignn", From 0e29e773d172046f8c5b26a359875b91a7b01172 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sun, 3 Nov 2024 23:32:17 -0500 Subject: [PATCH 10/37] PyTest fix. --- alignn/tests/test_alignn_ff.py | 9 ++++++--- alignn/train.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index 5621e5c..de42dfb 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -43,9 +43,9 @@ """ -def test_radius_graph_jarvis(): - atoms = Poscar.from_string(pos).atoms - g, lg = radius_graph_jarvis(atoms=atoms) +# def test_radius_graph_jarvis(): +# atoms = Poscar.from_string(pos).atoms +# g, lg = radius_graph_jarvis(atoms=atoms) def test_alignnff(): @@ -56,6 +56,9 @@ def test_alignnff(): old_g = Graph.from_atoms(atoms=atoms) g, lg = Graph.atom_dgl_multigraph(atoms) g, lg = Graph.atom_dgl_multigraph(atoms, neighbor_strategy="radius_graph") + g, lg = Graph.atom_dgl_multigraph( + atoms, neighbor_strategy="radius_graph_jarvis" + ) model_path = default_path() print("model_path", model_path) print("atoms", atoms) diff --git a/alignn/train.py b/alignn/train.py index 42c2d5c..39000cb 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -33,7 +33,7 @@ warnings.filterwarnings("ignore", category=RuntimeWarning) -torch.autograd.detect_anomaly() +# torch.autograd.detect_anomaly() def train_dgl( From 03680d51999257c43957e45fd68eb08887ab453b Mon Sep 17 00:00:00 2001 From: knc6 Date: Tue, 5 Nov 2024 13:39:46 -0500 Subject: [PATCH 11/37] Fix calculator. --- alignn/config.py | 22 --- .../examples/sample_data/config_example.json | 54 +------- .../config_example_atomwise.json | 61 +-------- alignn/ff/ff.py | 8 +- alignn/graphs.py | 31 +++-- alignn/lmdb_dataset.py | 2 +- alignn/tests/test_alignn_ff.py | 13 +- alignn/tests/test_prop.py | 128 ++++++------------ alignn/train.py | 35 +++-- 9 files changed, 104 insertions(+), 250 deletions(-) diff --git a/alignn/config.py b/alignn/config.py index 4ed2318..9ea153e 100644 --- a/alignn/config.py +++ b/alignn/config.py @@ -210,26 +210,4 @@ class TrainingConfig(BaseSettings): model: Union[ ALIGNNConfig, ALIGNNAtomWiseConfig, - # CGCNNConfig, - # ICGCNNConfig, - # SimpleGCNConfig, - # DenseGCNConfig, - # ALIGNN_LN_Config, - # DenseALIGNNConfig, - # ACGCNNConfig, ] = ALIGNNAtomWiseConfig(name="alignn_atomwise") - - # @root_validator() - # @model_validator(mode='before') - # def set_input_size(cls, values): - # """Automatically configure node feature dimensionality.""" - # values["model"].atom_input_features = FEATURESET_SIZE[ - # values["atom_features"] - # ] - - # return values - - # @property - # def atom_input_features(self): - # """Automatically configure node feature dimensionality.""" - # return FEATURESET_SIZE[self.atom_features] diff --git a/alignn/examples/sample_data/config_example.json b/alignn/examples/sample_data/config_example.json index dee460c..7a462ac 100644 --- a/alignn/examples/sample_data/config_example.json +++ b/alignn/examples/sample_data/config_example.json @@ -1,53 +1 @@ -{ - "version": "112bbedebdaecf59fb18e11c929080fb2f358246", - "dataset": "user_data", - "target": "target", - "atom_features": "cgcnn", - "neighbor_strategy": "k-nearest", - "id_tag": "jid", - "random_seed": 123, - "classification_threshold": null, - "n_val": null, - "n_test": null, - "n_train": null, - "train_ratio": 0.8, - "val_ratio": 0.1, - "test_ratio": 0.1, - "target_multiplication_factor": null, - "epochs": 3, - "batch_size": 2, - "weight_decay": 1e-05, - "learning_rate": 0.001, - "filename": "sample", - "warmup_steps": 2000, - "criterion": "mse", - "optimizer": "adamw", - "scheduler": "onecycle", - "pin_memory": false, - "save_dataloader": false, - "write_checkpoint": true, - "write_predictions": true, - "store_outputs": true, - "progress": true, - "log_tensorboard": false, - "standard_scalar_and_pca": false, - "use_canonize": true, - "num_workers": 0, - "cutoff": 8.0, - "max_neighbors": 12, - "keep_data_order": true, - "model": { - "name": "alignn_atomwise", - "alignn_layers": 4, - "gcn_layers": 4, - "atom_input_features": 92, - "edge_input_features": 80, - "triplet_input_features": 40, - "embedding_features": 64, - "hidden_features": 256, - "output_features": 1, - "link": "identity", - "zero_inflated": false, - "classification": false - } -} +{"version": "112bbedebdaecf59fb18e11c929080fb2f358246", "dataset": "user_data", "target": "target", "atom_features": "cgcnn", "neighbor_strategy": "k-nearest", "id_tag": "jid", "dtype": "float32", "random_seed": 123, "classification_threshold": null, "n_val": null, "n_test": null, "n_train": null, "train_ratio": 0.8, "val_ratio": 0.1, "test_ratio": 0.1, "target_multiplication_factor": null, "epochs": 3, "batch_size": 2, "weight_decay": 1e-05, "learning_rate": 0.001, "filename": "A", "warmup_steps": 2000, "criterion": "mse", "optimizer": "adamw", "scheduler": "onecycle", "pin_memory": false, "save_dataloader": false, "write_checkpoint": true, "write_predictions": true, "store_outputs": true, "progress": true, "log_tensorboard": false, "standard_scalar_and_pca": false, "use_canonize": true, "num_workers": 0, "cutoff": 8.0, "cutoff_extra": 3.0, "max_neighbors": 12, "keep_data_order": true, "normalize_graph_level_loss": false, "distributed": false, "data_parallel": false, "n_early_stopping": null, "output_dir": "temp", "use_lmdb": true, "model": {"name": "alignn_atomwise", "alignn_layers": 4, "gcn_layers": 4, "atom_input_features": 92, "edge_input_features": 80, "triplet_input_features": 40, "embedding_features": 64, "hidden_features": 256, "output_features": 1, "grad_multiplier": -1, "calculate_gradient": false, "atomwise_output_features": 0, "graphwise_weight": 1.0, "gradwise_weight": 1.0, "stresswise_weight": 0.0, "atomwise_weight": 0.0, "link": "identity", "zero_inflated": false, "classification": false, "force_mult_natoms": false, "energy_mult_natoms": false, "include_pos_deriv": false, "use_cutoff_function": false, "inner_cutoff": 3.0, "stress_multiplier": 1.0, "add_reverse_forces": true, "lg_on_fly": true, "batch_stress": true, "multiply_cutoff": false, "use_penalty": true, "extra_features": 0, "exponent": 5, "penalty_factor": 0.1, "penalty_threshold": 1.0}} \ No newline at end of file diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index 84a5ee6..f5981a4 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -1,60 +1 @@ -{ - "version": "112bbedebdaecf59fb18e11c929080fb2f358246", - "dataset": "user_data", - "target": "target", - "atom_features": "cgcnn", - "neighbor_strategy": "radius_graph", - "id_tag": "jid", - "dtype": "float32", - "random_seed": 123, - "classification_threshold": null, - "n_val": null, - "n_test": null, - "n_train": null, - "train_ratio": 0.8, - "val_ratio": 0.1, - "test_ratio": 0.1, - "target_multiplication_factor": null, - "epochs": 3, - "batch_size": 2, - "weight_decay": 1e-05, - "learning_rate": 0.001, - "filename": "sample", - "warmup_steps": 2000, - "criterion": "l1", - "optimizer": "adamw", - "scheduler": "onecycle", - "pin_memory": false, - "save_dataloader": false, - "write_checkpoint": true, - "write_predictions": true, - "store_outputs": false, - "progress": true, - "log_tensorboard": false, - "standard_scalar_and_pca": false, - "use_canonize": true, - "num_workers": 0, - "cutoff": 4.0, - "max_neighbors": 12, - "keep_data_order": true, - "distributed":false, - "use_lmdb": true, - "model": { - "name": "alignn_atomwise", - "atom_input_features": 92, - "calculate_gradient":true, - "atomwise_output_features":0, - "alignn_layers":1, - "gcn_layers":1, - "hidden_features":64, - "output_features": 1, - "graphwise_weight":0.85, - "gradwise_weight":0.05, - "atomwise_weight":0.0, - "use_cutoff_function":false, - "stresswise_weight":0.05, - "add_reverse_forces":true - - - } -} +{"version": "112bbedebdaecf59fb18e11c929080fb2f358246", "dataset": "user_data", "target": "target", "atom_features": "cgcnn", "neighbor_strategy": "radius_graph", "id_tag": "jid", "dtype": "float32", "random_seed": 123, "classification_threshold": null, "n_val": null, "n_test": null, "n_train": null, "train_ratio": 0.8, "val_ratio": 0.1, "test_ratio": 0.1, "target_multiplication_factor": null, "epochs": 3, "batch_size": 2, "weight_decay": 1e-05, "learning_rate": 0.001, "filename": "B", "warmup_steps": 2000, "criterion": "l1", "optimizer": "adamw", "scheduler": "onecycle", "pin_memory": false, "save_dataloader": false, "write_checkpoint": true, "write_predictions": true, "store_outputs": false, "progress": true, "log_tensorboard": false, "standard_scalar_and_pca": false, "use_canonize": true, "num_workers": 0, "cutoff": 4.0, "max_neighbors": 12, "keep_data_order": true, "distributed": false, "use_lmdb": true, "model": {"name": "alignn_atomwise", "atom_input_features": 92, "calculate_gradient": true, "atomwise_output_features": 0, "alignn_layers": 1, "gcn_layers": 1, "hidden_features": 64, "output_features": 1, "graphwise_weight": 0.85, "gradwise_weight": 0.05, "atomwise_weight": 0.0, "use_cutoff_function": false, "stresswise_weight": 0.05, "add_reverse_forces": true}} \ No newline at end of file diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index f0bf991..060685a 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -326,8 +326,12 @@ def calculate(self, atoms, properties=None, system_changes=None): "energy": energy, # * num_atoms, "forces": result["grad"].detach().cpu().numpy(), "stress": full_3x3_to_voigt_6_stress( - np.eye(3) - # result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() + # np.eye(3) + result["stresses"][:3] + .reshape(3, 3) + .detach() + .cpu() + .numpy() ) / 160.21766208, "dipole": np.zeros(3), diff --git a/alignn/graphs.py b/alignn/graphs.py index 054aaff..37d936e 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -91,7 +91,7 @@ def radius_graph_jarvis( dtype="float32", max_attempts=10, ): - """Construct radius graph with dynamic cutoff.""" + """Construct radius graph with jarvis tools.""" count = 0 while count <= max_attempts: # try: @@ -270,7 +270,7 @@ def radius_graph( bond_tol=0.5, id=None, atol=1e-5, - cutoff_extra=3.5, + cutoff_extra=0.5, ): """Construct edge list for radius graph.""" @@ -343,18 +343,23 @@ def temp_graph(cutoff=5): g = dgl.graph((u, v)) return g, u, v, r, cell_images - g, u, v, r, cell_images = temp_graph(cutoff) - while (g.num_nodes()) != len(atoms.elements): - try: + # g, u, v, r, cell_images = temp_graph(cutoff) + while True: # (g.num_nodes()) != len(atoms.elements): + # try: + g, u, v, r, cell_images = temp_graph(cutoff) + # g, u, v, r, cell_images = temp_graph(cutoff) + # print(atoms) + if (g.num_nodes()) == len(atoms.elements): + return u, v, r, cell_images + else: cutoff += cutoff_extra - g, u, v, r, cell_images = temp_graph(cutoff) - print("cutoff", id, cutoff) - print(atoms) + print("cutoff", id, cutoff, atoms) - except Exception as exp: - print("Graph exp", exp) - pass - return u, v, r, cell_images + # except Exception as exp: + # print("Graph exp", exp,atoms) + # cutoff += cutoff_extra + # pass + # return u, v, r, cell_images return u, v, r, cell_images @@ -534,6 +539,8 @@ def atom_dgl_multigraph( node_features = torch.tensor(sps_features).type( torch.get_default_dtype() ) + # print("u", u) + # print("v", v) g = dgl.graph((u, v)) g.ndata["atom_features"] = node_features # g.ndata["node_type"] = torch.tensor(node_types, dtype=torch.int64) diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index 682687b..1c4f422 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -93,7 +93,7 @@ def collate_line_graph( return ( batched_graph, batched_line_graph, - torch.tensor(lattices), + torch.stack(lattices), torch.stack(labels), ) else: diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index de42dfb..e439b06 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -51,7 +51,8 @@ def test_alignnff(): atoms = JAtoms.from_dict(get_jid_data()["atoms"]) atoms = JAtoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] + get_jid_data(dataset="dft_3d", jid="JVASP-1002")["atoms"] + # get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] ) old_g = Graph.from_atoms(atoms=atoms) g, lg = Graph.atom_dgl_multigraph(atoms) @@ -84,17 +85,19 @@ def test_alignnff(): xx = ff.run_nvt_andersen(steps=5) # xx = ff.run_npt_nose_hoover(steps=5) atoms_al = JAtoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-816")["atoms"] + get_jid_data(dataset="dft_3d", jid="JVASP-1002")["atoms"] + # get_jid_data(dataset="dft_3d", jid="JVASP-816")["atoms"] ) atoms_al2o3 = JAtoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] + get_jid_data(dataset="dft_3d", jid="JVASP-1002")["atoms"] + # get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] ) intf = get_interface_energy( film_atoms=atoms_al, subs_atoms=atoms_al, model_path=model_path, - film_thickness=10, - subs_thickness=10, + film_thickness=5, + subs_thickness=5, # film_atoms=atoms_al, subs_atoms=atoms_al2o3, model_path=model_path ) diff --git a/alignn/tests/test_prop.py b/alignn/tests/test_prop.py index bcc34c6..4b78fba 100644 --- a/alignn/tests/test_prop.py +++ b/alignn/tests/test_prop.py @@ -14,6 +14,9 @@ from alignn.ff.ff import AlignnAtomwiseCalculator, default_path, revised_path import torch from jarvis.db.jsonutils import loadjson, dumpjson +from alignn.config import TrainingConfig + +world_size = int(torch.cuda.device_count()) plt.switch_backend("agg") @@ -21,50 +24,27 @@ "dataset": "dft_2d", "target": "formation_energy_peratom", # "target": "optb88vdw_bandgap", - "n_train": 50, - "n_test": 25, - "n_val": 25, - "num_workers": 0, + "n_train": 4, + "n_test": 4, + "n_val": 4, "atom_features": "cgcnn", "neighbor_strategy": "k-nearest", "epochs": 2, - "save_dataloader": False, - "batch_size": 10, - "weight_decay": 1e-05, - "learning_rate": 0.01, - "criterion": "mse", - "optimizer": "adamw", - "scheduler": "onecycle", - "num_workers": 4, + "batch_size": 2, "model": { - "name": "alignn", + "name": "alignn_atomwise", + "calculate_gradient": False, + "energy_mult_natoms": False, + "atom_input_features": 92, }, } - -# def test_runtime_training(): -# cmd1 = 'python alignn/train_folder.py --root_dir "alignn/examples/sample_data" --config "alignn/examples/sample_data/config_example.json"' -# os.system(cmd1) -# cmd2 = 'python alignn/train_folder.py --root_dir "alignn/examples/sample_data" --classification_threshold 0.01 --config "alignn/examples/sample_data/config_example.json"' -# os.system(cmd2) -# cmd3 = 'python alignn/train_folder.py --root_dir "alignn/examples/sample_data_multi_prop" --config "alignn/examples/sample_data/config_example.json"' -# os.system(cmd3) - - -# def test_minor_configs(): -# tmp = config -# # tmp["log_tensorboard"] = True -# tmp["n_early_stopping"] = 2 -# tmp["model"]["name"] = "alignn" -# config["write_predictions"] = True -# result = train_dgl(tmp) +config = TrainingConfig(**config) def test_models(): + test_clean() - config["write_predictions"] = True - config["model"]["name"] = "alignn_atomwise" - config["filename"] = "X" t1 = time.time() result = train_dgl(config) t2 = time.time() @@ -74,12 +54,11 @@ def test_models(): print() print() print() - - config["model"]["name"] = "alignn_atomwise" - config["filename"] = "Y" - config["classification_threshold"] = 0.0 + test_clean() + config.classification_threshold = 0.0 + # config.model.classification = True t1 = time.time() - result = train_dgl(config) + # result = train_dgl(config,model=None) t2 = time.time() print("Total time", t2 - t1) # print("train=", result["train"]) @@ -87,51 +66,21 @@ def test_models(): print() print() print() + test_clean() - """ - - config["model"]["name"] = "simplegcn" - config["write_predictions"] = False - config["save_dataloader"] = False - t1 = time.time() - result = train_dgl(config) - t2 = time.time() - print("Total time", t2 - t1) - print("train=", result["train"]) - print("validation=", result["validation"]) - print() - print() - print() - """ - """ - x = [] - y = [] - for i in result["EOS"]: - x.append(i[0].cpu().numpy().tolist()) - y.append(i[1].cpu().numpy().tolist()) - x = np.array(x, dtype="float").flatten() - y = np.array(y, dtype="float").flatten() - plt.plot(x, y, ".") - plt.xlabel("DFT") - plt.ylabel("ML") - plt.savefig("compare.png") - plt.close() - """ - - -def test_pretrained(): - box = [[2.715, 2.715, 0], [0, 2.715, 2.715], [2.715, 0, 2.715]] - coords = [[0, 0, 0], [0.25, 0.2, 0.25]] - elements = ["Si", "Si"] - Si = Atoms(lattice_mat=box, coords=coords, elements=elements) - prd = get_prediction(atoms=Si) - print(prd) - cmd1 = "python alignn/pretrained.py" - os.system(cmd1) - get_multiple_predictions(atoms_array=[Si, Si]) - -world_size = int(torch.cuda.device_count()) +# def test_pretrained(): +# box = [[2.715, 2.715, 0], [0, 2.715, 2.715], [2.715, 0, 2.715]] +# coords = [[0, 0, 0], [0.25, 0.2, 0.25]] +# elements = ["Si", "Si"] +# Si = Atoms(lattice_mat=box, coords=coords, elements=elements) +# prd = get_prediction(atoms=Si) +# print(prd) +# cmd1 = "python alignn/pretrained.py" +# os.system(cmd1) +# get_multiple_predictions(atoms_array=[Si, Si]) +# cmd1 = "rm *.json" +# os.system(cmd1) def test_alignn_train_regression(): @@ -222,6 +171,13 @@ def test_alignn_train_ff(): train_for_folder( rank=0, world_size=world_size, root_dir=root_dir, config_name=config ) + cmd = "rm *.pt *.csv *.json *range" + os.system(cmd) + + +def test_clean(): + cmd = "rm *.pt *.csv *.json *range" + os.system(cmd) def test_calculator(): @@ -281,15 +237,19 @@ def test_del_files(): for i in fnames: cmd = "rm -r " + i os.system(cmd) - cmd="rm -r *train_data *val_data *test_data" + cmd = "rm -r *train_data *val_data *test_data" os.system(cmd) + +test_clean() +# test_pretrained() +# test_models() # test_alignn_train_ff() +# test_alignn_train_regression_multi_out() + # test_alignn_train_classification() # test_alignn_train() # test_minor_configs() -# test_pretrained() # test_runtime_training() # test_alignn_train() -# test_models() # test_calculator() diff --git a/alignn/train.py b/alignn/train.py index 39000cb..e63cf7a 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -53,14 +53,14 @@ def train_dgl( # print("rank", rank) # setup(rank, world_size) if rank == 0: - print("config:") # print(config) if type(config) is dict: try: - print(config) + print("Trying to convert dictionary.") config = TrainingConfig(**config) except Exception as exp: print("Check", exp) + print("config:", config.dict()) if not os.path.exists(config.output_dir): os.makedirs(config.output_dir) @@ -169,7 +169,9 @@ def train_dgl( net = _model.get(config.model.name)(config.model) else: net = model - print("net parameters", sum(p.numel() for p in net.parameters())) + print("Model parameters", sum(p.numel() for p in net.parameters())) + print("CUDA available", torch.cuda.is_available()) + print("CUDA device count", int(torch.cuda.device_count())) # print("device", device) net.to(device) if use_ddp: @@ -253,6 +255,9 @@ def train_dgl( loss3 = 0 # Such as forces loss4 = 0 # Such as stresses if config.model.output_features is not None: + # print('criterion',criterion) + # print('result["out"]',result["out"]) + # print('dats[-1]',dats[-1]) loss1 = config.model.graphwise_weight * criterion( result["out"], dats[-1].to(device), @@ -598,8 +603,10 @@ def train_dgl( with torch.no_grad(): ids = test_loader.dataset.ids # [test_loader.dataset.indices] for dat, id in zip(test_loader, ids): - g, lg, target = dat - out_data = best_model([g.to(device), lg.to(device)])["out"] + g, lg, lat, target = dat + out_data = best_model( + [g.to(device), lg.to(device), lat.to(device)] + )["out"] # out_data = net([g.to(device), lg.to(device)])["out"] # out_data = torch.exp(out_data.cpu()) # print('target',target) @@ -632,8 +639,10 @@ def train_dgl( with torch.no_grad(): ids = test_loader.dataset.ids # [test_loader.dataset.indices] for dat, id in zip(test_loader, ids): - g, lg, target = dat - out_data = best_model([g.to(device), lg.to(device)])["out"] + g, lg, lat, target = dat + out_data = best_model( + [g.to(device), lg.to(device), lat.to(device)] + )["out"] # out_data = net([g.to(device), lg.to(device)])["out"] out_data = out_data.cpu().numpy().tolist() if config.standard_scalar_and_pca: @@ -673,8 +682,10 @@ def train_dgl( with torch.no_grad(): ids = test_loader.dataset.ids # [test_loader.dataset.indices] for dat, id in zip(test_loader, ids): - g, lg, target = dat - out_data = best_model([g.to(device), lg.to(device)])["out"] + g, lg, lat, target = dat + out_data = best_model( + [g.to(device), lg.to(device), lat.to(device)] + )["out"] # out_data = net([g.to(device), lg.to(device)])["out"] out_data = out_data.cpu().numpy().tolist() if config.standard_scalar_and_pca: @@ -710,8 +721,10 @@ def train_dgl( with torch.no_grad(): ids = train_loader.dataset.ids # [test_loader.dataset.indices] for dat, id in zip(train_loader, ids): - g, lg, target = dat - out_data = best_model([g.to(device), lg.to(device)])["out"] + g, lg, lat, target = dat + out_data = best_model( + [g.to(device), lg.to(device), lat.to(device)] + )["out"] # out_data = net([g.to(device), lg.to(device)])["out"] out_data = out_data.cpu().numpy().tolist() if config.standard_scalar_and_pca: From e558fd88ddee1305a51e02406e1cafa9dda56cbc Mon Sep 17 00:00:00 2001 From: knc6 Date: Wed, 6 Nov 2024 18:21:06 -0500 Subject: [PATCH 12/37] Update train.py. --- alignn/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alignn/train.py b/alignn/train.py index e63cf7a..340a54f 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -644,7 +644,7 @@ def train_dgl( [g.to(device), lg.to(device), lat.to(device)] )["out"] # out_data = net([g.to(device), lg.to(device)])["out"] - out_data = out_data.cpu().numpy().tolist() + out_data = out_data.detach().cpu().numpy().tolist() if config.standard_scalar_and_pca: sc = pk.load(open("sc.pkl", "rb")) out_data = list( From fdc76003f25b9259a7d493568a1d69fbbbc63c03 Mon Sep 17 00:00:00 2001 From: knc6 Date: Fri, 8 Nov 2024 21:22:29 -0500 Subject: [PATCH 13/37] additional outputs. --- .github/workflows/main.yml | 4 +- alignn/data.py | 5 + alignn/examples/sample_data_ff_additional/CMD | 2 + .../DataDir/id_prop.json | 1 + .../sample_data_ff_additional/config.json | 85 ++++++++++ alignn/lmdb_dataset.py | 11 +- alignn/models/alignn_atomwise.py | 48 +++++- alignn/train.py | 149 ++++++++++++++++-- alignn/train_alignn.py | 24 ++- alignn/utils.py | 20 ++- 10 files changed, 319 insertions(+), 30 deletions(-) create mode 100644 alignn/examples/sample_data_ff_additional/CMD create mode 100644 alignn/examples/sample_data_ff_additional/DataDir/id_prop.json create mode 100644 alignn/examples/sample_data_ff_additional/config.json diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ab955e3..6742769 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -24,8 +24,8 @@ jobs: shell: bash -l {0} run: | conda install flake8 pycodestyle pydocstyle - flake8 --ignore E203,W503 --exclude=examples,tests,scripts --statistics --count --exit-zero alignn - pycodestyle --ignore E203,W503 --exclude=examples,tests,scripts alignn + flake8 --ignore E203,W503,W605 --exclude=examples,tests,scripts --statistics --count --exit-zero alignn + pycodestyle --ignore E203,W503,W605 --exclude=examples,tests,scripts alignn pydocstyle --match-dir=core --match-dir=io --match-dir=io --match-dir=ai --match-dir=analysis --match-dir=db --match-dir=tasks --count alignn - name: Run pytest diff --git a/alignn/data.py b/alignn/data.py index a30cba7..39f2fd0 100644 --- a/alignn/data.py +++ b/alignn/data.py @@ -122,6 +122,7 @@ def get_train_val_loaders( target_atomwise: str = "", target_grad: str = "", target_stress: str = "", + target_additional_output: str = "", atom_features: str = "cgcnn", neighbor_strategy: str = "k-nearest", n_train=None, @@ -161,6 +162,7 @@ def get_train_val_loaders( from alignn.lmdb_dataset import get_torch_dataset else: print("Not using LMDB dataset, memory footprint maybe high.") + print("WARNING: not using LMDB might result errors.") from alignn.dataset import get_torch_dataset train_sample = filename + "_train.data" val_sample = filename + "_val.data" @@ -373,6 +375,7 @@ def get_train_val_loaders( target_atomwise=target_atomwise, target_grad=target_grad, target_stress=target_stress, + target_additional_output=target_additional_output, neighbor_strategy=neighbor_strategy, use_canonize=use_canonize, name=dataset, @@ -397,6 +400,7 @@ def get_train_val_loaders( target_atomwise=target_atomwise, target_grad=target_grad, target_stress=target_stress, + target_additional_output=target_additional_output, neighbor_strategy=neighbor_strategy, use_canonize=use_canonize, name=dataset, @@ -424,6 +428,7 @@ def get_train_val_loaders( target_atomwise=target_atomwise, target_grad=target_grad, target_stress=target_stress, + target_additional_output=target_additional_output, neighbor_strategy=neighbor_strategy, use_canonize=use_canonize, name=dataset, diff --git a/alignn/examples/sample_data_ff_additional/CMD b/alignn/examples/sample_data_ff_additional/CMD new file mode 100644 index 0000000..d0883b7 --- /dev/null +++ b/alignn/examples/sample_data_ff_additional/CMD @@ -0,0 +1,2 @@ +train_alignn.py --root_dir DataDir/ --config config.json --output_dir temp --target_key energy --additional_output_key dos --stresswise_key stresses + diff --git a/alignn/examples/sample_data_ff_additional/DataDir/id_prop.json b/alignn/examples/sample_data_ff_additional/DataDir/id_prop.json new file mode 100644 index 0000000..c1cf375 --- /dev/null +++ b/alignn/examples/sample_data_ff_additional/DataDir/id_prop.json @@ -0,0 +1 @@ +[{"jid": "JVASP-71467_102", "atoms": {"lattice_mat": [[3.069291534, 0.0, 0.0], [0.0, 3.069291534, 0.0], [0.0, 0.0, 8.270263272]], "coords": [[0.0, 0.0, 5.55037], [1.53465, 1.53465, 4.81677], [0.0, 0.0, 7.70764], [1.53465, 1.53465, 2.60087]], "elements": ["Be", "P", "Se", "Se"], "abc": [3.06929, 3.06929, 8.27026], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 3.9313, "forces": [[-0.0, -0.0, -0.00100989], [0.0, -0.0, 0.00115946], [-0.0, 0.0, -0.00193256], [0.0, 0.0, 0.001783]], "energy": -10.23406177, "stresses": [[0.05711809, 0.0, 0.0], [0.0, 0.05711809, 0.0], [-0.0, 0.0, -0.17432286]], "dos": [0.6515440936997379, 0.621562637568951, 0.4868822099149206, 0.3409256082509479, 0.30754728180555246, 0.3406978602196719, 0.35569857627884, 0.33762673395564796, 0.31722688769354346, 0.2700971722261321, 0.22296745675872076, 0.2499856788014713, 0.2948625781962606, 0.2669782070151828, 0.17500340160214753, 0.12526919612636206, 0.19736011028719067, 0.26945102444801927, 0.21045898085701387, 0.14923297811799208, 0.1254734038180193, 0.11708958170260883, 0.1281868827943793, 0.16460715045445673, 0.18973571372028838, 0.1558414933245134, 0.12194727292873841, 0.14813853723427386, 0.18203255379882793, 0.1725901747467675, 0.13569068766321252, 0.11024557288750936, 0.10724138650519556, 0.10442853218868671, 0.10442724805199502, 0.10442596391530333, 0.10442467977861164, 0.10442339564191996, 0.10287411651132793, 0.09986864599232244, 0.08796845551968407, 0.0480647820538107, 0.008161108587937338, 0.003932198555021915, 0.0009267280360164241, 2.0309823648720353e-06, 7.468456731837267e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.128930080194815e-07, 1.6970296997077901e-06, 0.0001452380310990411, 0.0031530933753890957, 0.00616094871967915, 0.04200838898089866, 0.08749420089779296, 0.1493996492556201, 0.2277985331346023, 0.29290923472591424, 0.3121603254463461, 0.33141141616677794, 0.29351325379487636, 0.252635239858094, 0.26936732810164016, 0.3136831006582954, 0.35215955235952723, 0.3819498681613843, 0.4048951240042807, 0.3811196016505096, 0.3573440792967386, 0.3867234648790765, 0.42522161046176626, 0.4931918134376844, 0.5825607429141689, 0.6502731606174319, 0.6686681137486741, 0.6847248697022312, 0.6274865181961807, 0.5702481666901302, 0.6552080952817549, 0.785889528173075, 0.8864653053404682, 0.9547488136263472, 0.9971202064893098, 0.9409819197079388, 0.8848436329265679, 0.838033701606173, 0.7918824978907512, 0.8552501025425623, 0.9750807934066922, 1.0458704077682122, 1.0385217324885134, 1.0402655197011998, 1.1145460369677196, 1.188826554234239, 1.1694401917208603, 1.131835461382422, 1.160655014623672, 1.2410433056440207, 1.2646307905705392, 1.1482945690772666, 1.033856162303373, 1.0482728307744449, 1.0626894992455167, 0.8765469693255619, 0.6200461694882721, 0.49190281190320906, 0.5108010241394481, 0.5423351183903666, 0.6270932870750882, 0.7118514557598097, 0.7030291073379742, 0.6858079053622833, 0.6986774898200252, 0.72822298196507, 0.7876124998350559, 0.8979911582086892, 1.003657797246262, 1.0644294569620127, 1.1252011166777638, 1.40653887975555, 1.7360416507835108, 1.984352948178301, 2.165303177174916, 2.206588690274968, 1.8746923002178215, 1.5427959101606752, 1.4552775601094936, 1.3682209499430709, 1.2950326532315228, 1.2271339268622943, 1.1602780720632273, 1.0946984992164577, 1.0813820454693128, 1.3136668297798015, 1.5459516140902902, 1.58452151364354, 1.6018639763065843, 1.5948837717910949, 1.5734355261621367, 1.5755990507722342, 1.62107668006718, 1.6690733383250362, 1.7466051523384902, 1.8241369663519442, 1.861295956554083, 1.88863581200867, 1.9589899673889857, 2.067465893282051, 2.075284135242608, 1.7904614087486443, 1.5056386822546806, 1.4904982302136252, 1.4803811237849156, 1.437003818612632, 1.3798755312573607, 1.2919502250397428, 1.1637426993368953, 1.0577830939521637, 1.0694476412352527, 1.0811121885183417, 1.0061965565852968, 0.920012795766307, 0.9450621592965325, 1.0410418497283676, 1.0605455335794391, 0.9291916767124704, 0.8073504452464206, 0.8289666607368225, 0.8505828762272243, 0.8446413837681858, 0.8312844670125822, 0.8788335025069135, 0.9840283753713044, 1.0380159417453458, 0.9294297471969104, 0.8208435526484748, 0.9173494524526631, 1.0212249442018675, 1.0356044087672474, 1.0099810402023146, 0.9471298458867623, 0.8322047486452578, 0.7433398496828192, 0.8110899473690528, 0.8788400450552863, 0.8499790635402648, 0.8064825233431111, 0.780365178320147, 0.7661156008318587, 0.7225634831103195, 0.6167446131558048, 0.5209367447945863, 0.633766554138499, 0.7465963634824118, 0.8068124173476849, 0.8514514969486138, 0.8982643601945288, 0.9472741776838612, 0.966324446805635, 0.8810654681153316, 0.7958064894250282, 0.7646177162423801, 0.7363399470726741, 0.7539664909251523, 0.7937258849093044, 0.8294449887084183, 0.8591156748100099, 0.8916561298344556, 0.9440561721173738, 0.996456214400292, 0.9941056358160071, 0.9822474058416327, 0.920470818571022, 0.8222220172980743, 0.7584009445275496, 0.7735491125047377, 0.787784772231387, 0.771879808235897, 0.755974844240407, 0.7623329414948208, 0.7759089483594964, 0.7745527430747268, 0.757081354614022, 0.76710746675949, 0.8826472687905869, 0.9981870708216836, 0.9729558975183786, 0.9375368936942999, 0.9189627850454712, 0.9091324811708911, 0.881346861579417, 0.8247671847676441, 0.7685431088325363, 0.7152003500126836, 0.6618575911928308, 0.6645507723739792, 0.6782655552774118, 0.690855049436063, 0.7025654757156864, 0.7294433034247174, 0.7939644270087403, 0.8580076924540265, 0.8857279926904891, 0.9134482929269517, 0.9695349486893144, 1.035652151442002, 1.0946139870060796, 1.145328349278663, 1.1917002916758592, 1.219599478522794, 1.2474986653697289, 1.1584077579897127, 1.0586041313056918, 0.9397244955091643, 0.8102025473177039, 0.6746634131328383, 0.5287759928924348, 0.3967634069276494, 0.3993374862031594, 0.40191156547866935, 0.545998067430306, 0.721308529312322, 0.8250008152222115, 0.8689068482668374, 0.83849632754849, 0.6079718414375979, 0.3774473553267056, 0.38692352743648745, 0.3972222107104612, 0.40448669142181315, 0.4105850094740944, 0.3708570252601824, 0.27469830519870275, 0.19489255403407524, 0.19275427601538028, 0.1906159979966853, 0.18214510392652977, 0.172968331277833, 0.27930764858375595, 0.4548107985169881, 0.5523263286894866, 0.5058166435357097, 0.4709068361928494, 0.5749333673613611, 0.6789598985298728, 0.7231011518492105, 0.752535778700439, 0.8398297297029281, 0.9787177563910268, 1.1060142145340828, 1.199338329736332, 1.2926624449385815, 1.199084878612495, 1.1017288292668102, 1.127626410396113, 1.2048589905541045, 1.1964040681067174, 1.0751700597265799, 0.9804712541334875, 1.0276831419961379, 1.0748950298587885, 1.041217170720167, 0.9968534770579047, 0.9605853085631585, 0.9295127279659714, 0.868102927535523, 0.746437644681446, 0.6257413022189489, 0.5200502000222696, 0.41435909782559055, 0.4120089286600945, 0.437721542227582, 0.4769695476349125, 0.5291071075543836, 0.5836981601975806, 0.6461444263591948, 0.7085906925208089, 0.7442144202326998, 0.7788302342733047, 0.8206864269314361, 0.865802154220379, 0.8617918949678331, 0.7886307466376054, 0.7368844648069045, 0.8154628378000282, 0.8940412107931518, 0.8640975912146711, 0.8174935594125907, 0.8342082843953567, 0.8944366086607756, 0.89809595110187, 0.7806994220196561, 0.6697507306720994, 0.6980474499440733, 0.7263441692160472, 0.787897715509377, 0.8593829857078539, 0.8753579976536785, 0.8348867545520484, 0.7884382936337325, 0.7209942408892328, 0.653550188144733, 0.5253934083974336, 0.3938647114979314, 0.37289456193813864, 0.4056027463570177, 0.43395026919880997, 0.4557280238690629, 0.46136151226883737, 0.35370424074763657, 0.24604696922643574, 0.16434884123192928, 0.08721340713159512, 0.04523638763037758, 0.029108909161304274, 0.01918529500154188, 0.023595376163236873, 0.02772380752292133, 0.02202291584588465, 0.016322024168847968, 0.010297970723136953, 0.004168276459110214, 0.0007713485787672814, 0.00034181219938197007, 3.795572381817405e-07, 1.9610913936912494e-07, 1.2661040556509457e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, {"jid": "JVASP-71467_54", "atoms": {"lattice_mat": [[3.04042405, 0.0, -0.0], [0.0, 3.04042405, 0.0], [0.0, 0.0, 8.015354285]], "coords": [[0.0, -0.0, 4.31033], [1.52021, 1.52021, 5.13904], [0.0, -0.0, 7.79417], [1.52021, 1.52021, 2.79485]], "elements": ["Be", "P", "Se", "Se"], "abc": [3.04042, 3.04042, 8.01535], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 4.4266, "forces": [[0.0, -0.0, 0.05580781], [-0.0, 0.0, -0.09301742], [-0.0, -0.0, -0.00863591], [0.0, 0.0, 0.04584552]], "energy": -8.89726508, "stresses": [[1.95413462, 0.0, -0.0], [0.0, 1.95413462, 0.0], [-0.0, 0.0, 2.6348797]], "dos": [1.7118044033355297, 1.670693638170579, 1.6295828730056283, 1.475768144202696, 1.2758567137800831, 1.1861331219462545, 1.2648710707009756, 1.3257833539395256, 1.1848612576636024, 1.0439391613876792, 0.9601038261732336, 0.8936232836114707, 0.8997142824827395, 0.9935835726938019, 1.0480829953848168, 0.8572787357078877, 0.6664744760309586, 0.7328216880534285, 0.8540891496409058, 0.8452318197932872, 0.7112173469642527, 0.5982454813214996, 0.571895871893606, 0.5455462624657126, 0.5267192769633304, 0.5089068390054041, 0.48705297404318504, 0.46211125075163106, 0.42256839260192236, 0.33982158995999573, 0.2570747873180691, 0.18898982543872894, 0.12186884709437898, 0.0780267356999219, 0.04821019765542177, 0.028232808252981717, 0.03018212837804424, 0.032131448503106765, 0.02351631999004803, 0.014852961687141598, 0.0088809633646525, 0.004168682403713604, 0.004083439699989571, 0.011982779439016005, 0.023676690423129445, 0.10371259864539015, 0.18374850686765085, 0.27147905554918345, 0.36193681064745353, 0.44686820803978217, 0.5242931467545241, 0.610434030919611, 0.7672755101747459, 0.9241169894298809, 1.1680359623002445, 1.4343447879023312, 1.7036911319737738, 1.976312586972329, 2.19249723228515, 2.126964948449751, 2.061432664614352, 2.3431925160869427, 2.684987229495385, 2.8577666010827496, 2.8855980082210855, 2.947954500309281, 3.1298854511200265, 3.311816401930772, 3.807680540411968, 4.334684115398854, 4.392935775559506, 4.132750063781654, 3.9595846715909855, 4.008869250181292, 4.0581538287716, 4.193238576617316, 4.331259897744489, 4.448180741589438, 4.553870242533303, 4.60436027928867, 4.546901068030966, 4.491734157361255, 4.529572835712927, 4.5674115140646, 4.8463343935569405, 5.223838427363925, 5.487507141497387, 5.577174266363221, 5.680011156968692, 5.931865115829959, 6.1837190746912265, 6.2665681339137, 6.29805285080034, 6.313786245182356, 6.310471609638377, 6.345875749204879, 6.622425496375517, 6.898975243546155, 6.954297566554489, 6.962388302090844, 6.840653326755643, 6.594073877202897, 6.34664538311665, 6.095722894705935, 5.844800406295221, 5.561095261198192, 5.272970957525503, 5.100153928666896, 5.015419637599262, 5.002516922602126, 5.202103501988971, 5.401690081375817, 6.052183778075367, 6.73229798669415, 7.203976226218829, 7.550098080800397, 7.900123183269369, 8.258844703869732, 8.617566224470094, 9.442861140342142, 10.270262535328566, 10.124682730478243, 9.523794504524743, 8.88051575683105, 8.164102170869182, 7.485903038682435, 7.4952734702123935, 7.5046439017423525, 7.254763153725345, 6.913019980376374, 6.970663460693234, 7.570683179782713, 8.112281438415337, 8.180228598515768, 8.2481757586162, 8.462675777296502, 8.71484669635726, 9.025039677521233, 9.397780622668051, 9.736408955613552, 9.904817849023146, 10.073226742432743, 9.995567199565812, 9.875387830522909, 9.850180945022839, 9.90640628938718, 9.825574499367988, 9.27019454937452, 8.71481459938105, 8.77549642952917, 8.897248967554326, 8.804773544122947, 8.566796814008454, 8.377175923283371, 8.311136421064946, 8.24509691884652, 8.201411906029643, 8.158490795214954, 7.88332784393855, 7.4845747055472, 7.246130726766218, 7.321120675047615, 7.392719677645218, 7.327030442463985, 7.261341207282753, 7.117092952622256, 6.9407291246986595, 6.835213513166866, 6.837969410340664, 6.869251218976208, 7.223089850528348, 7.576928482080488, 7.4608542419345225, 7.202003029027091, 7.011214185080037, 6.902716346763126, 6.807351774604364, 6.793748576036902, 6.780145377469439, 6.823597157116046, 6.879225849177272, 6.733288167473971, 6.393556757065675, 6.095378221994799, 5.968143983132528, 5.840909744270257, 5.632258238443317, 5.412636792128049, 5.3199783091032335, 5.324286525623294, 5.27248387117548, 5.054740273021419, 4.8369966748673585, 5.126300202940153, 5.448883312946514, 5.534542166775264, 5.4775144797961515, 5.3407150122794524, 5.026225984698994, 4.711736957118535, 4.69926066143541, 4.688132636876018, 5.0162022666441874, 5.502963509737617, 5.777580546742627, 5.686271684875122, 5.609710467124983, 5.798228698751337, 5.9867469303776915, 5.740779366446521, 5.340897058868618, 5.07469258875339, 4.98998849180139, 4.916622413660653, 4.93513210194773, 4.953641790234807, 5.050607124930861, 5.167733005024723, 5.385379355454813, 5.711365364220153, 5.940095336733719, 5.6836993985243085, 5.427303460314899, 5.20188278718116, 4.981812406939587, 4.91069569324522, 4.967270561808847, 5.0419203489502395, 5.179134926601635, 5.31634950425303, 5.264709386514737, 5.194359366890521, 5.107093885310524, 5.008341992524634, 5.001526765451728, 5.229612810981817, 5.4576988565119064, 5.8114274167170406, 6.169442732422224, 6.195143648910413, 6.044038648060555, 6.046556498655726, 6.349368175096896, 6.647832166684805, 6.77064644703026, 6.893460727375715, 6.670776018141254, 6.306882798643624, 6.138066235799561, 6.267306901256507, 6.37187285115488, 6.1976163380597615, 6.023359824964643, 5.925688086690198, 5.8512791299350475, 5.533525026341943, 4.921613191526428, 4.370190425277361, 4.195185390377454, 4.020180355477548, 3.9255064038805623, 3.847971129005987, 3.4964983399099197, 2.881704143808453, 2.344096064239498, 2.123922801682033, 1.903749539124568, 1.940116223932893, 2.01103186954231, 2.0080770515046416, 1.9487159023131115, 1.923976017568114, 2.001597098094924, 2.079218178621734, 2.2230819428808384, 2.3712897213429374, 2.6572487194002656, 3.0261500563176873, 3.370164577008285, 3.6587573364649963, 3.947350095921708, 4.434912283451638, 4.92335268031833, 5.233136782783866, 5.459356860383794, 5.7134777821540945, 6.015714286857123, 6.307371238767009, 6.409057966628328, 6.510744694489648, 6.363217908994866, 6.1274311711690626, 5.986792091070943, 5.975312767982404, 5.931861136917042, 5.629459729233902, 5.327058321550764, 5.316749118440619, 5.381475000416705, 5.430745782426495, 5.463362597404149, 5.489196455428398, 5.481208200917345, 5.473219946406291, 5.423751696269924, 5.367121548096999, 5.385753870315983, 5.468892484311524, 5.523991814131106, 5.482063477681274, 5.440135141231442, 5.457743264673445, 5.481246087114321, 5.336762492662845, 5.078231752973664, 4.966459898422453, 5.229569115505881, 5.492678332589309, 5.299697887735358, 5.091180680496625, 4.920701367056197, 4.770455422628366, 4.620972642397081, 4.472981321752599, 4.324858733899139, 4.1714440906019, 4.018029447304661, 3.959275287646584, 3.93920030841908, 3.8041758269217816, 3.4935569939188564, 3.235202970754407, 3.5670419330707737, 3.8988808953871406, 3.96949997548168, 3.960795205137162, 3.870617515299846, 3.681974543703041, 3.4766811762594134, 3.167817436735338, 2.8589536972112626, 2.622098012456379, 2.4005999765648496, 2.2904202891520966, 2.2872235176971323, 2.269919304208704, 2.1946154230526314, 2.119311541896559, 1.6565492856702253, 1.1416318792857374, 0.8460232528627991, 0.717841245994119, 0.5937348198846153, 0.48167503198455797, 0.3696152440845006, 0.3681485020296499, 0.3739278651250186, 0.2944478217811411, 0.16364260015535376, 0.07050835452416734, 0.0612459582641511, 0.05198356200413485, 0.03173647877556328, 0.01144146606215582, 0.004382702935964152, 0.003513586500643884, 0.0025246615444467955, 0.0013291688210750048, 0.00019231269006067224, 0.00010730254032851011, 2.2292390596347973e-05, 4.3447869961189325e-08, 1.5875326563393436e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, {"jid": "JVASP-71467_10", "atoms": {"lattice_mat": [[2.762607658, 0.0, -0.0], [0.0, 2.762607658, 0.0], [0.0, 0.0, 7.695565317]], "coords": [[0.0, -0.0, 3.9024], [1.3813, 1.3813, 4.95183], [0.0, -0.0, 7.523], [1.3813, 1.3813, 2.86168]], "elements": ["Be", "P", "Se", "Se"], "abc": [2.76261, 2.76261, 7.69557], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 7.281, "forces": [[0.0, -0.0, 0.86799415], [0.0, 0.0, 4.19313435], [0.0, -0.0, 0.31765056], [-0.0, -0.0, -5.37877905]], "energy": -7.12602872, "stresses": [[350.16854223, 0.0, -0.0], [0.0, 350.16854223, 0.0], [-0.0, 0.0, 317.7697081]], "dos": [3.3456294370180233, 3.6684114208054317, 3.834306861651832, 4.000202302498232, 4.225637144151265, 4.49322529279601, 4.741113054005529, 4.752622312184956, 4.764131570364381, 5.0888833141425565, 5.651508690588771, 6.190931695325575, 6.365725037878824, 6.540518380432074, 6.4444038874481056, 6.127775191837605, 5.827535875597282, 5.894547228507382, 5.961558581417482, 6.0314128879068045, 6.10374644784577, 6.177397886490238, 6.301179700700089, 6.42496151490994, 6.5577122945561115, 6.698840088544873, 6.841879942049144, 7.207678102529954, 7.573476263010765, 7.811416186192378, 7.921476368581314, 8.03153655097025, 8.219906444505776, 8.408955212915727, 8.479342774234649, 8.42264129722332, 8.365939820211992, 8.660388531137357, 8.964099739853781, 9.335916549431857, 9.78585627173941, 10.235795994046963, 10.410750479485229, 10.573407902164181, 10.662978792520713, 10.662730317775763, 10.662481843030813, 10.434740120460019, 10.192500744817572, 9.896844883538002, 9.530823725337624, 9.164802567137247, 9.326082470044524, 9.531361029266886, 9.922411742908457, 10.575947098367775, 11.229482453827094, 11.521685863624217, 11.776347203127031, 12.067551457410893, 12.41418461024653, 12.760817763082164, 13.616045568653188, 14.534921169373145, 14.665212432731098, 13.510078963829741, 12.354945494928387, 12.39600872544443, 12.613180317782282, 12.639099562001855, 12.329575113028971, 12.02005066405609, 12.220415495409721, 12.507556648392537, 12.774678323829857, 13.003961723756566, 13.233245123683274, 13.116111789749878, 12.931743245159415, 13.000861561788332, 13.587187873084224, 14.173514184380117, 14.621933783480998, 15.040153822819153, 15.298798772116005, 15.205235078587474, 15.111671385058944, 15.145608589403254, 15.210776320761603, 15.18240553501404, 14.930167207349587, 14.677928879685133, 14.605398560379504, 14.581754453885772, 14.449460753536718, 14.034415764790328, 13.619370776043937, 13.392953714273379, 13.223186165812963, 13.015110258041393, 12.69827803905753, 12.381445820073667, 12.203493524911734, 12.071358282024006, 11.920575946779074, 11.711823946063529, 11.503071945347981, 11.604466189046336, 11.817780161038517, 11.938707943674313, 11.743732585135772, 11.548757226597232, 11.836270590458438, 12.313543767742814, 12.728985681181774, 12.910657588469261, 13.092329495756749, 13.090687565066347, 13.010713787945816, 12.960930987121806, 13.038147789860817, 13.11536459259983, 13.17396046599096, 13.223934311597905, 13.241561460074722, 13.106655793405489, 12.971750126736254, 12.700302673944861, 12.360504955496323, 12.094681032203418, 12.223499411356746, 12.352317790510075, 12.133184206091935, 11.726117488403311, 11.31273839772456, 10.860823798540633, 10.408909199356703, 10.15000470965375, 10.003390109790471, 9.854295233993918, 9.687621084760437, 9.520946935526956, 9.407133613218877, 9.326399019345068, 9.283855209386898, 9.561583072350109, 9.839310935313318, 9.961219669809099, 9.97837519420214, 9.99754055037562, 10.037168507206736, 10.076796464037853, 9.905825015413722, 9.582902427788559, 9.282846802761677, 9.276070463787912, 9.269294124814147, 9.305602026430803, 9.375246295520153, 9.462270029283845, 9.846615860978968, 10.230961692674091, 10.236826583457598, 9.928836836972675, 9.631447182609973, 9.601528231500343, 9.571609280390712, 9.55959720641815, 9.563492405255891, 9.569271325978072, 9.662803476438222, 9.756335626898373, 9.796784784255834, 9.786732775893187, 9.776023866012327, 9.597962483282162, 9.419901100551995, 9.435023629961318, 9.64695123047577, 9.858878830990221, 10.313019064366157, 10.770398077035821, 11.060331833368162, 11.167591728881723, 11.274851624395286, 11.663730364058237, 12.061408874980161, 12.291017278845082, 12.32423488375083, 12.357452488656575, 12.307496740028016, 12.253401910025952, 12.116922339143736, 11.877293858992706, 11.637665378841676, 11.576235417868713, 11.527093865876338, 11.228246241648272, 10.594226571529894, 9.960206901411516, 9.647344942385583, 9.363023847887925, 8.980737045716337, 8.45737715514625, 7.934017264576164, 7.007261908383784, 6.036321034987168, 5.341430104573305, 5.073395155816819, 4.805360207060333, 4.934667575035175, 5.11602634115525, 5.181455286906175, 5.054179475847799, 4.926903664789423, 4.9683909606482395, 5.035751439750439, 5.177445032523899, 5.452140767454386, 5.726836502384874, 6.120393878149216, 6.534932766287154, 6.957333604202218, 7.3949002808506625, 7.832466957499106, 8.476146212090299, 9.161188887522496, 9.731302235678989, 10.061959797285878, 10.39261735889277, 10.757125288219456, 11.129278665890807, 11.276827283829569, 10.91784823476581, 10.558869185702052, 10.5618223627306, 10.656024104372769, 10.70777285348877, 10.655633458187044, 10.603494062885316, 10.8381666737803, 11.153008758352488, 11.237502507438823, 10.708549950118563, 10.179597392798303, 10.084442827581858, 10.122963032342948, 10.105478159969882, 9.925130970230118, 9.744783780490351, 9.524063350676196, 9.289693092148468, 9.086384563458068, 8.982099020020497, 8.877813476582928, 8.838020669091279, 8.822053727225844, 8.792053103256949, 8.712778890260093, 8.63350467726324, 8.487618927132102, 8.314936528403997, 8.163903532613395, 8.097049664727848, 8.030195796842303, 7.972797940945226, 7.919529945470492, 7.886314814867261, 7.940015133910175, 7.993715452953089, 7.901312089307929, 7.739810016544189, 7.58932811728723, 7.492513472298546, 7.395698827309861, 7.47934513657175, 7.655209284611003, 7.717269247463097, 7.150554062776073, 6.5838388780890496, 6.275738293150588, 6.110160871129729, 5.900725505621755, 5.413018850269237, 4.92531219491672, 4.748137225621966, 4.755224825767771, 4.757993453403623, 4.728800900031252, 4.699608346658882, 4.3280016750822705, 3.737927337883338, 3.178233786468252, 2.886185775364008, 2.5941377642597634, 2.3809678058133827, 2.221848222805, 2.070511683734852, 2.003135592100156, 1.9357595004654593, 1.7966600216349575, 1.6048251545000927, 1.425388759786292, 1.416611945146119, 1.407835130505946, 1.3376091949828135, 1.2189403400546475, 1.103919941901418, 1.057319903318587, 1.0107198647357558, 1.0279704395219964, 1.099160297630379, 1.1685754272200988, 1.186819164862753, 1.2050629025054076, 1.1550224209368434, 1.043192794906398, 0.9320138861830928, 0.8598273013566339, 0.7876407165301751, 0.753765659426545, 0.7570159809707109, 0.7602663025148767, 0.6969921048610709, 0.6336719275837575, 0.5774395912093102, 0.5285620900749431, 0.4796845889405761, 0.4399879220850563, 0.40045759564156064, 0.3594972305202453, 0.3169477260102963, 0.2743982215003473, 0.2654720762383966, 0.25776185328951023, 0.2365748050338224, 0.19934552119486182, 0.16211623735590122, 0.132145940397251, 0.10257386422654902, 0.08075563712529046, 0.06882834439587188, 0.05690105166645329, 0.05034016303178655, 0.04417768730553351, 0.043891756275372805, 0.051643652082964564, 0.05939554789055633, 0.05458454439256703, 0.04858822149847411, 0.04227226808683177, 0.03548718437444174, 0.028702100662051703, 0.029745462441232644, 0.03169087012745976, 0.033187366859597195, 0.03397616352722446, 0.03476496019485173, 0.03795870569316163, 0.04148172369590493, 0.0453172283287219, 0.04968248063185491, 0.05404773293498792, 0.05084226582713324, 0.04642955152890642, 0.04187412678963548, 0.03705817987421286, 0.03224223295879024, 0.027280789296518945, 0.022292730892530144, 0.020109389625555616, 0.023448608871056674, 0.02678782811655773, 0.021441795172410575, 0.014294850105033459, 0.008283744646135007, 0.004689565726466336, 0.0010953868067976637, 0.0007293476228744794, 0.0011148607295209968, 0.0012239810367153752, 0.0006961026996613781, 0.000168224362607381, 4.06722257119056e-05, 1.6952549005867626e-05, 4.194479283346946e-09, 2.3076724332745554e-09, 4.208655832021648e-10, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, {"jid": "JVASP-71466_12", "atoms": {"lattice_mat": [[3.259761886, -0.0, -0.0], [0.0, 3.259761886, 0.0], [-0.0, 0.0, 8.889409806]], "coords": [[-0.0, 0.0, 8.64945], [1.62988, 1.62988, 2.76194], [1.62988, 1.62988, 6.03647], [-0.0, 0.0, 4.77567]], "elements": ["Na", "Na", "Mg", "Be"], "abc": [3.25976, 3.25976, 8.88941], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 1.319, "forces": [[0.0, -0.0, -0.00306631], [-0.0, 0.0, 0.0049626], [-0.0, 0.0, 0.00144501], [0.0, -0.0, -0.0033413]], "energy": 1.84953024, "stresses": [[0.22519245, 0.0, 0.0], [0.0, 0.22519245, 0.0], [0.0, 0.0, 0.08865638]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5498698051787273e-06, 6.1615834102419475e-06, 0.0031613775820181268, 0.00995723990983579, 0.03398507782683199, 0.11840245181727009, 0.2028198258077082, 0.2185170064516637, 0.23281173417025652, 0.24685392297370937, 0.26080645704035155, 0.2684488440278385, 0.2696793073704082, 0.25305113743244934, 0.18397202093076984, 0.11674548821259517, 0.20204835365241763, 0.2873512190922401, 0.26613044551428067, 0.2130909226195748, 0.20666742070096691, 0.24191581632135079, 0.2528626787908212, 0.20314067877901654, 0.15774809164605047, 0.20642183887759047, 0.25509558610913047, 0.22276415705506716, 0.17043394383091975, 0.1548417808460699, 0.16811522876335241, 0.1963574242826492, 0.2567024055393705, 0.3093796504683037, 0.26883336425627447, 0.22828707804424525, 0.17548313972006804, 0.12023942891316691, 0.15986256600319337, 0.26486690907583904, 0.3107522288838526, 0.2468450769154364, 0.18395668523829872, 0.12948249892972824, 0.07500831262115773, 0.17604504801817206, 0.30114467430230685, 0.3023431101283016, 0.22888313631982088, 0.18087200472054485, 0.17403203669500475, 0.1686762138546472, 0.1724372828659823, 0.1761983518773174, 0.21999495891258503, 0.26833903240481666, 0.2999695347653179, 0.3228356034836382, 0.3157379522319755, 0.26622788575741785, 0.22559478919894657, 0.22769501293362687, 0.22979523666830715, 0.35319359934753947, 0.48572192947393655, 0.4708900397617385, 0.38923200384429724, 0.35393446042929405, 0.37622371711772695, 0.3946033473881835, 0.397727768694583, 0.4008521900009825, 0.44724235647401017, 0.49534156485581543, 0.4481333031049034, 0.36386102519164026, 0.35495781124749404, 0.42836139937771767, 0.4808439884892046, 0.46559927569339055, 0.45035456289757647, 0.4127738239947132, 0.3750582599846144, 0.4687297268292229, 0.6057309591679336, 0.6654982842848962, 0.6510604828561788, 0.6017455818595951, 0.4571693492086837, 0.31902888102715776, 0.421982051023686, 0.5249352210202143, 0.5402994124880999, 0.5315319466065853, 0.5296408320594117, 0.5335603612596516, 0.5610244413033221, 0.6434258066476922, 0.7159256667431003, 0.6276114243541338, 0.5392971819651672, 0.5224335308976953, 0.5216812876769501, 0.5303217766178421, 0.5459331782779222, 0.5869572767262985, 0.6791129458204435, 0.7565838073038358, 0.6855749473941634, 0.614566087484491, 0.6884521318040993, 0.7883099042397574, 0.7670689588365222, 0.6670938635521069, 0.6155849784566244, 0.6487588122616901, 0.6763900976250306, 0.6641332064989593, 0.6518763153728881, 0.6858500344026703, 0.7261279275444438, 0.726716934342333, 0.7047865416789307, 0.6874832187823292, 0.6772373457319799, 0.6607556079168314, 0.6100171065243618, 0.5592786051318923, 0.5912220012693706, 0.6311434752384025, 0.690988243585111, 0.7606459778192023, 0.7925385987453865, 0.7739600869335831, 0.7563302883593124, 0.7428523853513639, 0.7293744823434155, 0.7187890592639482, 0.7083752239768383, 0.6857238423724079, 0.6578776049239454, 0.6143740850243891, 0.5524902770302913, 0.5176289292061731, 0.5797013605221679, 0.6417737918381629, 0.7022578500617445, 0.7627028499305999, 0.7477021856846475, 0.7053601772231504, 0.6492513758517944, 0.5789280809883312, 0.5390813721815108, 0.5906644215429907, 0.6419316643310282, 0.6540388920290306, 0.6661461197270332, 0.746800442926381, 0.8483947404691086, 0.9075928794284471, 0.9282784773850588, 0.9180422400755713, 0.829075939926746, 0.7440068657196047, 0.7524712141045862, 0.7609355624895676, 0.7927895230217532, 0.8305641623178902, 0.8497387317259455, 0.8540600079142352, 0.8647393021769303, 0.8893090435700995, 0.9113972928319124, 0.9015018657356502, 0.8916064386393878, 0.7558139667555175, 0.5942353531687955, 0.5270897189657607, 0.5261114104523692, 0.5305284752303217, 0.5451437311426166, 0.5659950256260747, 0.6405721908743544, 0.7151493561226339, 0.6649983446835531, 0.5948792492742718, 0.6659564340793608, 0.823573274498305, 0.9107788386678426, 0.882122461497682, 0.8586880785700152, 0.8684287178877408, 0.8781693572054663, 0.8964773049644206, 0.9158010633438831, 0.9101299581975791, 0.8911180363230231, 0.8759750771257595, 0.8663996495859736, 0.8570606503415993, 0.8488925340839926, 0.8407244178263859, 0.7960257627923174, 0.748407820642375, 0.8395169148428062, 0.9947162772974583, 1.0292843209929194, 0.9115621477956924, 0.8406624917040137, 0.9570529040298997, 1.0734433163557857, 1.092286137567146, 1.106852342049487, 1.1424787191108867, 1.1864585724464938, 1.251170311963344, 1.3388883217990801, 1.3758558143498199, 1.2448215945128407, 1.1137873746758615, 1.030412064976822, 0.9475181584102799, 0.9928894536353512, 1.0814730895699634, 1.1140847945784225, 1.0920519662330093, 1.0467659069521786, 0.9366526583982625, 0.8282695560472696, 0.796799316701757, 0.7653290773562443, 0.8338229036790274, 0.9305117228818938, 0.9271104028530148, 0.8377580710690425, 0.7985132689117964, 0.8784431318101301, 0.9549103630406663, 0.9707174172815238, 0.9865244715223812, 0.9921314308521761, 0.9953767918528764, 0.9843590521229235, 0.9625814293858305, 0.9240390334146336, 0.8511492971596337, 0.785994064594903, 0.8030404177568494, 0.820086770918796, 0.8075224712685776, 0.7894851125679626, 0.772908779109649, 0.7572983726387512, 0.7554079575314773, 0.7779086381812499, 0.7958421993334034, 0.7796384774625738, 0.7634347555917441, 0.7456885899665051, 0.7277240875541214, 0.7309182318654055, 0.7463269892692843, 0.7548347898719099, 0.7526390656401535, 0.7706337259617567, 0.9030405654147101, 1.0354474048676634, 1.0096749387662096, 0.9678754945698795, 0.9773652676626169, 1.012576660236081, 1.075626134051079, 1.1764933786085094, 1.2388133738083662, 1.1278825449781777, 1.0169517161479895, 0.96346424447017, 0.9136433700570755, 0.9187606267649611, 0.9476476765933957, 0.9533427958077441, 0.9313703486760276, 0.9268270602897821, 0.9862992801910249, 1.0457715000922676, 0.9815900137596953, 0.9138464865074879, 0.9454169331984505, 1.0137199455934394, 1.0024421237818975, 0.90766965624716, 0.8436714655554635, 0.8739976843719446, 0.9042983265748482, 0.9282303920052719, 0.9521624574356956, 0.9923684495099204, 1.037657374052797, 1.0195274390534417, 0.9428570183240428, 0.9045745654541909, 0.9659903600442281, 1.0213438071045375, 0.9143610592276653, 0.8073783113507931, 0.8336904266955869, 0.8945853403057904, 0.9447599978810512, 0.9862341578049516, 0.9948702004402383, 0.9304149497575472, 0.8737926931914883, 0.9245285324500325, 0.9752643717085767, 0.9310504567821771, 0.8668350197885313, 0.8668444102088498, 0.9126029927586101, 0.9388329143798567, 0.9274901491863704, 0.9163560181795676, 0.9070995948527967, 0.8978431715260258, 0.9181752846249714, 0.943404334987625, 0.9961009490081009, 1.0659202001135677, 1.081772330226022, 1.0072944663790937, 0.9307981344849039, 0.8410288460381737, 0.7512595575914436, 0.6743467530288522]}, {"jid": "JVASP-71466_21", "atoms": {"lattice_mat": [[3.261411755, -0.0, -0.0], [-0.0, 3.261411755, 0.0], [-0.0, 0.0, 8.890773868]], "coords": [[-0.0, 0.0, 8.64802], [1.63071, 1.63071, 2.76513], [1.63071, 1.63071, 6.03721], [-0.0, 0.0, 4.77658]], "elements": ["Na", "Na", "Mg", "Be"], "abc": [3.26141, 3.26141, 8.89077], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 1.3144, "forces": [[0.0, -0.0, -4.183e-05], [-0.0, 0.0, 0.00013541], [0.0, -0.0, -0.00058908], [-0.0, 0.0, 0.00049551]], "energy": 1.8495935, "stresses": [[-0.02011224, 0.0, 0.0], [0.0, -0.02011224, 0.0], [0.0, 0.0, 0.01141287]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.911324986646523e-07, 6.259516968770401e-07, 0.0003951825827094167, 0.0010232891595801093, 0.004422179780323423, 0.01222071072569293, 0.019104318459621863, 0.020587235412496817, 0.022070152365371774, 0.02454598740195603, 0.02712809758987698, 0.02740132628999047, 0.026494455040378815, 0.023747329437507642, 0.018461112943343155, 0.01524766222422311, 0.021563136708718407, 0.0278786111932137, 0.02523465298338768, 0.021989412917977916, 0.021195844974615786, 0.021475362990552724, 0.021429825420221593, 0.020993371440266497, 0.02099435319507851, 0.02261686162138399, 0.02423937004768947, 0.021305599553588433, 0.01823528142970296, 0.017518806136128502, 0.017675797903346363, 0.020172889804039276, 0.025127136786260846, 0.028187040613164287, 0.025448468788422645, 0.022692605949686178, 0.016272507489178776, 0.009852409028671372, 0.016026844848857674, 0.026110204951625797, 0.029023386112019593, 0.02536006866655917, 0.020636485499533622, 0.013187073245494093, 0.006564278409460411, 0.020590752989624407, 0.034617227569788404, 0.030749809599053716, 0.022323004917092083, 0.018140710454429932, 0.017356337551937506, 0.016850963727867586, 0.01695558770387969, 0.01719512431195324, 0.019166165495992774, 0.021137206680032307, 0.027920308684987218, 0.03568389978466048, 0.03404252647626188, 0.025843346346227862, 0.021185918041844803, 0.02317502890819042, 0.02597464189244118, 0.035604371184143356, 0.04523410047584553, 0.04475915228938242, 0.04270077581671872, 0.03997956977977285, 0.03685721868388022, 0.03636301847533464, 0.04013104360160672, 0.04362825090160032, 0.04546009098472979, 0.04729193106785927, 0.04321553760912447, 0.038470399070927856, 0.038720762537351544, 0.04158180480761257, 0.04347963904781664, 0.04402091231498279, 0.043910699994097434, 0.040700595085582866, 0.0374904901770683, 0.051029494682049606, 0.06578833709429834, 0.06734352257373942, 0.06298221910989349, 0.054680918016527275, 0.041544779341781315, 0.03295792978849485, 0.04175036035450183, 0.050542790920508805, 0.0537409063265661, 0.05674156822173518, 0.05512968350961247, 0.0517625589322802, 0.055024534220981317, 0.0653878029620843, 0.07091907811643197, 0.0612536279841154, 0.051588177851798825, 0.054952084272948704, 0.058319878488677826, 0.05728224481573713, 0.05483927992552676, 0.056137121496620285, 0.06093538116921863, 0.06437539202579379, 0.06423540992996536, 0.06438003502098916, 0.07273693235983611, 0.08109382969868306, 0.0777894923840897, 0.0714215664035891, 0.06818429301607555, 0.06750440458159597, 0.06671651315387925, 0.06568692034434943, 0.06494025333311547, 0.0681152791378122, 0.07129030494250894, 0.0723332310008428, 0.07292620242081828, 0.07078808245749134, 0.06670601799358383, 0.06293466946889577, 0.05975944835300007, 0.05732492479381413, 0.061477315124834064, 0.06562970545585399, 0.06604770253774071, 0.0658554084539062, 0.07097874024615303, 0.07938806117940546, 0.08255653919499438, 0.07704331198409863, 0.07249489995345339, 0.07413497583326764, 0.0757750517130819, 0.07114965123104212, 0.0657760712967766, 0.06040296418631784, 0.05503010969364845, 0.053092074567736916, 0.056092290037308375, 0.060266632976998365, 0.07022791325664629, 0.08018919353629422, 0.07472061650769375, 0.06803909761515534, 0.06229054348064602, 0.056969889369583276, 0.056095492547031126, 0.060788570211620346, 0.06463389061043495, 0.06513977785618492, 0.06564566510193488, 0.07455044731481808, 0.08379687166256676, 0.08890165418310923, 0.09239069142384154, 0.09141205266655607, 0.085550767025123, 0.08014687494577179, 0.07622157903628332, 0.07229628312679487, 0.07699577629768622, 0.08174116595028151, 0.08268897859805643, 0.08239214424300938, 0.08571895801070346, 0.09250504507605428, 0.09657746774962823, 0.09331488319363083, 0.08969332037991791, 0.07383522867954379, 0.05797713697916967, 0.0522921777032876, 0.049361395249633344, 0.050017192996377274, 0.05366249656273132, 0.05827830565069428, 0.06511768126004593, 0.07123322076966393, 0.06645297522604963, 0.06167272968243535, 0.07121163870294468, 0.08387797394322914, 0.08740933138528695, 0.08430371094478523, 0.0834484625684618, 0.08700785357790684, 0.09028222576760082, 0.09087539249877527, 0.09146855922994972, 0.08805143884259335, 0.08395166741717783, 0.08478177353071428, 0.08872441368091269, 0.09024368850092067, 0.08766196893539763, 0.08492384459028378, 0.08113791719130471, 0.07735198979232565, 0.08636764845987141, 0.09699256084580556, 0.09678472928599449, 0.09066087949681494, 0.0900800546476964, 0.09763501163054218, 0.10373646646736247, 0.10241104376038036, 0.10108562105339826, 0.10868770187773545, 0.1170437345168647, 0.13105175393671972, 0.14771253908641133, 0.1468417812198929, 0.12356930081597799, 0.10410363730725042, 0.10010915924636525, 0.09611468118548008, 0.10242439161360435, 0.10920927891016156, 0.10985117571248589, 0.1080368472796657, 0.10079502608536035, 0.08750158872579984, 0.0781578196794602, 0.08194304953844297, 0.08572827939742575, 0.0896694933669503, 0.09361232882834768, 0.08704795067761013, 0.0769467162037336, 0.07852108034395486, 0.09146643137137894, 0.10034223953466755, 0.09793413078535387, 0.0956625963727629, 0.09906480030379068, 0.10246700423481846, 0.1013670656464474, 0.0990117150608597, 0.09301738609002407, 0.08392820462389192, 0.07776900142964599, 0.0784839689679434, 0.07925091877830774, 0.08087309238541521, 0.08249526599252269, 0.07883794952440126, 0.0739881591085927, 0.0729748073417543, 0.07480642421208472, 0.07717780622791298, 0.08063207103254814, 0.08318207411762449, 0.07670955206271163, 0.07023703000779877, 0.07188718249769062, 0.07497590512308001, 0.07550763420714116, 0.07439070610898714, 0.07873346403415138, 0.092515862672872, 0.10435268435041523, 0.10255637004657911, 0.10076005574274298, 0.09757641103743375, 0.09420953757450014, 0.10303529676054965, 0.1186664611991117, 0.12439669301919198, 0.11528988829635137, 0.10604197266925372, 0.09604589205191008, 0.08604981143456643, 0.08954159502590185, 0.09425212853760104, 0.09415310881071753, 0.09174446993100668, 0.09500053789453816, 0.10564343625797748, 0.1115245311312743, 0.09742750747119906, 0.08333048381112383, 0.09287506882645628, 0.10363979589861841, 0.10018412279986867, 0.09090236823390352, 0.08579175538144991, 0.08542612949740455, 0.08654045010555031, 0.09271549606862028, 0.09889054203169026, 0.10307384518170853, 0.10722623551272847, 0.10120062098039798, 0.09165752968720138, 0.08875688486009564, 0.09245595104777031, 0.09345070712288236, 0.08675106274483817, 0.08031009551397147, 0.08758506015368284, 0.0948600247933942, 0.09806854715027505, 0.10010971889022831, 0.09799426620218395, 0.0922719584488631, 0.08997939361251212, 0.09592860047616548, 0.10119433029755918, 0.09407621791706643, 0.08695810553657365, 0.08587989610306056, 0.08621153079580458, 0.08843965860317939, 0.09210314345823892, 0.09369751158223817, 0.09104596837840002, 0.08891411915583958, 0.09229864402695961, 0.09568316889807965, 0.10048173341060507, 0.10554057646723888, 0.10463483655325757, 0.0998020834605606, 0.09310941904373789, 0.08313087797830315, 0.07350985528734022, 0.0665133205011716]}, {"jid": "JVASP-71466_0", "atoms": {"lattice_mat": [[3.222033, 0.0, 0.0], [0.0, 3.222033, 0.0], [0.0, 0.0, 9.046491]], "coords": [[0.0, 0.0, 8.84526], [1.61102, 1.61102, 2.80752], [1.61102, 1.61102, 6.14788], [0.0, 0.0, 4.81558]], "elements": ["Na", "Na", "Mg", "Be"], "abc": [3.22203, 3.22203, 9.04649], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 1.3361, "forces": [[0.0, 0.0, -0.04019519], [0.0, -0.0, 0.01790107], [0.0, -0.0, -0.02440596], [0.0, 0.0, 0.04670007]], "energy": 1.85322105, "stresses": [[1.83001484, 0.0, -0.0], [0.0, 1.83001484, 0.0], [-0.0, 0.0, -2.22477886]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2382713725934323e-06, 7.648962656127816e-06, 0.0017688097935619372, 0.012363671989784256, 0.022958534186006574, 0.15195545315296752, 0.2915387175627153, 0.4037820517903403, 0.502562742114208, 0.5601485166496385, 0.5613525851369369, 0.563349360506224, 0.5691012098442656, 0.5748530591823072, 0.48648704739439974, 0.390647180131256, 0.42742300670471284, 0.5270205002466606, 0.5577456357513536, 0.4974294836815698, 0.45251137356022564, 0.476424605300063, 0.5003378370399003, 0.4724671800687186, 0.4409928378823422, 0.4693269505033771, 0.524902048204903, 0.5269527471287893, 0.460656906075438, 0.40436662572026394, 0.39036958093219043, 0.3763725361441169, 0.4675178048069665, 0.5649662709242769, 0.5488092452235395, 0.48292995266585614, 0.4125438743984361, 0.33659790953063173, 0.30648512474127065, 0.45991170219391686, 0.6133382796465631, 0.5808754593740193, 0.5390289769260345, 0.4141175441146975, 0.2542931533558228, 0.28891586357136273, 0.5553252934231929, 0.7365129368900911, 0.5938401020536969, 0.4511672672173026, 0.4079392478377619, 0.3688061795869606, 0.36582676406016185, 0.3774302115921454, 0.399699845386124, 0.43425600268269693, 0.4974234064853784, 0.6639327116193792, 0.83044201675338, 0.6848717113340805, 0.5293024627058127, 0.5147611423516752, 0.5547703549726613, 0.6625429637378359, 0.845752438552313, 0.9824031055794858, 0.9589906523456029, 0.9355781991117199, 0.8666096390687673, 0.796590540036842, 0.8294567745954571, 0.9004565438402319, 0.9575420646922663, 0.9996568223120597, 1.0102544806097191, 0.9175873756649682, 0.8249202707202172, 0.8846232183024532, 0.9464952832418642, 0.9826072397536946, 1.009578269205195, 0.9785372181150106, 0.8871684647882031, 0.883982945608798, 1.1564929651545155, 1.4290029847002332, 1.4210506093683115, 1.4115382181793958, 1.1818153608749231, 0.8773497678958293, 0.8189838145940395, 1.0079800519566469, 1.159805508657762, 1.2006205704210884, 1.2411366578491416, 1.1813876250953035, 1.1216385923414656, 1.257450901455393, 1.456689519594648, 1.4733447665652457, 1.3126182999894926, 1.193742828820822, 1.1943843103533414, 1.195263291660179, 1.2165974536824518, 1.2379316157047244, 1.3341533251008129, 1.4535587207057528, 1.4880991327743331, 1.4429513164083194, 1.4529837296389405, 1.6138430935892742, 1.7684219940848094, 1.6149750264769676, 1.4615280588691255, 1.4446726977884525, 1.4681330895158218, 1.4629403750587378, 1.4317436987940468, 1.4351203181616503, 1.5290258656271003, 1.6193788182499864, 1.5885727820695412, 1.557766745889096, 1.5323250863205093, 1.5083910587717024, 1.4334781869983657, 1.3138521428619316, 1.2495778230351833, 1.324260488522776, 1.400783166639336, 1.525208014737774, 1.6496328628362118, 1.6883742019225951, 1.704217564265444, 1.6925522574587215, 1.6575706543302553, 1.6183532817593025, 1.5689333292845784, 1.5197889449853843, 1.4764260355763823, 1.4330631261673803, 1.3378183537547188, 1.2294097364598338, 1.2218375872087492, 1.296852128001088, 1.385965644709817, 1.5076850418559855, 1.6216667029554492, 1.6000946918373629, 1.5785226807192765, 1.4312246885761803, 1.2536895136927642, 1.2349160181518029, 1.3417714188567738, 1.4279270497324335, 1.4680904020503482, 1.516066154764881, 1.681218905177428, 1.8463716555899747, 1.9461027757758296, 2.030947190402152, 1.98178628345869, 1.8301883090505626, 1.7214062334054947, 1.7040766724523277, 1.6929934117987646, 1.7636287497575094, 1.8342640877162542, 1.8708324347251362, 1.9000808976792944, 1.905466853510445, 1.893234151766423, 1.889754094525274, 1.9042562177622366, 1.894498518054265, 1.6038974975374494, 1.3132964770206335, 1.227806157676529, 1.1838392457271107, 1.1808711944447432, 1.207135979986206, 1.2732271052459423, 1.4180608133462107, 1.538269540173324, 1.403331199019879, 1.268392857866434, 1.5186722367333028, 1.8422477383574465, 1.951584673435441, 1.9134358453088918, 1.8898364495742028, 1.8939338021392795, 1.907278135198792, 2.007176760383127, 2.1070753855674615, 2.037478486323439, 1.9376530859083818, 1.90727838146127, 1.9230548652856545, 1.9223026541019086, 1.8912424906812468, 1.851369358209804, 1.7363977550395522, 1.6214261518693003, 1.8822978368373215, 2.205804802816518, 2.2145935877214193, 2.0215588104057858, 1.9603249770863684, 2.1319802183025485, 2.2851720417335777, 2.2941991783568745, 2.3032263149801717, 2.478096097794622, 2.6787031791262135, 2.8196102294949195, 2.9235815095485016, 2.9238013634204765, 2.7472950909647342, 2.5621552964799337, 2.3149033466882267, 2.0676513968965198, 2.191284689752145, 2.368307844396644, 2.4011848553002744, 2.348048598639321, 2.2370888549535697, 2.0311479354109436, 1.8458571546840459, 1.7981057043250448, 1.7503542539660435, 1.8794300067325107, 2.0320114904270477, 2.0349173241858973, 1.9517121832777944, 1.9084871981083098, 1.9286116019272532, 1.9638318262973973, 2.09252407261537, 2.2212163189333434, 2.2666094579204303, 2.301830202886994, 2.2021087663860643, 2.0275630751597222, 1.8992544023502267, 1.8416391094701026, 1.7891463407177213, 1.766248362415806, 1.7433503841138909, 1.6936795047820392, 1.6410231349245328, 1.62130759028983, 1.619189321326671, 1.6633807776907252, 1.775910994570689, 1.8575219292986573, 1.7719202723321665, 1.6863186153656757, 1.6876327628720107, 1.6977343665240487, 1.6618083335215554, 1.602203444573691, 1.598484595038343, 1.6743832092516462, 1.7801621602291353, 2.037639520786864, 2.295116881344592, 2.1795615764026772, 2.03010220971917, 2.105379255047157, 2.2919437808863248, 2.468446121187487, 2.6311063880867995, 2.717707096576933, 2.4408769426966272, 2.1640467888163215, 2.0665212876933863, 1.9834943633056121, 2.0300019916783896, 2.1382221125450283, 2.1948390789361216, 2.182895774722648, 2.162799734024553, 2.105954487280192, 2.0491092405358304, 2.044613584511037, 2.043835537398491, 2.091107787128243, 2.1603923066743413, 2.1569248490094135, 2.0600884628053033, 1.9833225060607858, 1.9920799045733928, 2.0008373030859996, 2.04323818679405, 2.0877030248449624, 2.1542785224104835, 2.2305882184619574, 2.219247329852143, 2.0992293380900726, 2.003349091452356, 2.004884938213588, 2.0064207849748197, 1.9634878658391461, 1.9182490839280875, 1.9665373699471302, 2.05437091531275, 2.095826110388851, 2.081719123436148, 2.070337219423259, 2.06938972983031, 2.0684422402373617, 2.0348747586687326, 1.9999200267089612, 1.9495476706208235, 1.892918677937179, 1.9056861688086095, 1.9987929990493292, 2.0620269263220603, 2.016567428549423, 1.9711079307767851, 1.9576385404072072, 1.9452365193765775, 2.0384824445401515, 2.172846737587587, 2.232352904118251, 2.2081073620266967, 2.1697032547223793, 2.0822752434429965, 1.994847232163614, 1.8141689106871275, 1.6312187101144733, 1.52234253579299]}, {"jid": "JVASP-71465_8", "atoms": {"lattice_mat": [[1.749093212, -3.029519702, -0.0], [1.749093212, 3.029519702, 0.0], [0.0, -0.0, 4.308384392]], "coords": [[0.0, -0.0, 4.12662], [1.74909, 1.00984, 1.29954], [1.74909, -1.00984, 3.19061]], "elements": ["Be", "Cu", "Se"], "abc": [3.498186, 3.498186, 4.30838], "angles": [90.0, 90.0, 120.0001], "cartesian": true, "props": ["", "", ""]}, "efermi": 4.8577, "forces": [[-0.0, 0.0, 0.00048712], [0.0, -0.0, 0.00260708], [-0.0, -0.0, -0.0030942]], "energy": -3.50399833, "stresses": [[1.17475689, 0.0, -0.0], [0.0, 1.17475797, 0.0], [0.0, 0.0, 1.97716058]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2020413482097373e-10, 1.807024719534724e-08, 3.5920290255873506e-08, 0.00010366207646017127, 0.00030307943914403417, 0.0006502919166716855, 0.006350727152870177, 0.012051162389068669, 0.030429736730826638, 0.059866132308783265, 0.08960790746117016, 0.12642108500186788, 0.1632342625425656, 0.23577940834809716, 0.3377181551983175, 0.43741926913403334, 0.4993487934478283, 0.5612783177616232, 0.6023121577634629, 0.6271395214608364, 0.6593148409105117, 0.788429978041764, 0.9175451151730162, 1.022316400238969, 1.1092940050764957, 1.1893567758752432, 1.1949728523433438, 1.2005889288114446, 1.199994263276746, 1.195123725325149, 1.1919066911376255, 1.2036510220282413, 1.215395352918857, 1.1999941558530742, 1.1670018212289528, 1.1302575046532204, 1.0643674908967065, 0.9984774771401924, 0.9899268500080645, 1.0163239330136637, 1.0122091032636098, 0.8013082006978096, 0.5904072981320094, 0.5417449826579128, 0.5859980425680706, 0.6329210854134435, 0.6958308583800812, 0.7587406313467189, 0.8061493598509123, 0.8452251729360736, 0.8818886310281808, 0.9056616664467838, 0.9294347018653869, 1.0022926291092331, 1.0998878662755873, 1.1792596747311654, 1.1710106951421317, 1.1627617155530983, 1.189964931650622, 1.2338947013224393, 1.2808727759506955, 1.3411295988054148, 1.4013864216601342, 1.35670753545892, 1.2657531873384265, 1.226108348282283, 1.390129325258818, 1.5541503022353529, 1.6018352429433167, 1.6016553973142218, 1.6401268317473061, 1.8190840172289584, 1.9980412027106105, 1.917836755040185, 1.7383551742394463, 1.548315975474747, 1.3229903671602719, 1.0976647588457968, 1.023819783361239, 1.0038750569921708, 1.0229565604286814, 1.1624165424393471, 1.3018765244500128, 1.2739905998438175, 1.1909427881453158, 1.1294414348361346, 1.1294705982000706, 1.1294997615640063, 1.1506286669478483, 1.1781808859827603, 1.2797155562405917, 1.5773892201361492, 1.8750628840317067, 2.554010071569353, 3.339775602294123, 4.129960753651737, 4.931050032891681, 5.7321393121316255, 5.3904264654464615, 4.75526032718439, 3.960654269931295, 2.7991916279520157, 1.6377289859727369, 2.2911962160912647, 3.3698008610979535, 5.168023931473689, 8.51330409337819, 11.858584255282691, 11.322274517513488, 9.961129634978597, 8.967750289700097, 8.714306210710664, 8.46086213172123, 7.0457676255284225, 5.40820890331958, 4.205130150489256, 3.821342426473609, 3.437554702457961, 3.0549223874254148, 2.6724879122216727, 2.3666212403593265, 2.19624537083541, 2.025869501311494, 1.923140505788389, 1.8306688159402658, 1.7651651726414224, 1.7444934702742876, 1.7238217679071528, 1.6285387615818656, 1.5233566126818185, 1.440685735802489, 1.3932055439726716, 1.3457253521428545, 1.201188393399781, 1.0455543516493961, 0.9060877031334947, 0.7904066264378038, 0.6747255497421127, 0.6185425102559444, 0.5681058086984452, 0.5482292949095193, 0.5706953365646242, 0.5931613782197291, 0.5859961423092199, 0.5764787298483399, 0.5758871035201223, 0.5869492140242798, 0.5980113245284374, 0.6093768087644357, 0.6207613187661795, 0.6419957731622449, 0.6753543933770958, 0.7087130135919467, 0.735402697375138, 0.7617819202378955, 0.7586026662892578, 0.7211107212717522, 0.6836187762542467, 0.699725260347478, 0.7174866908905352, 0.7334820865964643, 0.7475435683053716, 0.7616050500142789, 0.7604703480479317, 0.7590976287638127, 0.7736467029867995, 0.8046456270734232, 0.8356445511600471, 0.8427006192313381, 0.8497353378103031, 0.7855273322059522, 0.6518702468781337, 0.5200895738263741, 0.5240030054044795, 0.5279164369825848, 0.4984201233259539, 0.4381956577179135, 0.3816044400210308, 0.45035498197666474, 0.5191055239322987, 0.5421887466892634, 0.5256492309159705, 0.5073921607983253, 0.45062954994610155, 0.3938669390938778, 0.37135680260889486, 0.37687527183566677, 0.3852813126036341, 0.4412542536094088, 0.4972271946151835, 0.5294993303453247, 0.5434865275553604, 0.5538741360951039, 0.5177010727183924, 0.4815280093416808, 0.5209934502377985, 0.6154507387131638, 0.6967745829223048, 0.6390346830181667, 0.5812947831140286, 0.5676092118738084, 0.5840901320110126, 0.6046073912027463, 0.6611222931447386, 0.7176371950867309, 0.7112723575360177, 0.6643816169007126, 0.6187229687030567, 0.582513256419279, 0.5463035441355012, 0.5552400817698757, 0.5915406487611867, 0.6309947295873354, 0.6915747817653659, 0.7521548339433966, 0.7210861110933362, 0.6378242148830544, 0.5702705657068017, 0.5957806857617726, 0.6212908058167433, 0.6223981443541743, 0.6104622886048963, 0.6062636950804137, 0.6430067213635007, 0.6797497476465876, 0.8508047174407538, 1.089153695496828, 1.284238240545699, 1.273189239505809, 1.262140238465919, 1.3228356991109727, 1.4171784535380458, 1.4780999242312833, 1.3946755719411417, 1.3112512196509998, 1.3613148105973885, 1.4698845025767413, 1.5498799129954075, 1.5173674967844042, 1.484855080573401, 1.3782889369763733, 1.2414463722680975, 1.114074466682508, 1.0208620078694917, 0.9276495490564753, 0.9371353660482414, 0.9857062727154675, 1.0275473033956057, 1.047059722164936, 1.0665721409342666, 1.062652286475721, 1.0504507371929122, 1.0461945100300576, 1.0662742045865614, 1.0863538991430652, 1.0940914249592923, 1.09778901575906, 1.1050276905961653, 1.122310420119132, 1.1395931496420988, 1.2277532471718091, 1.3373333590242378, 1.4112104719439076, 1.3910507526735465, 1.3708910334031854, 1.4019229616872442, 1.4471876863097022, 1.4640909906618949, 1.4114631027370408, 1.3588352148121867, 1.3691975401296985, 1.3956052857848156, 1.393924732516524, 1.3280127548520795, 1.262100777187635, 1.1651666983532958, 1.0610273540191968, 0.969691232845549, 0.9057149846959507, 0.8417387365463523, 0.8670469757685292, 0.9111573019675058, 0.9172095775948357, 0.8471382085449886, 0.7770668394951414, 0.7645362389843968, 0.76291879803844, 0.7630970573810768, 0.7666420209321204, 0.7701869844831641, 0.8596331596733124, 0.9636347922482607, 1.0448401152339128, 1.0859332279353406, 1.1270263406367684, 1.1500253602636736, 1.170311998761, 1.1762086000667267, 1.1583155453842142, 1.140422490701702, 1.1203019837480415, 1.0998896659306088, 1.099693704900946, 1.1309283723964259, 1.1621630398919058, 1.1172128469894906, 1.0636750979812097, 1.0074003384311958, 0.9471205857362438, 0.8868408330412916, 0.9458981760248341, 1.0162943717755677, 1.050332456885159, 1.0342635238012912, 1.0181945907174232, 1.0400133808481364, 1.0647823077984027, 1.0644868771543272, 1.031639905817793, 0.9987929344812587, 1.0286855462357876, 1.0624205865544882, 1.0836162879385094, 1.089458582233779, 1.0953008765290488, 1.0689447724043628, 1.0411356267234795, 1.0345411663948458, 1.0524448442354106, 1.0703485220759756, 1.056294912418875, 1.0412987672931857, 1.0652667374834184, 1.1316804598022032, 1.1980941821209883, 1.1866659247039133, 1.1741229472739834, 1.1625846568653548, 1.1520789760660608, 1.1416015472071674, 1.199599963702657, 1.257598380198147, 1.3373616209795118, 1.4382313813268897, 1.5366789089691628, 1.4754498971085044, 1.414220885247846, 1.3691919280862308, 1.338985130055678, 1.3088138135951428, 1.2798108631546152, 1.2508079127140879, 1.2085927004986616, 1.1549739345390035, 1.100831337233705, 1.035308396701067, 0.9697854561684289, 0.9621232855333702, 1.0015595239763277]}, {"jid": "JVASP-71465_7", "atoms": {"lattice_mat": [[1.751834541, -3.034267829, -0.0], [1.751834541, 3.034267829, 0.0], [0.0, -0.0, 4.316398111]], "coords": [[0.0, -0.0, 4.13103], [1.75183, 1.01142, 1.3066], [1.75183, -1.01142, 3.19517]], "elements": ["Be", "Cu", "Se"], "abc": [3.50367, 3.50367, 4.3164], "angles": [90.0, 90.0, 120.0002], "cartesian": true, "props": ["", "", ""]}, "efermi": 4.8136, "forces": [[0.0, -0.0, 0.00556405], [-0.0, 0.0, -0.01842332], [0.0, -0.0, 0.01285927]], "energy": -3.50358024, "stresses": [[-2.90602405, 0.0, 0.0], [0.0, -2.90602673, -0.0], [0.0, 0.0, -1.1536657]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.459968750095036e-08, 1.5634804017187617e-07, 2.78096392842802e-07, 0.0010697162813226404, 0.0024192346303613405, 0.01275548564111244, 0.04399197069685672, 0.07522845575260101, 0.10661908994103572, 0.13804613171716426, 0.19219072338847717, 0.2952387638746016, 0.39828680436072594, 0.4729018261413976, 0.5415040328252202, 0.6143115331985304, 0.6955160724971022, 0.7767206117956742, 0.8372410093058769, 0.8938783655744168, 0.973673478645689, 1.0964395417186508, 1.2192056047916127, 1.2896901211108123, 1.3515533860824158, 1.397962482177121, 1.4176810052849638, 1.4373995283928067, 1.4067538786357618, 1.3689095934815807, 1.3415647798467591, 1.3311204373477212, 1.3206760948486833, 1.2572944497537621, 1.1874663033968558, 1.1463011424471214, 1.1481853744923378, 1.150069606537554, 0.9712880826468986, 0.774188933459518, 0.6624016425281212, 0.6702833879466876, 0.6781651333652541, 0.7390018874697495, 0.8041666951686895, 0.8532469433382144, 0.8812387518611031, 0.9092305603839917, 0.9649700314851936, 1.0224509857092674, 1.1060790295998562, 1.2217687312596126, 1.337458432919369, 1.2989420948358215, 1.2535719256939069, 1.2816401526775285, 1.3939660859563545, 1.5062920192351805, 1.5116672359846501, 1.514181411255353, 1.4788958369303327, 1.4030192038665852, 1.327142570802838, 1.4924584866015138, 1.6601008744070396, 1.7607977760854403, 1.794199278458242, 1.8282503557749805, 1.9551569545256773, 2.0820635532763743, 2.126862654761341, 2.094398244290432, 2.054624167047084, 1.7104129539321116, 1.366201740817139, 1.180485628662489, 1.1343669249680182, 1.095266274811375, 1.2248275221038225, 1.3543887693962697, 1.3928817300318612, 1.3563200550748618, 1.3227899004563601, 1.339932453369711, 1.3570750062830619, 1.3497914916583098, 1.3236789997232945, 1.3129182310107606, 1.4972504876745116, 1.6815827443382627, 2.1946904741972895, 2.944707393963225, 3.733664992629734, 4.919019179863051, 6.104373367096368, 6.289874528556728, 5.802393634440233, 5.2353749755451275, 3.9971833466448636, 2.7589917177445993, 2.4725449658213807, 2.7839211759806832, 3.523310980704514, 7.3300895595975675, 11.136868138490621, 12.317159860828601, 11.959543420613421, 11.512033432371068, 10.507472531462732, 9.502911630554395, 8.418075683048466, 7.2894806387418605, 6.170607946861401, 5.104555595692954, 4.038503244524506, 3.4974379608298474, 3.22238413524299, 2.946709619648261, 2.6680461228665253, 2.3893826260847897, 2.243323405526508, 2.1596082282205136, 2.080939323750932, 2.024001579005154, 1.967063834259376, 1.8573492303053687, 1.724659146716463, 1.62044586608634, 1.62669890358046, 1.6329519410745799, 1.489762188729924, 1.2864777631740962, 1.109807148234463, 1.0267006210173357, 0.9435940938002084, 0.8376662482604385, 0.723284947177249, 0.6357808031088168, 0.634350585913322, 0.6329203687178274, 0.6367277868405118, 0.6423166071801195, 0.6492050178636252, 0.6599011975798676, 0.67059737729611, 0.6904195159616062, 0.7130809541755727, 0.7302413741258337, 0.7326004773462741, 0.7349595805667144, 0.7774674368345381, 0.831351802949843, 0.8625742869408861, 0.8376225005533868, 0.8126707141658875, 0.8046012192825471, 0.8008661860024403, 0.8005166357036603, 0.8079196733017799, 0.8153227108998996, 0.8418517288591655, 0.8728029494370959, 0.8913272770967887, 0.8834997723042107, 0.8756722675116326, 0.9149061016360288, 0.9638674927651741, 0.9539584548092316, 0.8282078304447369, 0.7024572060802421, 0.6435018681653321, 0.5967795410985517, 0.5506879239446315, 0.5057499980824308, 0.4608120722202302, 0.5003039619647446, 0.5533456814639734, 0.5818143821235578, 0.5684349899508749, 0.555055597778192, 0.5147560973704106, 0.4707232588015395, 0.4437420901774894, 0.4438321568561834, 0.44392222353487737, 0.4659974420628233, 0.4906599216790248, 0.539584709898521, 0.6244575609773295, 0.709330412056138, 0.6206028860631162, 0.5149592505692265, 0.5088688723885615, 0.6405622549720759, 0.7722556375555902, 0.7378810607686885, 0.6905659047633298, 0.6633022718956849, 0.6619810717100176, 0.6606598715243505, 0.716967570094988, 0.7766803410874766, 0.7946004292539803, 0.7619456090913707, 0.7292907889287611, 0.6849059080904984, 0.6400413139086814, 0.6238768624156674, 0.6402116964970803, 0.6565465305784932, 0.7369492316046131, 0.8188460099304683, 0.8205713174316162, 0.7373232239425933, 0.6540751304535704, 0.6937196377245484, 0.734141819337616, 0.7355636555513266, 0.6982900188465266, 0.6617705569712408, 0.6983472098584624, 0.734923862745684, 0.8563411914906451, 1.0565568385187163, 1.2524468530612485, 1.2907308935366733, 1.329014934012098, 1.4052386846212845, 1.5144429808156628, 1.6182457037833784, 1.6024518148273705, 1.5866579258713627, 1.6246817432349512, 1.7064785559025966, 1.780199391609894, 1.726751275473942, 1.6733031593379903, 1.5834338668736323, 1.4658598930442495, 1.3468092483639273, 1.209860006205844, 1.0729107640477606, 1.0421955007402612, 1.0870091976638443, 1.1303782257886787, 1.1596100509119671, 1.1888418760352553, 1.189993372854909, 1.172499798095556, 1.1576219126449838, 1.1640781996656866, 1.1705344866863894, 1.2107732046622073, 1.2719409417061527, 1.3195750414153165, 1.2730900696686922, 1.2266050979220677, 1.3118967736220402, 1.4732725617901192, 1.603089919031433, 1.5425611187817576, 1.482032318532082, 1.5135341752422167, 1.5944894663387488, 1.6575219794323317, 1.6255542062593407, 1.5935864330863496, 1.5855191517024971, 1.5893859745030774, 1.5787716110871353, 1.4999913832944936, 1.4212111555018518, 1.347722607323113, 1.2766847811307644, 1.2011875381799557, 1.1068893852351147, 1.012591232290274, 0.9798925660560106, 0.9735977068747516, 0.9658598513566921, 0.9526347804527636, 0.9394097095488352, 0.9119345420911639, 0.8788201150176411, 0.8525673135876611, 0.8499871725875351, 0.847407031587409, 0.9297677223186276, 1.0430735673927831, 1.1418176493239292, 1.1947561522357626, 1.2476946551475963, 1.2633042422125982, 1.266435625065206, 1.270891986679846, 1.2791645312698954, 1.2874370758599447, 1.2698778874676293, 1.2444262209333015, 1.2340869168047153, 1.2637455641390576, 1.2934042114734, 1.2825832902232004, 1.2605089072235685, 1.2192407548681192, 1.1311450394092468, 1.0430493239503744, 1.056781175099812, 1.096133853021156, 1.1269995888330122, 1.1387274103866374, 1.1504552319402628, 1.1494269958991497, 1.1455122621446867, 1.1474815672757244, 1.161743058626318, 1.1760045499769116, 1.180978649958492, 1.1840770352408079, 1.2037542936569634, 1.2555822211346934, 1.3074101486124234, 1.265715717611693, 1.207319257720226, 1.1754402463487483, 1.1913800897832836, 1.2073199332178188, 1.1463276766185635, 1.073325897092447, 1.0738221107299732, 1.1977480371698666, 1.3216739636097596, 1.358106583342917, 1.3827744954021464, 1.3669900593723354, 1.2878604717516144, 1.2087308841308932, 1.3129655766484922, 1.4380337558358038, 1.5408465097824962, 1.6111286276933665, 1.6814107456042369, 1.6435840531615415, 1.5956465669494504, 1.5536811647222126, 1.5198712741662914, 1.4860613836103702, 1.4493481491075966, 1.4124196527621764, 1.3686567371899154, 1.316168070721635, 1.2636794042533548, 1.1954231854530466, 1.1262928720608454, 1.1035666924203962, 1.1362610756202012]}, {"jid": "JVASP-71465_0", "atoms": {"lattice_mat": [[1.682443, -2.914078, 0.0], [1.682443, 2.914078, 0.0], [0.0, 0.0, 4.450628]], "coords": [[0.0, 0.0, 4.27266], [1.68244, 0.97136, 1.36651], [1.68244, -0.97136, 3.26208]], "elements": ["Be", "Cu", "Se"], "abc": [3.364887, 3.364887, 4.45063], "angles": [90.0, 90.0, 120.0001], "cartesian": true, "props": ["", "", ""]}, "efermi": 5.0586, "forces": [[-0.0, -0.0, -0.07631329], [0.0, 0.0, -0.04216169], [-0.0, -0.0, 0.11847498]], "energy": -3.42783764, "stresses": [[66.93145307, 0.0, -0.0], [0.0, 66.93151018, -0.0], [0.0, 0.0, 1.08453505]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5583840243513424e-09, 1.3422310608462383e-08, 2.4286237192573423e-08, 0.00011124804255355835, 0.000268199029286168, 0.0013800042569690934, 0.006789994434780311, 0.01219998461259153, 0.041196336264105876, 0.08262093710229095, 0.12743014003059785, 0.195286609413857, 0.2631430787971162, 0.36325391262168916, 0.48476132087959356, 0.6099121924351016, 0.7804381025149507, 0.9509640125947998, 1.1072514926578019, 1.2517695924787184, 1.3954358574157026, 1.4986160067348455, 1.6017961560539884, 1.743615995883672, 1.9250544205265454, 2.1064928451694187, 2.232355581925998, 2.3563148911590983, 2.4340911923721635, 2.453094850326006, 2.472098508279849, 2.4384791529216803, 2.3998556359797543, 2.3688353008289096, 2.34989373179605, 2.330952162763191, 2.239046297614365, 2.1352073840997954, 2.062792895062892, 2.0534406302357806, 2.0440883654086686, 1.772857154985624, 1.4384800220979455, 1.2072398965440265, 1.2427165590691411, 1.2781932215942555, 1.353278478839886, 1.4414262852056852, 1.5206175239100779, 1.568982610395626, 1.6173476968811737, 1.6966734861548591, 1.7893773305790162, 1.8980814509918698, 2.0841274020744347, 2.2701733531569994, 2.3327958368323807, 2.327355331827064, 2.3341014257307813, 2.432205966568259, 2.5303105074057366, 2.6100253079740474, 2.6770047331320947, 2.7334012545432556, 2.634950360446299, 2.5364994663493423, 2.6355585196867435, 2.9048340482014607, 3.172089003628817, 3.2442402631949263, 3.316391522761035, 3.4370887037037843, 3.609668153135258, 3.7822476025667315, 3.4822958556357335, 3.1608689244224446, 2.8228529500861694, 2.4628170092329906, 2.102781068379812, 2.0138586366474254, 1.954123388119891, 2.011243152088111, 2.2623371963230086, 2.5134312405579062, 2.453499752362187, 2.3382871238106553, 2.266102040530836, 2.284422343511976, 2.3027426464931153, 2.437067236860324, 2.6012382198316826, 2.929891784702771, 3.7066551285113873, 4.483418472320004, 5.813567688188222, 7.33649446321076, 8.727263026351036, 9.634873051891708, 10.54248307743238, 9.920999969495908, 8.605832184428074, 7.3684453182065, 6.536211677295559, 5.703978036384617, 7.11311534111064, 9.815062263939804, 12.662799109875388, 16.722986573419043, 20.7831740369627, 21.191532202085803, 18.960630454613618, 16.82038173002808, 16.280275133232365, 15.740168536436652, 13.752244308618483, 10.463576606396042, 7.1749089041735985, 6.971036041392534, 6.7674905289078, 6.28022714458657, 5.476897151901348, 4.673567159216125, 4.358119461784344, 4.070427436058493, 3.8270101217969974, 3.6449072139170062, 3.462804306037015, 3.370285695744204, 3.2885629238614027, 3.170529515522826, 2.9894859820824, 2.8084424486419746, 2.6694132289657855, 2.5384625471530966, 2.385819552310286, 2.1853127691323575, 1.9848059859544294, 1.7394323481572274, 1.4817708463351815, 1.283496615554716, 1.2558880195291837, 1.2282794235036516, 1.2225870628270084, 1.2249479408081498, 1.2314659539237758, 1.254162636426049, 1.2768593189283224, 1.2981328332388016, 1.3187290955468365, 1.3386912543982672, 1.3550798718185675, 1.371468489238868, 1.3921088604950995, 1.4153129134164264, 1.439201461865235, 1.4694635535952, 1.499725645325165, 1.4764860377682378, 1.4129052753372475, 1.3554377011042351, 1.433172076802736, 1.5109064525012372, 1.5537083476919737, 1.5637963468410234, 1.5738843459900729, 1.542624611769213, 1.5109277045339236, 1.511767789395612, 1.550394271552892, 1.589020753710172, 1.5055818320776178, 1.4137708908375841, 1.305101937308983, 1.1720688221248439, 1.0390357069407048, 0.9706265647138063, 0.9108554443550602, 0.8815346734655346, 0.9074828682503476, 0.9334310630351607, 0.9661224623935327, 1.0002109631547795, 0.9983394426697638, 0.913162111537378, 0.8279847804049922, 0.8080276175896021, 0.8070432943143246, 0.8132131231158329, 0.8411006596648519, 0.8689881962138709, 0.9602909851367939, 1.0761417727483698, 1.149849831428178, 1.0485462496957576, 0.9472426679633372, 0.992676658345907, 1.1113011147540195, 1.2116225841694621, 1.1998340482285619, 1.1880455122876619, 1.1700812840904586, 1.148226181943214, 1.1373111738624, 1.2417491911768566, 1.3461872084913133, 1.358708934405479, 1.2989414469160385, 1.2390506077303698, 1.1755341593489204, 1.112017710967471, 1.0967986575144921, 1.1287236762752972, 1.1606486950361021, 1.277752052275172, 1.3966664251157157, 1.4155972451655456, 1.3134657169084725, 1.2113341886513993, 1.2549833358907265, 1.3103742944436467, 1.333408389909474, 1.3076227883827178, 1.281837186855962, 1.3150388308933387, 1.356919504451778, 1.5086930857495826, 1.8692176438543369, 2.229742201959091, 2.330126786273353, 2.3726380430537968, 2.4410695336074264, 2.5726062921312804, 2.7041430506551345, 2.7168403146427553, 2.6928878912861856, 2.7025304094857168, 2.8200700290558043, 2.9376096486258914, 2.9226826958072634, 2.853800016654964, 2.760634367510233, 2.5595719282017106, 2.3585094888931883, 2.175883515635724, 2.002889040351418, 1.8575876546422916, 1.8976333494242414, 1.9376790442061913, 1.9946571175085346, 2.0627770300699173, 2.1265364201231303, 2.1374828500952208, 2.1484292800673117, 2.169374846929051, 2.1985211350660787, 2.2276039076204204, 2.2539331087794383, 2.2802623099384562, 2.2991979399061364, 2.310611295389575, 2.3220246508730136, 2.512969572718431, 2.7096915224502744, 2.7919051846585163, 2.7295374871360676, 2.667169789613619, 2.736706210027664, 2.818479776466215, 2.83692322166747, 2.755576667866509, 2.674230114065548, 2.693691154722509, 2.729374753460695, 2.720014032541736, 2.621044248622015, 2.522074464702294, 2.3608082165823414, 2.1847063592021696, 2.0279755290458175, 1.9208591064552065, 1.8137426838645954, 1.8239524906722668, 1.8724539558331603, 1.8852438088890462, 1.7764787962354662, 1.6677137835818865, 1.6351835475521186, 1.635291729860088, 1.6381229285324237, 1.6539388249673355, 1.669754721402247, 1.8022478493418026, 1.9985405054358851, 2.1758233633177446, 2.2132159038329515, 2.2506084443481584, 2.265503377995288, 2.2649427467366383, 2.264246132984188, 2.2616224073644937, 2.258998681744799, 2.2659681683826287, 2.281140924633696, 2.2954752370217597, 2.2418677815121364, 2.188260326002513, 2.1348069522041797, 2.0815169753672373, 2.028226998530295, 1.9413727181032936, 1.8530640581337423, 1.834788426867515, 1.90874064337039, 1.982692859873265, 2.031941485581243, 2.0785892758636373, 2.106917588232386, 2.1050865159860237, 2.1032554437396613, 2.088453138837513, 2.071380173257484, 2.0967011138793477, 2.2104072747045587, 2.3241134355297692, 2.2770069198277385, 2.1890173447140784, 2.118691071569283, 2.0960141071872425, 2.0733371428052014, 2.0163762875312634, 1.9475928948286787, 1.9286348710878363, 2.0897573627645913, 2.250879854441347, 2.3567531413515885, 2.4377880558538525, 2.4883762464831793, 2.3826202968758747, 2.2768643472685706, 2.3159113220793883, 2.437783486140561, 2.5584424004897373, 2.6692120760723292, 2.779981751654921, 2.826088673441449, 2.8258333985369695, 2.8221234555337182, 2.759684154857007, 2.697244854180296, 2.6212094745524217, 2.5330538595501437, 2.444943572942954, 2.3815057890361593, 2.3180680051293647, 2.227490112867892, 2.106912007513965, 1.9863339021600381, 2.004956659089198, 2.031197543287284, 2.074599044896223, 2.1415771787989963, 2.2085553127017694, 2.1751760427000537, 2.1299465713085715]}, {"jid": "JVASP-71464_7", "atoms": {"lattice_mat": [[3.391901186, -0.0, -0.0], [0.0, 3.391901186, -0.0], [-0.0, 0.0, 4.590427633]], "coords": [[1.69595, 1.69595, -0.0], [-0.0, 0.0, 3.42837], [-0.0, 0.0, 1.16206], [1.69595, 1.69595, 2.29521]], "elements": ["Li", "Be", "Be", "Se"], "abc": [3.3919, 3.3919, 4.59043], "angles": [90.0, 90.0, 90.0], "cartesian": true, "props": ["", "", "", ""]}, "efermi": 4.7918, "forces": [[0.0, -0.0, 0.0], [-0.0, 0.0, 0.05522697], [0.0, -0.0, -0.05522697], [-0.0, -0.0, 0.0]], "energy": -7.35492729, "stresses": [[-3.27570341, 0.0, -0.0], [0.0, -3.27570341, 0.0], [0.0, 0.0, 2.62263108]], "dos": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.728224419852447e-08, 2.5454405122121544e-07, 3.730809349878245e-05, 0.0010318014511884921, 0.002026294808878202, 0.014885275638317801, 0.03447042738986649, 0.05794515503866329, 0.09378392010763342, 0.12962268517660358, 0.1484626678478755, 0.1640747660413914, 0.18414301634208569, 0.20984778654458305, 0.23556113510669624, 0.2614671211170384, 0.2873731071273806, 0.2828288428549026, 0.2614328623958914, 0.24861145502779086, 0.262235051021533, 0.2758586470152751, 0.3172832074125046, 0.3637700188293278, 0.3892955311189106, 0.3888963005850752, 0.3881194682382424, 0.37988262683429075, 0.37164578543033905, 0.40372825959372827, 0.4575892235168395, 0.498136718754495, 0.49882811283947764, 0.4995195069244603, 0.5349138205791688, 0.5763599946923754, 0.6040492526685088, 0.6151011008080367, 0.6255284083178826, 0.6249552128754504, 0.6243820174330181, 0.6311520107972599, 0.6417926525946266, 0.6419953534811, 0.6118535395713145, 0.581711725661529, 0.5465590223063506, 0.5105705444910227, 0.4984223709089287, 0.5144687627881666, 0.5305446450258268, 0.5470886026551641, 0.5636325602845015, 0.553462209973133, 0.529553135196628, 0.5197345783912417, 0.5497115159300883, 0.5796884534689348, 0.6277011446156444, 0.6785867591409267, 0.7272614325335836, 0.7733790636042406, 0.8105704153337088, 0.7189683080002018, 0.6273662006666949, 0.604858125092187, 0.6170126621389319, 0.6223384777341611, 0.6089204428263473, 0.5955024079185335, 0.5687753675733788, 0.5400268875791394, 0.5093854796705085, 0.47660306708826133, 0.44106897734135736, 0.36917822093436087, 0.29728746452736443, 0.3228667218288041, 0.39613522900714254, 0.4463594633648006, 0.4350862103625122, 0.4238129573602239, 0.4081088314233639, 0.39176411753521745, 0.3950440423021523, 0.42003203913872056, 0.4410508521800458, 0.41374800619232954, 0.3864451602046132, 0.39516373770031543, 0.42106713187633, 0.4235279824787503, 0.36514331409535006, 0.30675864571194983, 0.34905820943038424, 0.4051872198935072, 0.42370126419539195, 0.40152161259536984, 0.3830346612501167, 0.40619122145333947, 0.42934778165656223, 0.425471365755028, 0.4090226840695504, 0.41752352130340237, 0.4890275592308978, 0.5605315971583932, 0.6333029304358201, 0.7062392985460964, 0.7529251881922496, 0.7718357632782851, 0.791442997220527, 0.8183614217753863, 0.8452798463302454, 0.8439960056347161, 0.8299290865048666, 0.8231481403392025, 0.8342734399476217, 0.8453987395560408, 0.787848140189011, 0.7218376173796758, 0.6986438557216599, 0.7197596011680379, 0.7391522749196285, 0.74164878160935, 0.7441452882990713, 0.706366538523101, 0.6508003903201618, 0.5985690703729152, 0.5543165999032701, 0.510064129433625, 0.4745975852821386, 0.4401522604857981, 0.4172594897256382, 0.40605972364460274, 0.3930764193680193, 0.363693627903643, 0.33431083643926673, 0.350499379532558, 0.38629324315231867, 0.39792786348695147, 0.353273055941248, 0.3086182483955446, 0.25837416892326237, 0.2075188107424355, 0.1857913516181926, 0.19289913281651047, 0.20116517661592276, 0.219449470307878, 0.23773376399983326, 0.23755697115329827, 0.22964576765349493, 0.22033213851821773, 0.20783562475473227, 0.19533911099124684, 0.19867507605054893, 0.2036351888721495, 0.2052943816059689, 0.20375751638808481, 0.20663291359776584, 0.2455090873917854, 0.284385261185805, 0.28480982020613976, 0.2695507711511267, 0.2536695031939461, 0.23641230352058357, 0.21915510384722103, 0.24795911576219326, 0.2811795962585257, 0.29936800279471243, 0.30332131636107174, 0.30741474540052705, 0.3125893778537169, 0.3177640103069067, 0.3249384646429821, 0.3329067927673684, 0.34652484152815644, 0.3723188406830841, 0.39811283983801177, 0.45998794247413033, 0.525083747447613, 0.5500582862513301, 0.5378731442452426, 0.527084013208209, 0.526506153416439, 0.5259282936246692, 0.5159437387874166, 0.502326006614573, 0.502681382467921, 0.5323926636547544, 0.5621039448415877, 0.5946702857677605, 0.6274728046827525, 0.6482458656881799, 0.6581223237090855, 0.6624578837166539, 0.6282945842772, 0.5941312848377461, 0.5620463306219394, 0.5307421231634805, 0.5188992906158559, 0.5469223873133879, 0.5749454840109199, 0.5554093190113939, 0.5322462642016264, 0.525059276897274, 0.532025405768412, 0.5412441239592529, 0.5653582885716646, 0.5894724531840762, 0.600676860763252, 0.6071660731712584, 0.6237764371378617, 0.6606066901193108, 0.6974369431007601, 0.6586955410313868, 0.6146735744823043, 0.6225121233993466, 0.6752813636804361, 0.7171646737394698, 0.6904205390731901, 0.6636764044069106, 0.627715569353538, 0.5884829517993073, 0.5793699889936761, 0.6289531285649783, 0.6785362681362807, 0.6750729988802132, 0.6682378472473289, 0.6635341038743783, 0.6606362099655028, 0.6677580804678287, 0.7351977921140147, 0.8026375037602008, 0.7541314681795027, 0.6656382305984521, 0.6043474190988534, 0.59477713144051, 0.5852068437821666, 0.5631560805230182, 0.5403898292084618, 0.5556796879977309, 0.602499985023714, 0.6423235694674977, 0.6418681419792021, 0.6414127144909063, 0.6112410612067521, 0.5711167725986244, 0.580081721724557, 0.6801258075677376, 0.7801698934109181, 0.7742256070002662, 0.7628584127777481, 0.7617206771380867, 0.7688705375109365, 0.7744223482572651, 0.7711644142702818, 0.7679064802832986, 0.739842739378557, 0.7037143095012688, 0.6726469029915199, 0.650744350966931, 0.6288417989423419, 0.6722517202060295, 0.718605476609512, 0.683507009645284, 0.5838838021977392, 0.5071195749701001, 0.5511830298210201, 0.5952464846719401, 0.6245565292032674, 0.6492126632337971, 0.6769901683482796, 0.7102854054934937, 0.7435806426387077, 0.7236488112407338, 0.7016383952039749, 0.700307396995079, 0.7149934869567358, 0.7311158936442896, 0.7545262519170834, 0.7779366101898774, 0.7641848238937319, 0.7390643977202245, 0.7299749110360109, 0.748553039649344, 0.7671311682626771, 0.7517983431685388, 0.735343118158569, 0.7385227025633376, 0.75657074157912, 0.7657878016065371, 0.7319439048000217, 0.6981000079935062, 0.7214022263333012, 0.7616500042583637, 0.8249923197345036, 0.9272557713308964, 1.0295192229272891, 0.9763865108598464, 0.919024933923114, 0.867698686920846, 0.8208403390239443, 0.7876825927086422, 0.8187900720954399, 0.8498975514822376, 0.8623816603525851, 0.8695157419445991, 0.8969799165065064, 0.9579052898903364, 1.0188306632741664, 1.0021651708555495, 0.9838396133933716, 0.9335060403114765, 0.8600096742254402, 0.8020425459891821, 0.8142146796078672, 0.8263868132265523, 0.7879389090810208, 0.7354110079937077, 0.7139037569813438, 0.7422663390096474, 0.7706289210379511, 0.7874089303179904, 0.8040077613469605, 0.8038006620301914, 0.7917061190016896, 0.7638328410503016, 0.6672772993209232, 0.5707217575915448, 0.5183254508995523, 0.47781481591089414, 0.446030694111358, 0.42795123473191854, 0.40987177535247904, 0.45497714913101156, 0.5007114622526014, 0.5106332073747242, 0.4957969142423811, 0.4782779554340717, 0.4494957653864074, 0.4207135753387431, 0.4298157552583104, 0.4487786591794703, 0.4851327666322229, 0.5481708338717848, 0.6112089011113467, 0.6060478547909296, 0.6005915765875071, 0.5669198063582079, 0.5141854637846599, 0.47833095220245614, 0.5108879512904636, 0.543444950378471, 0.4653462862973996, 0.3594135400963248, 0.30095080795868573, 0.3136556886538949, 0.3262823892586864, 0.27569560870477117, 0.22510882815085592, 0.20213837838300625, 0.19739967633289182, 0.1868816319956713, 0.15373671619568158, 0.12059180039569184, 0.10156585627631114, 0.08596957166734347, 0.07455770324197956, 0.06927620390333468, 0.06398238400767561, 0.056883902613048376, 0.04978542121842114]}] \ No newline at end of file diff --git a/alignn/examples/sample_data_ff_additional/config.json b/alignn/examples/sample_data_ff_additional/config.json new file mode 100644 index 0000000..32bbd93 --- /dev/null +++ b/alignn/examples/sample_data_ff_additional/config.json @@ -0,0 +1,85 @@ +{ + "version": "112bbedebdaecf59fb18e11c929080fb2f358246", + "dataset": "user_data", + "target": "target", + "atom_features": "atomic_number", + "neighbor_strategy": "radius_graph", + "id_tag": "jid", + "dtype": "float32", + "random_seed": 123, + "classification_threshold": null, + "n_val": 2, + "n_test": 2, + "n_train": 4, + "train_ratio": 0.8, + "val_ratio": 0.1, + "test_ratio": 0.1, + "target_multiplication_factor": null, + "epochs": 3, + "batch_size": 2, + "weight_decay": 1e-05, + "learning_rate": 0.001, + "filename": "A", + "warmup_steps": 2000, + "criterion": "mse", + "optimizer": "adamw", + "scheduler": "onecycle", + "pin_memory": false, + "save_dataloader": false, + "write_checkpoint": true, + "write_predictions": true, + "store_outputs": true, + "progress": true, + "log_tensorboard": false, + "standard_scalar_and_pca": false, + "use_canonize": true, + "num_workers": 0, + "cutoff": 4.0, + "cutoff_extra": 3.0, + "max_neighbors": 12, + "keep_data_order": true, + "normalize_graph_level_loss": false, + "distributed": false, + "data_parallel": false, + "n_early_stopping": null, + "output_dir": "temp", + "use_lmdb": true, + "model": { + "name": "alignn_atomwise", + "alignn_layers": 2, + "gcn_layers": 2, + "atom_input_features": 1, + "edge_input_features": 80, + "triplet_input_features": 40, + "embedding_features": 64, + "hidden_features": 64, + "output_features": 1, + "grad_multiplier": -1, + "calculate_gradient": true, + "atomwise_output_features": 0, + "graphwise_weight": 1.0, + "gradwise_weight": 1.0, + "stresswise_weight": 0.01, + "atomwise_weight": 0.0, + "link": "identity", + "zero_inflated": false, + "classification": false, + "force_mult_natoms": false, + "energy_mult_natoms": false, + "include_pos_deriv": false, + "use_cutoff_function": false, + "inner_cutoff": 3.0, + "stress_multiplier": 1.0, + "add_reverse_forces": true, + "lg_on_fly": true, + "batch_stress": true, + "multiply_cutoff": false, + "use_penalty": true, + "extra_features": 0, + "exponent": 5, + "additional_output_features":400, + "additional_output_weight":0.1, + "penalty_factor": 0.1, + "penalty_threshold": 1.0 + } +} diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index 1c4f422..fbc14d1 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -112,6 +112,7 @@ def get_torch_dataset( target_atomwise="", target_grad="", target_stress="", + target_additional_output="", neighbor_strategy="k-nearest", atom_features="cgcnn", use_canonize="", @@ -171,12 +172,12 @@ def get_torch_dataset( torch.get_default_dtype() ) label = torch.tensor(d[target]).type(torch.get_default_dtype()) + natoms = len(d["atoms"]["elements"]) # print('label',label,label.view(-1).long()) if classification: label = label.long() # label = label.view(-1).long() if "extra_features" in d: - natoms = len(d["atoms"]["elements"]) g.ndata["extra_features"] = torch.tensor( [d["extra_features"] for n in range(natoms)] ).type(torch.get_default_dtype()) @@ -193,6 +194,14 @@ def get_torch_dataset( g.ndata[target_stress] = torch.tensor( np.array([stress for ii in range(g.number_of_nodes())]) ).type(torch.get_default_dtype()) + if ( + target_additional_output is not None + and target_additional_output != "" + ): + additional_output = np.array(d[target_additional_output]) + g.ndata[target_additional_output] = torch.tensor( + ([additional_output for ii in range(natoms)]) + ).type(torch.get_default_dtype()) # labels.append(label) if line_graph: diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index dd7d15d..4fae2c3 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -68,6 +68,8 @@ class ALIGNNAtomWiseConfig(BaseSettings): exponent: int = 5 penalty_factor: float = 0.1 penalty_threshold: float = 1 + additional_output_features: int = 0 + additional_output_weight: float = 0 class Config: """Configure model settings behavior.""" @@ -336,6 +338,10 @@ def __init__( config.hidden_features, config.atomwise_output_features ) + if config.additional_output_features: + self.fc_additional_output = nn.Linear( + config.hidden_features, config.additional_output_features + ) if self.classification: self.fc = nn.Linear(config.hidden_features, 1) self.softmax = nn.Sigmoid() @@ -442,6 +448,7 @@ def forward( x, y = gcn_layer(g, x, y) # norm-activation-pool-classify out = torch.empty(1) + additional_out = torch.empty(1) if self.config.output_features is not None: h = self.readout(g, x) out = self.fc(h) @@ -455,6 +462,9 @@ def forward( # print('out',out) else: out = torch.squeeze(out) + if self.config.additional_output_features > 0: + additional_out = self.fc_additional_output(h) + atomwise_pred = torch.empty(1) if ( self.config.atomwise_output_features > 0 @@ -571,7 +581,6 @@ def forward( # print("stress1", stress, stress.shape) # print("g.batch_size", g.batch_size) else: - # print('Using batch_stress') stresses = [] count_edge = 0 count_node = 0 @@ -592,11 +601,8 @@ def forward( count_edge = count_edge + num_edges num_nodes = g.batch_num_nodes()[graph_id] count_node = count_node + num_nodes - # print("stresses.append",stresses[-1],stresses[-1].shape) - for n in range(num_nodes): - stresses.append(st) - # stress = (stresses) - stress = self.config.stress_multiplier * torch.cat( + stresses.append(st) + stress = self.config.stress_multiplier * torch.stack( stresses ) # print("stress2", stress, stress.shape) @@ -614,8 +620,38 @@ def forward( # out = torch.max(out,dim=1) out = self.softmax(out) result["out"] = out + result["additional"] = additional_out result["grad"] = forces result["stresses"] = stress result["atomwise_pred"] = atomwise_pred # print(result) return result + + +""" +if __name__ == "__main__": + from jarvis.core.atoms import Atoms + from alignn.graphs import Graph + + FIXTURES = { + "lattice_mat": [ + [2.715, 2.715, 0], + [0, 2.715, 2.715], + [2.715, 0, 2.715], + ], + "coords": [[0, 0, 0], [0.25, 0.25, 0.25]], + "elements": ["Si", "Si"], + } + Si = Atoms( + lattice_mat=FIXTURES["lattice_mat"], + coords=FIXTURES["coords"], + elements=FIXTURES["elements"], + ) + g, lg = Graph.atom_dgl_multi_graph( + atoms=Si, neighbor_strategy="radius_graph", cutoff=5 + ) + lat = torch.tensor(atoms.lattice_mat) + model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(name="alignn_atomwise")) + out = model([g, lg, lat]) + print(out) +""" diff --git a/alignn/train.py b/alignn/train.py index 340a54f..dfc6566 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -28,6 +28,7 @@ setup_optimizer, print_train_val_loss, ) +import dgl # from sklearn.metrics import log_loss @@ -35,6 +36,14 @@ # torch.autograd.detect_anomaly() +figlet_alignn = """ + _ _ ___ ____ _ _ _ _ + / \ | | |_ _/ ___| \ | | \ | | + / _ \ | | | | | _| \| | \| | + / ___ \| |___ | | |_| | |\ | |\ | +/_/ \_\_____|___\____|_| \_|_| \_| +""" + def train_dgl( config: Union[TrainingConfig, Dict[str, Any]], @@ -169,9 +178,26 @@ def train_dgl( net = _model.get(config.model.name)(config.model) else: net = model + print(figlet_alignn) print("Model parameters", sum(p.numel() for p in net.parameters())) print("CUDA available", torch.cuda.is_available()) print("CUDA device count", int(torch.cuda.device_count())) + try: + gpu_stats = torch.cuda.get_device_properties(0) + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + from platform import system as platform_system + + platform_system = platform_system() + statistics = ( + f" GPU: {gpu_stats.name}. Max memory: {max_memory} GB" + + f". Platform = {platform_system}.\n" + f" Pytorch: {torch.__version__}. CUDA = " + + f"{gpu_stats.major}.{gpu_stats.minor}." + + f" CUDA Toolkit = {torch.version.cuda}.\n" + ) + print(statistics) + except Exception: + pass # print("device", device) net.to(device) if use_ddp: @@ -224,6 +250,7 @@ def train_dgl( running_loss2 = 0 running_loss3 = 0 running_loss4 = 0 + running_loss5 = 0 train_result = [] for dats, jid in zip(train_loader, train_loader.dataset.ids): info = {} @@ -249,11 +276,14 @@ def train_dgl( info["pred_grad"] = [] info["target_stress"] = [] info["pred_stress"] = [] + info["target_additional"] = [] + info["pred_additional"] = [] loss1 = 0 # Such as energy loss2 = 0 # Such as bader charges loss3 = 0 # Such as forces loss4 = 0 # Such as stresses + loss5 = 0 # Such as dos if config.model.output_features is not None: # print('criterion',criterion) # print('result["out"]',result["out"]) @@ -299,13 +329,23 @@ def train_dgl( ) running_loss3 += loss3.item() if config.model.stresswise_weight != 0: + # print('unbatch',dgl.unbatch(dats[0])) + + targ_stress = torch.stack( + [ + gg.ndata["stresses"][0] + for gg in dgl.unbatch(dats[0]) + ] + ).to(device) + pred_stress = result["stresses"] + # print('targ_stress',targ_stress,targ_stress.shape) + # print('pred_stress',pred_stress,pred_stress.shape) loss4 = config.model.stresswise_weight * criterion( - (result["stresses"]).to(device), - torch.cat(tuple(dats[0].ndata["stresses"])).to(device), + pred_stress.to(device), + targ_stress.to(device), ) info["target_stress"] = ( - torch.cat(tuple(dats[0].ndata["stresses"])) - .cpu() + targ_stress.cpu() .numpy() .tolist() # dats[0].ndata["stresses"][0].cpu().numpy().tolist() @@ -314,10 +354,32 @@ def train_dgl( result["stresses"].cpu().detach().numpy().tolist() ) running_loss4 += loss4.item() + if config.model.additional_output_weight != 0: + # print('unbatch',dgl.unbatch(dats[0])) + additional_dat = [ + gg.ndata["additional"][0] + for gg in dgl.unbatch(dats[0]) + ] + # print('additional_dat',additional_dat,len(additional_dat)) + targ = torch.stack(additional_dat).to(device) + # targ=torch.tensor(additional_dat).to( dats[0].device) + # print('result["additional"]',result["additional"],result["additional"].shape) + # print('targ',targ,targ.shape) + # print('targ device',targ.device) + loss5 = config.model.additional_output_weight * criterion( + (result["additional"]).to(device), + targ, + # (dats[0].ndata["additional"]).to(device), + ) + info["target_additional"] = targ.cpu().numpy().tolist() + info["pred_additional"] = ( + result["additional"].cpu().detach().numpy().tolist() + ) + running_loss5 += loss5.item() # print("target_stress", info["target_stress"][0]) # print("pred_stress", info["pred_stress"][0]) train_result.append(info) - loss = loss1 + loss2 + loss3 + loss4 + loss = loss1 + loss2 + loss3 + loss4 + loss5 loss.backward() optimizer.step() # optimizer.zero_grad() #never @@ -337,6 +399,7 @@ def train_dgl( running_loss2, running_loss3, running_loss4, + running_loss5, ] ) dumpjson( @@ -348,6 +411,7 @@ def train_dgl( val_loss2 = 0 val_loss3 = 0 val_loss4 = 0 + val_loss5 = 0 val_result = [] # for dats in val_loader: val_init_time = time.time() @@ -379,6 +443,7 @@ def train_dgl( loss2 = 0 # Such as bader charges loss3 = 0 # Such as forces loss4 = 0 # Such as stresses + loss5 = 0 # Such as stresses if config.model.output_features is not None: loss1 = config.model.graphwise_weight * criterion( result["out"], dats[-1].to(device) @@ -421,13 +486,22 @@ def train_dgl( # result["stress"].to(device), # dats[0].ndata["stresses"][0].to(device), # ) + + targ_stress = torch.stack( + [ + gg.ndata["stresses"][0] + for gg in dgl.unbatch(dats[0]) + ] + ).to(device) + pred_stress = result["stresses"] + # print('targ_stress',targ_stress,targ_stress.shape) + # print('pred_stress',pred_stress,pred_stress.shape) loss4 = config.model.stresswise_weight * criterion( - (result["stresses"]).to(device), - torch.cat(tuple(dats[0].ndata["stresses"])).to(device), + pred_stress.to(device), + targ_stress.to(device), ) info["target_stress"] = ( - torch.cat(tuple(dats[0].ndata["stresses"])) - .cpu() + targ_stress.cpu() .numpy() .tolist() # dats[0].ndata["stresses"][0].cpu().numpy().tolist() @@ -435,8 +509,31 @@ def train_dgl( info["pred_stress"] = ( result["stresses"].cpu().detach().numpy().tolist() ) + val_loss4 += loss4.item() - loss = loss1 + loss2 + loss3 + loss4 + if config.model.additional_output_weight != 0: + additional_dat = [ + gg.ndata["additional"][0] + for gg in dgl.unbatch(dats[0]) + ] + # print('additional_dat',additional_dat,len(additional_dat)) + targ = torch.stack(additional_dat).to(device) + # targ=torch.tensor(additional_dat).to( dats[0].device) + # print('result["additional"]',result["additional"],result["additional"].shape) + # print('targ',targ,targ.shape) + # print('targ device',targ.device) + loss5 = config.model.additional_output_weight * criterion( + (result["additional"]).to(device), + targ, + # (dats[0].ndata["additional"]).to(device), + ) + info["target_additional"] = targ.cpu().numpy().tolist() + info["pred_additional"] = ( + result["additional"].cpu().detach().numpy().tolist() + ) + + val_loss5 += loss5.item() + loss = loss1 + loss2 + loss3 + loss4 + loss5 val_result.append(info) val_loss += loss.item() # mean_out, mean_atom, mean_grad, mean_stress = get_batch_errors( @@ -472,8 +569,15 @@ def train_dgl( data=val_result, ) best_model = net - history_train.append( - [val_loss, val_loss1, val_loss2, val_loss3, val_loss4] + history_val.append( + [ + val_loss, + val_loss1, + val_loss2, + val_loss3, + val_loss4, + val_loss5, + ] ) # history_val.append([mean_out, mean_atom, mean_grad, mean_stress]) dumpjson( @@ -488,11 +592,13 @@ def train_dgl( running_loss2, running_loss3, running_loss4, + running_loss5, val_loss, val_loss1, val_loss2, val_loss3, val_loss4, + val_loss5, train_ep_time, val_ep_time, saving_msg=saving_msg, @@ -559,19 +665,30 @@ def train_dgl( result["grad"].cpu().detach().numpy().tolist() ) if config.model.stresswise_weight != 0: + + targ_stress = torch.stack( + [ + gg.ndata["stresses"][0] + for gg in dgl.unbatch(dats[0]) + ] + ).to(device) + pred_stress = result["stresses"] + # print('targ_stress',targ_stress,targ_stress.shape) + # print('pred_stress',pred_stress,pred_stress.shape) loss4 = config.model.stresswise_weight * criterion( - result["stresses"].to(device), - torch.cat(tuple(dats[0].ndata["stresses"])).to(device), + pred_stress.to(device), + targ_stress.to(device), ) info["target_stress"] = ( - torch.cat(tuple(dats[0].ndata["stresses"])) - .cpu() + targ_stress.cpu() .numpy() .tolist() + # dats[0].ndata["stresses"][0].cpu().numpy().tolist() ) info["pred_stress"] = ( result["stresses"].cpu().detach().numpy().tolist() ) + test_result.append(info) loss = loss1 + loss2 + loss3 + loss4 if not classification: diff --git a/alignn/train_alignn.py b/alignn/train_alignn.py index c36db4e..d066e66 100644 --- a/alignn/train_alignn.py +++ b/alignn/train_alignn.py @@ -115,6 +115,12 @@ def cleanup(world_size): help="Name of the key for stress (3x3) level data such as forces", ) +parser.add_argument( + "--additional_output_key", + default="additional_output", + help="Name of the key for extra global output eg DOS", +) + parser.add_argument( "--output_dir", @@ -150,6 +156,7 @@ def train_for_folder( atomwise_key="forces", gradwise_key="forces", stresswise_key="stresses", + additional_output_key="additional_output", file_format="poscar", restart_model_path=None, output_dir=None, @@ -197,6 +204,7 @@ def train_for_folder( train_grad = False train_stress = False + train_additional_output = False train_atom = False try: if ( @@ -217,6 +225,13 @@ def train_for_folder( train_atom = True else: train_atom = False + if ( + config.model.additional_output_features > 0 + and config.model.additional_output_weight != 0 + ): + train_additional_output = True + else: + train_additional_output = False except Exception as exp: print("exp", exp) pass @@ -229,6 +244,7 @@ def train_for_folder( target_atomwise = None # "atomwise_target" target_grad = None # "atomwise_grad" target_stress = None # "stresses" + target_additional_output = None # "stresses" # mem = [] # enp = [] @@ -283,11 +299,14 @@ def train_for_folder( info["stresses"] = stress # - mean_force target_stress = "stresses" - # print("stresses",info["stresses"] ) + if train_additional_output: + target_additional_output = "additional" + info["additional"] = i[additional_output_key] # - mean_force if "extra_features" in i: info["extra_features"] = i["extra_features"] dataset.append(info) print("len dataset", len(dataset)) + print("train_stress", train_stress) del dat # multioutput = False lists_length_equal = True @@ -356,6 +375,7 @@ def train_for_folder( target_atomwise=target_atomwise, target_grad=target_grad, target_stress=target_stress, + target_additional_output=target_additional_output, n_train=config.n_train, n_val=config.n_val, n_test=config.n_test, @@ -427,6 +447,7 @@ def train_for_folder( args.atomwise_key, args.force_key, args.stresswise_key, + args.additional_output_key, args.file_format, args.restart_model_path, args.output_dir, @@ -447,6 +468,7 @@ def train_for_folder( args.atomwise_key, args.force_key, args.stresswise_key, + args.additional_output_key, args.file_format, args.restart_model_path, args.output_dir, diff --git a/alignn/utils.py b/alignn/utils.py index 840cdf3..f97c454 100644 --- a/alignn/utils.py +++ b/alignn/utils.py @@ -115,17 +115,21 @@ def print_train_val_loss( running_loss2, running_loss3, running_loss4, + running_loss5, val_loss, val_loss1, val_loss2, val_loss3, val_loss4, + val_loss5, train_ep_time, val_ep_time, saving_msg="", ): """Train loss header.""" - header = ("{:<12} {:<8} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}").format( + header = ( + "{:<12} {:<8} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}" + ).format( "Train Loss:", "Epoch", "Total", @@ -133,13 +137,15 @@ def print_train_val_loss( "Atom", "Grad", "Stress", + "Addn.", "Time", ) print(header) # Train loss values train_row = ( - "{:<12} {:<8} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} " + "{:<12} {:<8} {:<10.4f} {:<10.4f} " + + "{:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} " "{:<10.2f}" ).format( "", @@ -149,12 +155,15 @@ def print_train_val_loss( running_loss2, running_loss3, running_loss4, + running_loss5, train_ep_time, ) print(train_row) # Validation loss header - header = ("{:<12} {:<8} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}").format( + header = ( + "{:<12} {:<8} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}" + ).format( "Val Loss:", "Epoch", "Total", @@ -162,13 +171,15 @@ def print_train_val_loss( "Atom", "Grad", "Stress", + "Addn.", "Time", ) print(header) # Validation loss values val_row = ( - "{:<12} {:<8} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} " + "{:<12} {:<8} {:<10.4f} {:<10.4f} " + + "{:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} " "{:<10.2f} {:<10}" ).format( "", @@ -178,6 +189,7 @@ def print_train_val_loss( val_loss2, val_loss3, val_loss4, + val_loss5, val_ep_time, saving_msg, ) From 3c180dee2eda1ec3cdacd97d938289ac7eca5a2d Mon Sep 17 00:00:00 2001 From: knc6 Date: Fri, 8 Nov 2024 21:53:48 -0500 Subject: [PATCH 14/37] Arr 0 fix. --- alignn/lmdb_dataset.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/alignn/lmdb_dataset.py b/alignn/lmdb_dataset.py index fbc14d1..88e3b56 100644 --- a/alignn/lmdb_dataset.py +++ b/alignn/lmdb_dataset.py @@ -186,9 +186,19 @@ def get_torch_dataset( np.array(d[target_atomwise]) ).type(torch.get_default_dtype()) if target_grad is not None and target_grad != "": - g.ndata[target_grad] = torch.tensor( - np.array(d[target_grad]) - ).type(torch.get_default_dtype()) + # print('grad', np.array(d[target_grad])) + # print('grad shape',np.array(d[target_grad]).shape) + arr = np.array(d[target_grad]) + try: + g.ndata[target_grad] = torch.tensor(arr).type( + torch.get_default_dtype() + ) + except Exception: + arr = arr.reshape(1, -1) + g.ndata[target_grad] = torch.tensor(arr).type( + torch.get_default_dtype() + ) + # print('arr',arr.shape) if target_stress is not None and target_stress != "": stress = np.array(d[target_stress]) g.ndata[target_stress] = torch.tensor( From 8c983f5a6993d1f6bb6dbe58e35a4d79f36a6aa9 Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 18 Nov 2024 09:53:57 -0500 Subject: [PATCH 15/37] stress change. --- alignn/ff/ff.py | 19 ++++------ alignn/models/alignn_atomwise.py | 64 ++++++++++++++++---------------- 2 files changed, 40 insertions(+), 43 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 060685a..f73ee22 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -307,7 +307,7 @@ def calculate(self, atoms, properties=None, system_changes=None): ( g.to(self.device), lg.to(self.device), - torch.tensor(atoms.cell) + torch.tensor(np.array(atoms.cell)) .type(torch.get_default_dtype()) .to(self.device), ) @@ -321,19 +321,16 @@ def calculate(self, atoms, properties=None, system_changes=None): energy = result["out"].detach().cpu().numpy() * num_atoms else: energy = result["out"].detach().cpu().numpy() - + stress = self.stress_wt * np.array( + full_3x3_to_voigt_6_stress( + result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() + ) + ) + # print('stress',stress) self.results = { "energy": energy, # * num_atoms, "forces": result["grad"].detach().cpu().numpy(), - "stress": full_3x3_to_voigt_6_stress( - # np.eye(3) - result["stresses"][:3] - .reshape(3, 3) - .detach() - .cpu() - .numpy() - ) - / 160.21766208, + "stress": stress, "dipole": np.zeros(3), "charges": np.zeros(len(atoms)), "magmom": 0.0, diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 4fae2c3..e055f4e 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -581,38 +581,38 @@ def forward( # print("stress1", stress, stress.shape) # print("g.batch_size", g.batch_size) else: - stresses = [] - count_edge = 0 - count_node = 0 - for graph_id in range(g.batch_size): - num_edges = g.batch_num_edges()[graph_id] - num_nodes = 0 - st = -1 * ( - 160.21766208 - * torch.matmul( - r[count_edge : count_edge + num_edges].T, - pair_forces[ - count_edge : count_edge + num_edges - ], - ) - / g.ndata["V"][count_node + num_nodes] - ) - - count_edge = count_edge + num_edges - num_nodes = g.batch_num_nodes()[graph_id] - count_node = count_node + num_nodes - stresses.append(st) - stress = self.config.stress_multiplier * torch.stack( - stresses - ) - # print("stress2", stress, stress.shape) - # virial = ( - # 160.21766208 - # * 10 - # * torch.einsum("ij, ik->jk", - # result["r"], result["dy_dr"]) - # / 2 - # ) # / ( g.ndata["V"][0]) + # stresses = [] + # count_edge = 0 + # count_node = 0 + # for graph_id in range(g.batch_size): + # num_edges = g.batch_num_edges()[graph_id] + # num_nodes = 0 + # st = -1 * ( + # 160.21766208 + # * torch.matmul( + # r[count_edge : count_edge + num_edges].T, + # pair_forces[ + # count_edge : count_edge + num_edges + # ], + # ) + # / g.ndata["V"][count_node + num_nodes] + # ) + + # count_edge = count_edge + num_edges + # num_nodes = g.batch_num_nodes()[graph_id] + # count_node = count_node + num_nodes + # stresses.append(st) + # stress = self.config.stress_multiplier * torch.stack( + # stresses + # ) + stress = ( + # 160.21766208 + # * 10 + -1 + * torch.einsum("ij, ik->jk", r, pair_forces) + / 2 + # / (2 * g.ndata["V"][0]) + ) ## / ( g.ndata["V"][0]) if self.link: out = self.link(out) From bac7d0be792972793a2a649079b53f352c3260a1 Mon Sep 17 00:00:00 2001 From: knc6 Date: Tue, 19 Nov 2024 10:28:49 -0500 Subject: [PATCH 16/37] Add new models. --- alignn/ff/all_models_ff.json | 4 +- alignn/ff/ff.py | 31 +++---- alignn/models/alignn_atomwise.py | 140 ++++++++++++++++--------------- 3 files changed, 88 insertions(+), 87 deletions(-) diff --git a/alignn/ff/all_models_ff.json b/alignn/ff/all_models_ff.json index 981f0da..085fab8 100644 --- a/alignn/ff/all_models_ff.json +++ b/alignn/ff/all_models_ff.json @@ -1,4 +1,6 @@ { + "v10.30.2024_dft_3d_307k": "https://figshare.com/ndownloader/files/50634327", + "v10.30.2024_mp_168k": "https://figshare.com/ndownloader/files/50634318", "v8.29.2024_dft_3d": "https://figshare.com/ndownloader/files/48889834", "v8.29.2024_mpf": "https://figshare.com/ndownloader/files/48889837", "v5.27.2024": "https://figshare.com/ndownloader/files/47286127", @@ -12,4 +14,4 @@ "revised": "https://figshare.com/ndownloader/files/41583600", "scf_fd_top_10_en_42_fmax_600_wt01": "https://figshare.com/ndownloader/files/41967375", "scf_fd_top_10_en_42_fmax_600_wt10": "https://figshare.com/ndownloader/files/41967372" -} \ No newline at end of file +} diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index f73ee22..1c83804 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -223,6 +223,7 @@ def __init__( device=None, model=None, config=None, + force_mult_batchsize=True, path=".", model_filename="best_model.pt", config_filename="config.json", @@ -243,8 +244,7 @@ def __init__( self.config = config self.include_stress = include_stress self.stress_wt = stress_wt - # self.force_multiplier = force_multiplier - # self.force_mult_natoms = force_mult_natoms + self.force_mult_batchsize = force_mult_batchsize if self.config is None: config = loadjson(os.path.join(path, config_filename)) # print('config',config) @@ -253,7 +253,7 @@ def __init__( if self.include_stress: self.implemented_properties = ["energy", "forces", "stress"] if config["model"]["stresswise_weight"] == 0: - config["model"]["stresswise_weight"] = 0.1 + config["model"]["stresswise_weight"] = 0.1 # self.stress_wt else: self.implemented_properties = ["energy", "forces"] @@ -276,13 +276,9 @@ def __init__( map_location=self.device, ) ) - else: - model = self.model - model.to(device) - model.eval() - - self.net = model - self.net.to(self.device) + model.to(device) + model.eval() + self.model = model def calculate(self, atoms, properties=None, system_changes=None): """Calculate properties.""" @@ -303,7 +299,7 @@ def calculate(self, atoms, properties=None, system_changes=None): if self.config["model"]["alignn_layers"] > 0: # g,lg = g - result = self.net( + result = self.model( ( g.to(self.device), lg.to(self.device), @@ -313,7 +309,7 @@ def calculate(self, atoms, properties=None, system_changes=None): ) ) else: - result = self.net( + result = self.model( (g.to(self.device, torch.tensor(atoms.cell).to(self.device))) ) # print ('stress',result["stress"].detach().numpy()) @@ -326,15 +322,14 @@ def calculate(self, atoms, properties=None, system_changes=None): result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() ) ) - # print('stress',stress) + forces = result["grad"].detach().cpu().numpy() + if self.force_mult_batchsize: + forces *= self.config["batch_size"] + # stress*=self.config['batch_size'] self.results = { "energy": energy, # * num_atoms, - "forces": result["grad"].detach().cpu().numpy(), + "forces": forces, "stress": stress, - "dipole": np.zeros(3), - "charges": np.zeros(len(atoms)), - "magmom": 0.0, - "magmoms": np.zeros(len(atoms)), } diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index e055f4e..5d28ab6 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -29,47 +29,45 @@ class ALIGNNAtomWiseConfig(BaseSettings): """Hyperparameter schema for jarvisdgl.models.alignn.""" name: Literal["alignn_atomwise"] - alignn_layers: int = 2 - gcn_layers: int = 2 - atom_input_features: int = 1 - # atom_input_features: int = 92 + alignn_layers: int = 4 + gcn_layers: int = 4 + # atom_input_features: int = 1 + atom_input_features: int = 92 edge_input_features: int = 80 triplet_input_features: int = 40 embedding_features: int = 64 - hidden_features: int = 64 - # hidden_features: int = 256 - # fc_layers: int = 1 - # fc_features: int = 64 + # hidden_features: int = 64 + hidden_features: int = 256 output_features: int = 1 grad_multiplier: int = -1 calculate_gradient: bool = True atomwise_output_features: int = 0 graphwise_weight: float = 1.0 - gradwise_weight: float = 1.0 + gradwise_weight: float = 0.0 stresswise_weight: float = 0.0 atomwise_weight: float = 0.0 - # if link == log, apply `exp` to final outputs - # to constrain predictions to be positive - link: Literal["identity", "log", "logit"] = "identity" zero_inflated: bool = False classification: bool = False force_mult_natoms: bool = False - energy_mult_natoms: bool = True + energy_mult_natoms: bool = False + # energy_mult_natoms: bool = True include_pos_deriv: bool = False use_cutoff_function: bool = False - inner_cutoff: float = 3 # Ansgtrom + inner_cutoff: float = 6 # Ansgtrom stress_multiplier: float = 1 add_reverse_forces: bool = True # will make True as default soon lg_on_fly: bool = True # will make True as default soon - batch_stress: bool = True multiply_cutoff: bool = False use_penalty: bool = True extra_features: int = 0 - exponent: int = 5 + exponent: int = 3 penalty_factor: float = 0.1 penalty_threshold: float = 1 additional_output_features: int = 0 additional_output_weight: float = 0 + stress_method: int = 3 + link: Literal["identity", "log", "logit"] = "identity" + batch_stress: bool = True class Config: """Configure model settings behavior.""" @@ -382,7 +380,7 @@ def forward( features = self.extra_feature_embedding(features) g = g.local_var() result = {} - + # print('g',g) # initial node features: atom feature network... x = g.ndata.pop("atom_features") # print('x1',x,x.shape) @@ -559,60 +557,66 @@ def forward( forces = torch.squeeze(g.ndata["forces_ji"]) if self.config.stresswise_weight != 0: - # Under development, use with caution - # 1 eV/Angstrom3 = 160.21766208 GPa - # 1 GPa = 10 kbar - # Virial stress formula, assuming inital velocity = 0 - # Save volume as g.gdta['V']? - # print('pair_forces',pair_forces.shape) - # print('r',r.shape) - # print('g.ndata["V"]',g.ndata["V"].shape) - if not self.config.batch_stress: - # print('Not batch_stress') + if self.config.stress_method == 1: + g.ndata["cart_coords"] = compute_cartesian_coordinates( + g, lat + ) + r, bondlength = compute_pair_vector_and_distance(g) + stress = -160.21766208 * ( + torch.matmul(r.T, pair_forces) + # / (2 * g.edata["V"]) + / (2 * g.ndata["V"][0]) + ) + if self.config.stress_method == 2: + cart_coords = compute_cartesian_coordinates( + g, lat + ).view(g.batch_size, -1, 3) + forces_batched = forces.view(g.batch_size, -1, 3) + vols = torch.abs(torch.det(lat)) + if vols.ndim == 0: + vols = vols.unsqueeze(0) + stresses = [] + for graph_id in range(g.batch_size): + st = ( + -160.21766208 + * torch.matmul( + cart_coords[graph_id].T, + forces_batched[graph_id], + ) + / (vols[graph_id]) + ) + stresses.append(st) + # print(st) + stress = torch.stack(stresses) + if self.config.stress_method == 3: stress = ( - -1 - * 160.21766208 - * ( - torch.matmul(r.T, pair_forces) - # / (2 * g.edata["V"]) - / (2 * g.ndata["V"][0]) + -1 * torch.einsum("ij, ik->jk", r, pair_forces) / 2 + ) + if self.config.stress_method == 4: + stresses = [] + count_edge = 0 + count_node = 0 + for graph_id in range(g.batch_size): + num_edges = g.batch_num_edges()[graph_id] + num_nodes = 0 + st = -1 * ( + 160.21766208 + * torch.matmul( + r[count_edge : count_edge + num_edges].T, + pair_forces[ + count_edge : count_edge + num_edges + ], + ) + / g.ndata["V"][count_node + num_nodes] ) + + count_edge = count_edge + num_edges + num_nodes = g.batch_num_nodes()[graph_id] + count_node = count_node + num_nodes + stresses.append(st) + stress = self.config.stress_multiplier * torch.stack( + stresses ) - # print("stress1", stress, stress.shape) - # print("g.batch_size", g.batch_size) - else: - # stresses = [] - # count_edge = 0 - # count_node = 0 - # for graph_id in range(g.batch_size): - # num_edges = g.batch_num_edges()[graph_id] - # num_nodes = 0 - # st = -1 * ( - # 160.21766208 - # * torch.matmul( - # r[count_edge : count_edge + num_edges].T, - # pair_forces[ - # count_edge : count_edge + num_edges - # ], - # ) - # / g.ndata["V"][count_node + num_nodes] - # ) - - # count_edge = count_edge + num_edges - # num_nodes = g.batch_num_nodes()[graph_id] - # count_node = count_node + num_nodes - # stresses.append(st) - # stress = self.config.stress_multiplier * torch.stack( - # stresses - # ) - stress = ( - # 160.21766208 - # * 10 - -1 - * torch.einsum("ij, ik->jk", r, pair_forces) - / 2 - # / (2 * g.ndata["V"][0]) - ) ## / ( g.ndata["V"][0]) if self.link: out = self.link(out) From d837b3f257a3e5c4e14e3df1dbc72529ed9e7977 Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Tue, 19 Nov 2024 12:06:32 -0500 Subject: [PATCH 17/37] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 08d994d..c07ad98 100644 --- a/README.md +++ b/README.md @@ -410,7 +410,8 @@ References 6) [Graph neural network predictions of metal organic framework CO2 adsorption properties](https://www.sciencedirect.com/science/article/pii/S092702562200163X) 7) [Rapid Prediction of Phonon Structure and Properties using an Atomistic Line Graph Neural Network (ALIGNN)](https://journals.aps.org/prmaterials/abstract/10.1103/PhysRevMaterials.7.023803) 8) [Unified graph neural network force-field for the periodic table](https://pubs.rsc.org/en/content/articlehtml/2023/dd/d2dd00096b) -9) [Large Scale Benchmark of Materials Design Methods](https://arxiv.org/abs/2306.11688) +9) [Large Scale Benchmark of Materials Design Methods](https://www.nature.com/articles/s41524-024-01259-w) +10) [CHIPS-FF](https://github.com/usnistgov/chipsff) Please see detailed publications list [here](https://jarvis-tools.readthedocs.io/en/master/publications.html). From ceb98134464e43982602ea7cc9557ff91332b6de Mon Sep 17 00:00:00 2001 From: knc6 Date: Tue, 19 Nov 2024 21:06:32 -0500 Subject: [PATCH 18/37] Add stress method. --- .../examples/sample_data/config_example.json | 84 ++++++++++++++++++- .../config_example_atomwise.json | 59 ++++++++++++- .../sample_data_ff_additional/config.json | 6 +- .../config_example_atomwise.json | 21 +++-- alignn/models/alignn_atomwise.py | 2 +- 5 files changed, 155 insertions(+), 17 deletions(-) diff --git a/alignn/examples/sample_data/config_example.json b/alignn/examples/sample_data/config_example.json index 7a462ac..073819a 100644 --- a/alignn/examples/sample_data/config_example.json +++ b/alignn/examples/sample_data/config_example.json @@ -1 +1,83 @@ -{"version": "112bbedebdaecf59fb18e11c929080fb2f358246", "dataset": "user_data", "target": "target", "atom_features": "cgcnn", "neighbor_strategy": "k-nearest", "id_tag": "jid", "dtype": "float32", "random_seed": 123, "classification_threshold": null, "n_val": null, "n_test": null, "n_train": null, "train_ratio": 0.8, "val_ratio": 0.1, "test_ratio": 0.1, "target_multiplication_factor": null, "epochs": 3, "batch_size": 2, "weight_decay": 1e-05, "learning_rate": 0.001, "filename": "A", "warmup_steps": 2000, "criterion": "mse", "optimizer": "adamw", "scheduler": "onecycle", "pin_memory": false, "save_dataloader": false, "write_checkpoint": true, "write_predictions": true, "store_outputs": true, "progress": true, "log_tensorboard": false, "standard_scalar_and_pca": false, "use_canonize": true, "num_workers": 0, "cutoff": 8.0, "cutoff_extra": 3.0, "max_neighbors": 12, "keep_data_order": true, "normalize_graph_level_loss": false, "distributed": false, "data_parallel": false, "n_early_stopping": null, "output_dir": "temp", "use_lmdb": true, "model": {"name": "alignn_atomwise", "alignn_layers": 4, "gcn_layers": 4, "atom_input_features": 92, "edge_input_features": 80, "triplet_input_features": 40, "embedding_features": 64, "hidden_features": 256, "output_features": 1, "grad_multiplier": -1, "calculate_gradient": false, "atomwise_output_features": 0, "graphwise_weight": 1.0, "gradwise_weight": 1.0, "stresswise_weight": 0.0, "atomwise_weight": 0.0, "link": "identity", "zero_inflated": false, "classification": false, "force_mult_natoms": false, "energy_mult_natoms": false, "include_pos_deriv": false, "use_cutoff_function": false, "inner_cutoff": 3.0, "stress_multiplier": 1.0, "add_reverse_forces": true, "lg_on_fly": true, "batch_stress": true, "multiply_cutoff": false, "use_penalty": true, "extra_features": 0, "exponent": 5, "penalty_factor": 0.1, "penalty_threshold": 1.0}} \ No newline at end of file +{ + "version": "112bbedebdaecf59fb18e11c929080fb2f358246", + "dataset": "user_data", + "target": "target", + "atom_features": "cgcnn", + "neighbor_strategy": "k-nearest", + "id_tag": "jid", + "dtype": "float32", + "random_seed": 123, + "classification_threshold": null, + "n_val": null, + "n_test": null, + "n_train": null, + "train_ratio": 0.8, + "val_ratio": 0.1, + "test_ratio": 0.1, + "target_multiplication_factor": null, + "epochs": 3, + "batch_size": 2, + "weight_decay": 1e-05, + "learning_rate": 0.001, + "filename": "A", + "warmup_steps": 2000, + "criterion": "mse", + "optimizer": "adamw", + "scheduler": "onecycle", + "pin_memory": false, + "save_dataloader": false, + "write_checkpoint": true, + "write_predictions": true, + "store_outputs": true, + "progress": true, + "log_tensorboard": false, + "standard_scalar_and_pca": false, + "use_canonize": true, + "num_workers": 0, + "cutoff": 8.0, + "cutoff_extra": 3.0, + "max_neighbors": 12, + "keep_data_order": true, + "normalize_graph_level_loss": false, + "distributed": false, + "data_parallel": false, + "n_early_stopping": null, + "output_dir": "temp", + "use_lmdb": true, + "model": { + "name": "alignn_atomwise", + "alignn_layers": 4, + "gcn_layers": 4, + "atom_input_features": 92, + "edge_input_features": 80, + "triplet_input_features": 40, + "embedding_features": 64, + "hidden_features": 256, + "output_features": 1, + "grad_multiplier": -1, + "calculate_gradient": false, + "atomwise_output_features": 0, + "graphwise_weight": 1.0, + "gradwise_weight": 1.0, + "stresswise_weight": 0.0, + "atomwise_weight": 0.0, + "link": "identity", + "zero_inflated": false, + "classification": false, + "force_mult_natoms": false, + "energy_mult_natoms": false, + "include_pos_deriv": false, + "use_cutoff_function": false, + "inner_cutoff": 3.0, + "stress_multiplier": 1.0, + "add_reverse_forces": true, + "lg_on_fly": true, + "batch_stress": true, + "multiply_cutoff": false, + "use_penalty": true, + "extra_features": 0, + "exponent": 5, + "penalty_factor": 0.1, + "penalty_threshold": 1.0 + } +} \ No newline at end of file diff --git a/alignn/examples/sample_data_ff/config_example_atomwise.json b/alignn/examples/sample_data_ff/config_example_atomwise.json index f5981a4..fe0caec 100644 --- a/alignn/examples/sample_data_ff/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff/config_example_atomwise.json @@ -1 +1,58 @@ -{"version": "112bbedebdaecf59fb18e11c929080fb2f358246", "dataset": "user_data", "target": "target", "atom_features": "cgcnn", "neighbor_strategy": "radius_graph", "id_tag": "jid", "dtype": "float32", "random_seed": 123, "classification_threshold": null, "n_val": null, "n_test": null, "n_train": null, "train_ratio": 0.8, "val_ratio": 0.1, "test_ratio": 0.1, "target_multiplication_factor": null, "epochs": 3, "batch_size": 2, "weight_decay": 1e-05, "learning_rate": 0.001, "filename": "B", "warmup_steps": 2000, "criterion": "l1", "optimizer": "adamw", "scheduler": "onecycle", "pin_memory": false, "save_dataloader": false, "write_checkpoint": true, "write_predictions": true, "store_outputs": false, "progress": true, "log_tensorboard": false, "standard_scalar_and_pca": false, "use_canonize": true, "num_workers": 0, "cutoff": 4.0, "max_neighbors": 12, "keep_data_order": true, "distributed": false, "use_lmdb": true, "model": {"name": "alignn_atomwise", "atom_input_features": 92, "calculate_gradient": true, "atomwise_output_features": 0, "alignn_layers": 1, "gcn_layers": 1, "hidden_features": 64, "output_features": 1, "graphwise_weight": 0.85, "gradwise_weight": 0.05, "atomwise_weight": 0.0, "use_cutoff_function": false, "stresswise_weight": 0.05, "add_reverse_forces": true}} \ No newline at end of file +{ + "version": "112bbedebdaecf59fb18e11c929080fb2f358246", + "dataset": "user_data", + "target": "target", + "atom_features": "cgcnn", + "neighbor_strategy": "radius_graph", + "id_tag": "jid", + "dtype": "float32", + "random_seed": 123, + "classification_threshold": null, + "n_val": null, + "n_test": null, + "n_train": null, + "train_ratio": 0.8, + "val_ratio": 0.1, + "test_ratio": 0.1, + "target_multiplication_factor": null, + "epochs": 3, + "batch_size": 2, + "weight_decay": 1e-05, + "learning_rate": 0.001, + "filename": "B", + "warmup_steps": 2000, + "criterion": "l1", + "optimizer": "adamw", + "scheduler": "onecycle", + "pin_memory": false, + "save_dataloader": false, + "write_checkpoint": true, + "write_predictions": true, + "store_outputs": false, + "progress": true, + "log_tensorboard": false, + "standard_scalar_and_pca": false, + "use_canonize": true, + "num_workers": 0, + "cutoff": 4.0, + "max_neighbors": 12, + "keep_data_order": true, + "distributed": false, + "use_lmdb": true, + "model": { + "name": "alignn_atomwise", + "atom_input_features": 92, + "calculate_gradient": true, + "atomwise_output_features": 0, + "alignn_layers": 1, + "gcn_layers": 1, + "hidden_features": 64, + "output_features": 1, + "graphwise_weight": 0.85, + "gradwise_weight": 0.05, + "atomwise_weight": 0.0, + "use_cutoff_function": false, + "stresswise_weight": 0.05, + "add_reverse_forces": true + } +} \ No newline at end of file diff --git a/alignn/examples/sample_data_ff_additional/config.json b/alignn/examples/sample_data_ff_additional/config.json index 32bbd93..69944ed 100644 --- a/alignn/examples/sample_data_ff_additional/config.json +++ b/alignn/examples/sample_data_ff_additional/config.json @@ -77,9 +77,9 @@ "use_penalty": true, "extra_features": 0, "exponent": 5, - "additional_output_features":400, - "additional_output_weight":0.1, + "additional_output_features": 400, + "additional_output_weight": 0.1, "penalty_factor": 0.1, "penalty_threshold": 1.0 } -} +} \ No newline at end of file diff --git a/alignn/examples/sample_data_ff_feats/config_example_atomwise.json b/alignn/examples/sample_data_ff_feats/config_example_atomwise.json index 19d13ae..68892d7 100644 --- a/alignn/examples/sample_data_ff_feats/config_example_atomwise.json +++ b/alignn/examples/sample_data_ff_feats/config_example_atomwise.json @@ -36,20 +36,19 @@ "cutoff": 8.0, "max_neighbors": 12, "keep_data_order": true, - "distributed":false, + "distributed": false, "model": { "name": "alignn_atomwise", "atom_input_features": 92, - "calculate_gradient":true, - "atomwise_output_features":0, - "alignn_layers":4, - "gcn_layers":4, + "calculate_gradient": true, + "atomwise_output_features": 0, + "alignn_layers": 4, + "gcn_layers": 4, "output_features": 1, "extra_features": 6, - "graphwise_weight":0.0, - "gradwise_weight":0.0, - "atomwise_weight":0.0, - "stresswise_weight":0.0 - + "graphwise_weight": 0.0, + "gradwise_weight": 0.0, + "atomwise_weight": 0.0, + "stresswise_weight": 0.0 } -} +} \ No newline at end of file diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 5d28ab6..8af8a0c 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -65,7 +65,7 @@ class ALIGNNAtomWiseConfig(BaseSettings): penalty_threshold: float = 1 additional_output_features: int = 0 additional_output_weight: float = 0 - stress_method: int = 3 + stress_method: int = 4 link: Literal["identity", "log", "logit"] = "identity" batch_stress: bool = True From 4c0dfd5ce7017df91e274afc163f64339e11e232 Mon Sep 17 00:00:00 2001 From: knc6 Date: Tue, 19 Nov 2024 21:08:26 -0500 Subject: [PATCH 19/37] GH Yaml. --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6742769..3014416 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,8 +9,8 @@ jobs: matrix: os: ["ubuntu-latest"] steps: - - uses: actions/checkout@v2 - - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v3 + - uses: conda-incubator/setup-miniconda@v3 with: activate-environment: test environment-file: environment.yml From 61d45e60c12ec371a2b976f69680a82a33375850 Mon Sep 17 00:00:00 2001 From: knc6 Date: Tue, 19 Nov 2024 23:47:59 -0500 Subject: [PATCH 20/37] Slow test fix --- alignn/ff/ff.py | 36 ++++++--------- alignn/models/alignn_atomwise.py | 2 +- alignn/tests/test_alignn_ff.py | 76 ++++++++++++++++++-------------- 3 files changed, 57 insertions(+), 57 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 1c83804..3023a82 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -8,7 +8,6 @@ from ase.md.nptberendsen import NPTBerendsen from ase.io import Trajectory import matplotlib.pyplot as plt -from jarvis.analysis.thermodynamics.energetics import unary_energy from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.optimize import BFGS from ase.optimize.bfgslinesearch import BFGSLineSearch @@ -30,15 +29,11 @@ from jarvis.db.jsonutils import loadjson from alignn.graphs import Graph from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig -from jarvis.analysis.defects.vacancy import Vacancy import numpy as np -from alignn.pretrained import get_prediction from jarvis.analysis.structure.spacegroup import ( Spacegroup3D, - symmetrically_distinct_miller_indices, + # symmetrically_distinct_miller_indices, ) -from jarvis.analysis.interface.zur import make_interface -from jarvis.analysis.defects.surface import Surface from jarvis.core.kpoints import Kpoints3D as Kpoints import zipfile from ase import Atoms as AseAtoms @@ -51,11 +46,15 @@ from tqdm import tqdm import torch -try: - from gpaw import GPAW, PW -except Exception: - pass -# plt.switch_backend("agg") +# from jarvis.analysis.thermodynamics.energetics import unary_energy +# from jarvis.analysis.defects.vacancy import Vacancy +# from jarvis.analysis.defects.surface import Surface +# from alignn.pretrained import get_prediction +# from jarvis.analysis.interface.zur import make_interface +# try: +# from gpaw import GPAW, PW +# except Exception: +# pass # Reference: https://doi.org/10.1039/D2DD00096B @@ -229,7 +228,7 @@ def __init__( config_filename="config.json", output_dir=None, batch_stress=True, - stress_wt=0.1, + stress_wt=0.03, **kwargs, ): """Initialize class.""" @@ -871,6 +870,7 @@ def ev_curve( return x, y, eos, kv +""" def vacancy_formation( atoms=None, jid="", @@ -883,7 +883,6 @@ def vacancy_formation( using_wyckoffs=True, on_relaxed_struct=True, ): - """Get vacancy energy.""" if atoms is None: from jarvis.db.figshare import data @@ -975,7 +974,6 @@ def surface_energy( thickness=25, model_filename="best_model.pt", ): - """Get surface energy.""" if atoms is None: from jarvis.db.figshare import data @@ -1065,7 +1063,6 @@ def get_interface_energy( from_conventional_structure=True, gpaw_verify=False, ): - """Get work of adhesion.""" film_surf = Surface( film_atoms, indices=film_index, @@ -1091,14 +1088,6 @@ def get_interface_energy( atol=atol, apply_strain=apply_strain, ) - """ - print('film') - print(het['film_sl']) - print('subs') - print(het['subs_sl']) - print('intf') - print(het['interface']) - """ a = get_prediction( atoms=het["film_sl"], model_name="jv_optb88vdw_total_energy_alignn" )[0] @@ -1180,6 +1169,7 @@ def get_interface_energy( info["film_sl"] = het["film_sl"].to_dict() info["subs_sl"] = het["subs_sl"].to_dict() return info +""" def phonons( diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 8af8a0c..a256719 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -65,7 +65,7 @@ class ALIGNNAtomWiseConfig(BaseSettings): penalty_threshold: float = 1 additional_output_features: int = 0 additional_output_weight: float = 0 - stress_method: int = 4 + stress_method: int = 1 link: Literal["identity", "log", "logit"] = "identity" batch_stress: bool = True diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index e439b06..d961939 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -3,10 +3,8 @@ from alignn.ff.ff import ( default_path, ev_curve, - surface_energy, - vacancy_formation, ForceField, - get_interface_energy, + get_figshare_model_ff ) from alignn.graphs import Graph, radius_graph_jarvis from alignn.ff.ff import phonons @@ -42,25 +40,42 @@ 3.7801334214124656 6.383208181964002 7.430329467399411 """ - +pos = """System +1.0 +5.49363 0.0 0.0 +-0.0 5.49363 0.0 +0.0 0.0 5.49363 +Si +8 +direct +0.25 0.75 0.25 Si +0.0 0.0 0.5 Si +0.25 0.25 0.75 Si +0.0 0.5 0.0 Si +0.75 0.75 0.75 Si +0.5 0.0 0.0 Si +0.75 0.25 0.25 Si +0.5 0.5 0.5 Si +""" # def test_radius_graph_jarvis(): # atoms = Poscar.from_string(pos).atoms # g, lg = radius_graph_jarvis(atoms=atoms) -def test_alignnff(): - atoms = JAtoms.from_dict(get_jid_data()["atoms"]) - atoms = JAtoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-1002")["atoms"] - # get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] - ) +def test_graph_builder(): + + atoms = Poscar.from_string(pos).atoms old_g = Graph.from_atoms(atoms=atoms) g, lg = Graph.atom_dgl_multigraph(atoms) g, lg = Graph.atom_dgl_multigraph(atoms, neighbor_strategy="radius_graph") g, lg = Graph.atom_dgl_multigraph( atoms, neighbor_strategy="radius_graph_jarvis" ) - model_path = default_path() + + +def test_ev(): + atoms = Poscar.from_string(pos).atoms + model_path = get_figshare_model_ff(model_name="v10.30.2024_dft_3d_307k") #default_path() print("model_path", model_path) print("atoms", atoms) # atoms = atoms.make_supercell_matrix([2, 2, 2]) @@ -71,6 +86,11 @@ def test_alignnff(): # vac = vacancy_formation(atoms=atoms, model_path=model_path) # print('vac',vac) + +def test_ev(): + atoms = Poscar.from_string(pos).atoms + #model_path = default_path() + model_path = get_figshare_model_ff(model_name="v10.30.2024_dft_3d_307k") #default_path() ff = ForceField( jarvis_atoms=atoms, model_path=model_path, @@ -84,30 +104,20 @@ def test_alignnff(): xx = ff.run_nvt_langevin(steps=5) xx = ff.run_nvt_andersen(steps=5) # xx = ff.run_npt_nose_hoover(steps=5) - atoms_al = JAtoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-1002")["atoms"] - # get_jid_data(dataset="dft_3d", jid="JVASP-816")["atoms"] - ) - atoms_al2o3 = JAtoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-1002")["atoms"] - # get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] - ) - intf = get_interface_energy( - film_atoms=atoms_al, - subs_atoms=atoms_al, - model_path=model_path, - film_thickness=5, - subs_thickness=5, - # film_atoms=atoms_al, subs_atoms=atoms_al2o3, model_path=model_path - ) -def test_phonons(): - atoms = Atoms.from_dict( - get_jid_data(jid="JVASP-816", dataset="dft_3d")["atoms"] - ) - ph_path = fd_path() - ph = phonons(model_path=ph_path, atoms=(atoms)) +def test_phonons(): + atoms = Poscar.from_string(pos).atoms.get_primitive_atoms + #model_path = default_path() + model_path = get_figshare_model_ff(model_name="v10.30.2024_dft_3d_307k") #default_path() + ph = phonons(model_path=model_path, atoms=(atoms)) + +#print('test_graph_builder') +#test_graph_builder() +#print('test_ev') +#test_ev() +#print('test_phonons') +#test_phonons() # test_alignnff() From 2c3cb5b7c7cd72b08cba8001a905418ed0e1bc7f Mon Sep 17 00:00:00 2001 From: knc6 Date: Wed, 20 Nov 2024 00:15:02 -0500 Subject: [PATCH 21/37] Coverage update. --- alignn/graphs.py | 2 +- alignn/models/alignn.py | 2 +- alignn/pretrained.py | 7 ++++--- alignn/tests/test_alignn_ff.py | 36 ++++++++++++++++++++-------------- alignn/tests/test_prop.py | 24 +++++++++++------------ 5 files changed, 39 insertions(+), 32 deletions(-) diff --git a/alignn/graphs.py b/alignn/graphs.py index 37d936e..0bb35d1 100644 --- a/alignn/graphs.py +++ b/alignn/graphs.py @@ -1074,7 +1074,7 @@ def collate_line_graph( return ( batched_graph, batched_line_graph, - torch.tensor(lattices), + torch.stack(lattices), torch.tensor(labels), ) diff --git a/alignn/models/alignn.py b/alignn/models/alignn.py index fe00d3d..69c14ad 100644 --- a/alignn/models/alignn.py +++ b/alignn/models/alignn.py @@ -291,7 +291,7 @@ def forward( if len(self.alignn_layers) > 0: # print('features2',features.shape) - g, lg = g + g, lg, lat = g lg = lg.local_var() # angle features (fixed) diff --git a/alignn/pretrained.py b/alignn/pretrained.py index 6e49e5b..9d4c9ee 100644 --- a/alignn/pretrained.py +++ b/alignn/pretrained.py @@ -311,8 +311,9 @@ def get_prediction( cutoff=float(cutoff), max_neighbors=max_neighbors, ) + lat = torch.tensor(atoms.lattice_mat) out_data = ( - model([g.to(device), lg.to(device)]) + model([g.to(device), lg.to(device), lat.to(device)]) .detach() .cpu() .numpy() @@ -411,8 +412,8 @@ def atoms_to_graph(atoms): with torch.no_grad(): ids = test_loader.dataset.ids for dat, id in zip(test_loader, ids): - g, lg, target = dat - out_data = model([g.to(device), lg.to(device)]) + g, lg, lat, target = dat + out_data = model([g.to(device), lg.to(device), lat.to(device)]) out_data = out_data.cpu().numpy().tolist() target = target.cpu().numpy().flatten().tolist() info = {} diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index d961939..ea39fb9 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -4,9 +4,9 @@ default_path, ev_curve, ForceField, - get_figshare_model_ff + get_figshare_model_ff, ) -from alignn.graphs import Graph, radius_graph_jarvis +from alignn.graphs import Graph, radius_graph_jarvis, radius_graph_old from alignn.ff.ff import phonons from jarvis.core.atoms import ase_to_atoms from jarvis.db.figshare import get_jid_data @@ -71,11 +71,14 @@ def test_graph_builder(): g, lg = Graph.atom_dgl_multigraph( atoms, neighbor_strategy="radius_graph_jarvis" ) + g = radius_graph_old(atoms) def test_ev(): atoms = Poscar.from_string(pos).atoms - model_path = get_figshare_model_ff(model_name="v10.30.2024_dft_3d_307k") #default_path() + model_path = get_figshare_model_ff( + model_name="v10.30.2024_dft_3d_307k" + ) # default_path() print("model_path", model_path) print("atoms", atoms) # atoms = atoms.make_supercell_matrix([2, 2, 2]) @@ -89,8 +92,10 @@ def test_ev(): def test_ev(): atoms = Poscar.from_string(pos).atoms - #model_path = default_path() - model_path = get_figshare_model_ff(model_name="v10.30.2024_dft_3d_307k") #default_path() + # model_path = default_path() + model_path = get_figshare_model_ff( + model_name="v10.30.2024_dft_3d_307k" + ) # default_path() ff = ForceField( jarvis_atoms=atoms, model_path=model_path, @@ -106,18 +111,19 @@ def test_ev(): # xx = ff.run_npt_nose_hoover(steps=5) - - def test_phonons(): atoms = Poscar.from_string(pos).atoms.get_primitive_atoms - #model_path = default_path() - model_path = get_figshare_model_ff(model_name="v10.30.2024_dft_3d_307k") #default_path() + # model_path = default_path() + model_path = get_figshare_model_ff( + model_name="v10.30.2024_dft_3d_307k" + ) # default_path() ph = phonons(model_path=model_path, atoms=(atoms)) -#print('test_graph_builder') -#test_graph_builder() -#print('test_ev') -#test_ev() -#print('test_phonons') -#test_phonons() + +# print('test_graph_builder') +test_graph_builder() +# print('test_ev') +# test_ev() +# print('test_phonons') +# test_phonons() # test_alignnff() diff --git a/alignn/tests/test_prop.py b/alignn/tests/test_prop.py index 4b78fba..51bd379 100644 --- a/alignn/tests/test_prop.py +++ b/alignn/tests/test_prop.py @@ -69,18 +69,18 @@ def test_models(): test_clean() -# def test_pretrained(): -# box = [[2.715, 2.715, 0], [0, 2.715, 2.715], [2.715, 0, 2.715]] -# coords = [[0, 0, 0], [0.25, 0.2, 0.25]] -# elements = ["Si", "Si"] -# Si = Atoms(lattice_mat=box, coords=coords, elements=elements) -# prd = get_prediction(atoms=Si) -# print(prd) -# cmd1 = "python alignn/pretrained.py" -# os.system(cmd1) -# get_multiple_predictions(atoms_array=[Si, Si]) -# cmd1 = "rm *.json" -# os.system(cmd1) +def test_pretrained(): + box = [[2.715, 2.715, 0], [0, 2.715, 2.715], [2.715, 0, 2.715]] + coords = [[0, 0, 0], [0.25, 0.2, 0.25]] + elements = ["Si", "Si"] + Si = Atoms(lattice_mat=box, coords=coords, elements=elements) + prd = get_prediction(atoms=Si) + print(prd) + cmd1 = "python alignn/pretrained.py" + os.system(cmd1) + get_multiple_predictions(atoms_array=[Si, Si]) + cmd1 = "rm *.json" + os.system(cmd1) def test_alignn_train_regression(): From bc50cada394dcd714d7365b5fa34ea37cacbb24f Mon Sep 17 00:00:00 2001 From: knc6 Date: Wed, 20 Nov 2024 00:34:36 -0500 Subject: [PATCH 22/37] Coverage update. --- alignn/ff/ff.py | 90 ---------------------------------- alignn/tests/test_alignn_ff.py | 10 ++-- 2 files changed, 5 insertions(+), 95 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 3023a82..79d67a2 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -1475,93 +1475,3 @@ def ase_phonon( fig.savefig(filename) plt.close() return bs - - -if __name__ == "__main__": - """ - atoms = JarvisAtoms.from_dict( - # get_jid_data(jid="JVASP-867", dataset="dft_3d")["atoms"] - # get_jid_data(jid="JVASP-1002", dataset="dft_3d")["atoms"] - get_jid_data(jid="JVASP-816", dataset="dft_3d")["atoms"] - ) - mlearn = "/wrk/knc6/ALINN_FC/FD_mult/temp_new" # mlearn_path() - phonons(atoms=atoms, model_path=mlearn, enforce_c_size=3) - """ - ff = get_figshare_model_ff() - print("ff", ff) - # phonons3(atoms=atoms, model_path=mlearn, enforce_c_size=3) - # ase_phonon(atoms=atoms, model_path=mlearn) - -""" -if __name__ == "__main__": - - from jarvis.db.figshare import get_jid_data - from jarvis.core.atoms import Atoms - - # atoms = Spacegroup3D( - # JarvisAtoms.from_dict( - # get_jid_data(jid="JVASP-816", dataset="dft_3d")["atoms"] - # ) - # ).conventional_standard_structure - # atoms = JarvisAtoms.from_poscar("POSCAR") - # atoms = atoms.make_supercell_matrix([2, 2, 2]) - # print(atoms) - model_path = default_path() - print("model_path", model_path) - # atoms=atoms.strain_atoms(.05) - # print(atoms) - # ev = ev_curve(atoms=atoms, model_path=model_path) - # surf = surface_energy(atoms=atoms, model_path=model_path) - # print(surf) - # vac = vacancy_formation(atoms=atoms, model_path=model_path) - # print(vac) - - # ff = ForceField( - # jarvis_atoms=atoms, - # model_path=model_path, - # ) - # en,fs = ff.unrelaxed_atoms() - # print ('en',en) - # print('fs',fs) - # phonons(atoms=atoms) - # phonons3(atoms=atoms) - # ff.set_momentum_maxwell_boltzmann(temperature_K=300) - # xx = ff.optimize_atoms(optimizer="FIRE") - # print("optimized st", xx) - # xx = ff.run_nve_velocity_verlet(steps=5) - # xx = ff.run_nvt_langevin(steps=5) - # xx = ff.run_nvt_andersen(steps=5) - # xx = ff.run_npt_nose_hoover(steps=20000, temperature_K=1800) - # print(xx) - atoms_al = Atoms.from_dict( - get_jid_data(dataset="dft_3d", jid="JVASP-816")["atoms"] - ) - surf = surface_energy(atoms=atoms_al, model_path=model_path) - # atoms_al2o3 = Atoms.from_dict( - # get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] - # ) - # atoms_sio2 = Atoms.from_dict( - # get_jid_data(dataset="dft_3d", jid="JVASP-58349")["atoms"] - # ) - # atoms_cu = Atoms.from_dict( - # get_jid_data(dataset="dft_3d", jid="JVASP-867")["atoms"] - # ) - # atoms_cu2o = Atoms.from_dict( - # get_jid_data(dataset="dft_3d", jid="JVASP-1216")["atoms"] - # ) - # atoms_graph = Atoms.from_dict( - # get_jid_data(dataset="dft_3d", jid="JVASP-48")["atoms"] - # ) - # intf = get_interface_energy( - # film_atoms=atoms_cu, - # subs_atoms=atoms_cu2o, - # film_thickness=25, - # subs_thickness=25, - # model_path=model_path, - # seperation=4.5, - # subs_index=[1, 1, 1], - # film_index=[1, 1, 1], - # ) - # print(intf) - print(surf) -""" diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index ea39fb9..0707f2d 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -7,7 +7,7 @@ get_figshare_model_ff, ) from alignn.graphs import Graph, radius_graph_jarvis, radius_graph_old -from alignn.ff.ff import phonons +from alignn.ff.ff import phonons,ase_phonon from jarvis.core.atoms import ase_to_atoms from jarvis.db.figshare import get_jid_data from jarvis.core.atoms import Atoms @@ -83,7 +83,7 @@ def test_ev(): print("atoms", atoms) # atoms = atoms.make_supercell_matrix([2, 2, 2]) # atoms=atoms.strain_atoms(.05) - ev = ev_curve(atoms=atoms, model_path=model_path) + ev = ev_curve(atoms=atoms, model_path=model_path,on_relaxed_struct=True) # surf = surface_energy(atoms=atoms, model_path=model_path) # print('surf',surf) # vac = vacancy_formation(atoms=atoms, model_path=model_path) @@ -118,12 +118,12 @@ def test_phonons(): model_name="v10.30.2024_dft_3d_307k" ) # default_path() ph = phonons(model_path=model_path, atoms=(atoms)) - + ase_phonon(atoms=atoms,model_path=model_path) # print('test_graph_builder') -test_graph_builder() +#test_graph_builder() # print('test_ev') # test_ev() # print('test_phonons') -# test_phonons() +test_phonons() # test_alignnff() From e4440a0721d210010975ab72381c89e039db020f Mon Sep 17 00:00:00 2001 From: knc6 Date: Wed, 20 Nov 2024 13:03:42 -0500 Subject: [PATCH 23/37] Add NPT test. --- .github/workflows/main.yml | 4 ++-- alignn/ff/ff.py | 10 +++++----- alignn/tests/test_alignn_ff.py | 17 ++++++++++++----- alignn/tests/test_prop.py | 2 +- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3014416..01b3c92 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -39,8 +39,8 @@ jobs: export DGLBACKEND=pytorch export CUDA_VISIBLE_DEVICES="-1" pip install phonopy flake8 pytest pycodestyle pydocstyle codecov pytest-cov coverage - pip uninstall jarvis-tools -y - pip install -q git+https://github.com/usnistgov/jarvis.git@develop + #pip uninstall jarvis-tools -y + #pip install -q git+https://github.com/usnistgov/jarvis.git@develop python setup.py develop echo 'environment.yml' conda env export diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 79d67a2..599194b 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -613,10 +613,10 @@ def run_npt_berendsen( interval=1, temperature_K=300, steps=1000, - taut=49.11347394232032, - taup=98.22694788464064, - pressure=None, - compressibility=None, + taut=5.0 * units.fs, + taup=500.0 * units.fs, + pressure=1.0 * units.bar, + compressibility=5e-7 / units.bar, initial_temperature_K=None, ): """Run NPT.""" @@ -633,7 +633,7 @@ def run_npt_berendsen( taup=taup, pressure=pressure, compressibility=compressibility, - communicator=self.communicator, + # communicator=self.communicator, ) # Create monitors for logfile and a trajectory file # logfile = os.path.join(".", "%s.log" % filename) diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index 0707f2d..0e27468 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -7,7 +7,7 @@ get_figshare_model_ff, ) from alignn.graphs import Graph, radius_graph_jarvis, radius_graph_old -from alignn.ff.ff import phonons,ase_phonon +from alignn.ff.ff import phonons, ase_phonon from jarvis.core.atoms import ase_to_atoms from jarvis.db.figshare import get_jid_data from jarvis.core.atoms import Atoms @@ -83,7 +83,7 @@ def test_ev(): print("atoms", atoms) # atoms = atoms.make_supercell_matrix([2, 2, 2]) # atoms=atoms.strain_atoms(.05) - ev = ev_curve(atoms=atoms, model_path=model_path,on_relaxed_struct=True) + ev = ev_curve(atoms=atoms, model_path=model_path, on_relaxed_struct=True) # surf = surface_energy(atoms=atoms, model_path=model_path) # print('surf',surf) # vac = vacancy_formation(atoms=atoms, model_path=model_path) @@ -108,6 +108,7 @@ def test_ev(): xx = ff.run_nve_velocity_verlet(steps=5) xx = ff.run_nvt_langevin(steps=5) xx = ff.run_nvt_andersen(steps=5) + xx = ff.run_npt_berendsen(steps=5) # xx = ff.run_npt_nose_hoover(steps=5) @@ -118,12 +119,18 @@ def test_phonons(): model_name="v10.30.2024_dft_3d_307k" ) # default_path() ph = phonons(model_path=model_path, atoms=(atoms)) - ase_phonon(atoms=atoms,model_path=model_path) + ase_phonon(atoms=atoms, model_path=model_path) + + +def test_qclean(): + cmd = "rm *.pt *.traj *.csv *.json *range" + os.system(cmd) + # print('test_graph_builder') -#test_graph_builder() +# test_graph_builder() # print('test_ev') # test_ev() # print('test_phonons') -test_phonons() +# test_phonons() # test_alignnff() diff --git a/alignn/tests/test_prop.py b/alignn/tests/test_prop.py index 51bd379..5303125 100644 --- a/alignn/tests/test_prop.py +++ b/alignn/tests/test_prop.py @@ -176,7 +176,7 @@ def test_alignn_train_ff(): def test_clean(): - cmd = "rm *.pt *.csv *.json *range" + cmd = "rm *.pt *.traj *.csv *.json *range" os.system(cmd) From 9540915f3d6ba14adf815440612645baa2f5bc72 Mon Sep 17 00:00:00 2001 From: knc6 Date: Wed, 20 Nov 2024 13:22:27 -0500 Subject: [PATCH 24/37] Add NPT test. --- alignn/tests/test_alignn_ff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index 0e27468..14feb41 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -108,7 +108,7 @@ def test_ev(): xx = ff.run_nve_velocity_verlet(steps=5) xx = ff.run_nvt_langevin(steps=5) xx = ff.run_nvt_andersen(steps=5) - xx = ff.run_npt_berendsen(steps=5) + # xx = ff.run_npt_berendsen(steps=5) # xx = ff.run_npt_nose_hoover(steps=5) From a7c51eca323e69f11a4ec2e4706e6f59081afc41 Mon Sep 17 00:00:00 2001 From: knc6 Date: Thu, 21 Nov 2024 01:13:27 -0500 Subject: [PATCH 25/37] Os. --- alignn/ff/ff.py | 8 +++++++- alignn/tests/test_alignn_ff.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 599194b..bd0c966 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -223,6 +223,7 @@ def __init__( model=None, config=None, force_mult_batchsize=True, + stress_method=None, path=".", model_filename="best_model.pt", config_filename="config.json", @@ -249,6 +250,8 @@ def __init__( # print('config',config) # config=TrainingConfig(**config).dict() self.config = config + if stress_method is not None: + config["model"]["stress_method"] = stress_method # self.stress_wt if self.include_stress: self.implemented_properties = ["energy", "forces", "stress"] if config["model"]["stresswise_weight"] == 0: @@ -322,8 +325,11 @@ def calculate(self, atoms, properties=None, system_changes=None): ) ) forces = result["grad"].detach().cpu().numpy() + # print('self.config["batch_size"]',self.config["batch_size"]) if self.force_mult_batchsize: - forces *= self.config["batch_size"] + # print('forces1',forces) + forces = np.array(forces) * self.config["batch_size"] + # print('forces2',forces) # stress*=self.config['batch_size'] self.results = { "energy": energy, # * num_atoms, diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index 14feb41..020c789 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -20,7 +20,7 @@ ForceField, ) from jarvis.io.vasp.inputs import Poscar - +import os # JVASP-25139 pos = """Rb8 1.0 From 33725b1a9a22e74aaf0e1c89c3dae0e9c2e2e9e8 Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 25 Nov 2024 15:45:18 -0500 Subject: [PATCH 26/37] Back to old stress calculator. --- alignn/ff/ff.py | 192 ++++++++++++++++++++++++------- alignn/models/alignn_atomwise.py | 94 ++++++++------- 2 files changed, 193 insertions(+), 93 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index bd0c966..b83099c 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -8,6 +8,7 @@ from ase.md.nptberendsen import NPTBerendsen from ase.io import Trajectory import matplotlib.pyplot as plt +from jarvis.analysis.thermodynamics.energetics import unary_energy from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.optimize import BFGS from ase.optimize.bfgslinesearch import BFGSLineSearch @@ -29,11 +30,15 @@ from jarvis.db.jsonutils import loadjson from alignn.graphs import Graph from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig +from jarvis.analysis.defects.vacancy import Vacancy import numpy as np +from alignn.pretrained import get_prediction from jarvis.analysis.structure.spacegroup import ( Spacegroup3D, - # symmetrically_distinct_miller_indices, + symmetrically_distinct_miller_indices, ) +from jarvis.analysis.interface.zur import make_interface +from jarvis.analysis.defects.surface import Surface from jarvis.core.kpoints import Kpoints3D as Kpoints import zipfile from ase import Atoms as AseAtoms @@ -46,15 +51,11 @@ from tqdm import tqdm import torch -# from jarvis.analysis.thermodynamics.energetics import unary_energy -# from jarvis.analysis.defects.vacancy import Vacancy -# from jarvis.analysis.defects.surface import Surface -# from alignn.pretrained import get_prediction -# from jarvis.analysis.interface.zur import make_interface -# try: -# from gpaw import GPAW, PW -# except Exception: -# pass +try: + from gpaw import GPAW, PW +except Exception: + pass +# plt.switch_backend("agg") # Reference: https://doi.org/10.1039/D2DD00096B @@ -222,14 +223,12 @@ def __init__( device=None, model=None, config=None, - force_mult_batchsize=True, - stress_method=None, path=".", model_filename="best_model.pt", config_filename="config.json", output_dir=None, batch_stress=True, - stress_wt=0.03, + stress_wt=0.1, **kwargs, ): """Initialize class.""" @@ -244,18 +243,17 @@ def __init__( self.config = config self.include_stress = include_stress self.stress_wt = stress_wt - self.force_mult_batchsize = force_mult_batchsize + # self.force_multiplier = force_multiplier + # self.force_mult_natoms = force_mult_natoms if self.config is None: config = loadjson(os.path.join(path, config_filename)) # print('config',config) # config=TrainingConfig(**config).dict() self.config = config - if stress_method is not None: - config["model"]["stress_method"] = stress_method # self.stress_wt if self.include_stress: self.implemented_properties = ["energy", "forces", "stress"] if config["model"]["stresswise_weight"] == 0: - config["model"]["stresswise_weight"] = 0.1 # self.stress_wt + config["model"]["stresswise_weight"] = 0.1 else: self.implemented_properties = ["energy", "forces"] @@ -278,9 +276,13 @@ def __init__( map_location=self.device, ) ) - model.to(device) - model.eval() - self.model = model + else: + model = self.model + model.to(device) + model.eval() + + self.net = model + self.net.to(self.device) def calculate(self, atoms, properties=None, system_changes=None): """Calculate properties.""" @@ -301,17 +303,17 @@ def calculate(self, atoms, properties=None, system_changes=None): if self.config["model"]["alignn_layers"] > 0: # g,lg = g - result = self.model( + result = self.net( ( g.to(self.device), lg.to(self.device), - torch.tensor(np.array(atoms.cell)) + torch.tensor(atoms.cell) .type(torch.get_default_dtype()) .to(self.device), ) ) else: - result = self.model( + result = self.net( (g.to(self.device, torch.tensor(atoms.cell).to(self.device))) ) # print ('stress',result["stress"].detach().numpy()) @@ -319,22 +321,25 @@ def calculate(self, atoms, properties=None, system_changes=None): energy = result["out"].detach().cpu().numpy() * num_atoms else: energy = result["out"].detach().cpu().numpy() - stress = self.stress_wt * np.array( - full_3x3_to_voigt_6_stress( - result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() - ) - ) - forces = result["grad"].detach().cpu().numpy() - # print('self.config["batch_size"]',self.config["batch_size"]) - if self.force_mult_batchsize: - # print('forces1',forces) - forces = np.array(forces) * self.config["batch_size"] - # print('forces2',forces) - # stress*=self.config['batch_size'] + self.results = { "energy": energy, # * num_atoms, - "forces": forces, - "stress": stress, + "forces": result["grad"].detach().cpu().numpy() + * self.config["batch_size"], + "stress": full_3x3_to_voigt_6_stress( + # np.eye(3) + result["stresses"][:3] + .reshape(3, 3) + .detach() + .cpu() + .numpy() + ) + * self.stress_wt + / 160.21766208, + "dipole": np.zeros(3), + "charges": np.zeros(len(atoms)), + "magmom": 0.0, + "magmoms": np.zeros(len(atoms)), } @@ -619,10 +624,10 @@ def run_npt_berendsen( interval=1, temperature_K=300, steps=1000, - taut=5.0 * units.fs, - taup=500.0 * units.fs, - pressure=1.0 * units.bar, - compressibility=5e-7 / units.bar, + taut=49.11347394232032, + taup=98.22694788464064, + pressure=None, + compressibility=None, initial_temperature_K=None, ): """Run NPT.""" @@ -639,7 +644,7 @@ def run_npt_berendsen( taup=taup, pressure=pressure, compressibility=compressibility, - # communicator=self.communicator, + communicator=self.communicator, ) # Create monitors for logfile and a trajectory file # logfile = os.path.join(".", "%s.log" % filename) @@ -876,7 +881,6 @@ def ev_curve( return x, y, eos, kv -""" def vacancy_formation( atoms=None, jid="", @@ -889,6 +893,7 @@ def vacancy_formation( using_wyckoffs=True, on_relaxed_struct=True, ): + """Get vacancy energy.""" if atoms is None: from jarvis.db.figshare import data @@ -980,6 +985,7 @@ def surface_energy( thickness=25, model_filename="best_model.pt", ): + """Get surface energy.""" if atoms is None: from jarvis.db.figshare import data @@ -1069,6 +1075,7 @@ def get_interface_energy( from_conventional_structure=True, gpaw_verify=False, ): + """Get work of adhesion.""" film_surf = Surface( film_atoms, indices=film_index, @@ -1094,6 +1101,14 @@ def get_interface_energy( atol=atol, apply_strain=apply_strain, ) + """ + print('film') + print(het['film_sl']) + print('subs') + print(het['subs_sl']) + print('intf') + print(het['interface']) + """ a = get_prediction( atoms=het["film_sl"], model_name="jv_optb88vdw_total_energy_alignn" )[0] @@ -1175,7 +1190,6 @@ def get_interface_energy( info["film_sl"] = het["film_sl"].to_dict() info["subs_sl"] = het["subs_sl"].to_dict() return info -""" def phonons( @@ -1481,3 +1495,93 @@ def ase_phonon( fig.savefig(filename) plt.close() return bs + + +if __name__ == "__main__": + """ + atoms = JarvisAtoms.from_dict( + # get_jid_data(jid="JVASP-867", dataset="dft_3d")["atoms"] + # get_jid_data(jid="JVASP-1002", dataset="dft_3d")["atoms"] + get_jid_data(jid="JVASP-816", dataset="dft_3d")["atoms"] + ) + mlearn = "/wrk/knc6/ALINN_FC/FD_mult/temp_new" # mlearn_path() + phonons(atoms=atoms, model_path=mlearn, enforce_c_size=3) + """ + ff = get_figshare_model_ff() + print("ff", ff) + # phonons3(atoms=atoms, model_path=mlearn, enforce_c_size=3) + # ase_phonon(atoms=atoms, model_path=mlearn) + +""" +if __name__ == "__main__": + + from jarvis.db.figshare import get_jid_data + from jarvis.core.atoms import Atoms + + # atoms = Spacegroup3D( + # JarvisAtoms.from_dict( + # get_jid_data(jid="JVASP-816", dataset="dft_3d")["atoms"] + # ) + # ).conventional_standard_structure + # atoms = JarvisAtoms.from_poscar("POSCAR") + # atoms = atoms.make_supercell_matrix([2, 2, 2]) + # print(atoms) + model_path = default_path() + print("model_path", model_path) + # atoms=atoms.strain_atoms(.05) + # print(atoms) + # ev = ev_curve(atoms=atoms, model_path=model_path) + # surf = surface_energy(atoms=atoms, model_path=model_path) + # print(surf) + # vac = vacancy_formation(atoms=atoms, model_path=model_path) + # print(vac) + + # ff = ForceField( + # jarvis_atoms=atoms, + # model_path=model_path, + # ) + # en,fs = ff.unrelaxed_atoms() + # print ('en',en) + # print('fs',fs) + # phonons(atoms=atoms) + # phonons3(atoms=atoms) + # ff.set_momentum_maxwell_boltzmann(temperature_K=300) + # xx = ff.optimize_atoms(optimizer="FIRE") + # print("optimized st", xx) + # xx = ff.run_nve_velocity_verlet(steps=5) + # xx = ff.run_nvt_langevin(steps=5) + # xx = ff.run_nvt_andersen(steps=5) + # xx = ff.run_npt_nose_hoover(steps=20000, temperature_K=1800) + # print(xx) + atoms_al = Atoms.from_dict( + get_jid_data(dataset="dft_3d", jid="JVASP-816")["atoms"] + ) + surf = surface_energy(atoms=atoms_al, model_path=model_path) + # atoms_al2o3 = Atoms.from_dict( + # get_jid_data(dataset="dft_3d", jid="JVASP-32")["atoms"] + # ) + # atoms_sio2 = Atoms.from_dict( + # get_jid_data(dataset="dft_3d", jid="JVASP-58349")["atoms"] + # ) + # atoms_cu = Atoms.from_dict( + # get_jid_data(dataset="dft_3d", jid="JVASP-867")["atoms"] + # ) + # atoms_cu2o = Atoms.from_dict( + # get_jid_data(dataset="dft_3d", jid="JVASP-1216")["atoms"] + # ) + # atoms_graph = Atoms.from_dict( + # get_jid_data(dataset="dft_3d", jid="JVASP-48")["atoms"] + # ) + # intf = get_interface_energy( + # film_atoms=atoms_cu, + # subs_atoms=atoms_cu2o, + # film_thickness=25, + # subs_thickness=25, + # model_path=model_path, + # seperation=4.5, + # subs_index=[1, 1, 1], + # film_index=[1, 1, 1], + # ) + # print(intf) + print(surf) +""" diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index a256719..4fae2c3 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -29,45 +29,47 @@ class ALIGNNAtomWiseConfig(BaseSettings): """Hyperparameter schema for jarvisdgl.models.alignn.""" name: Literal["alignn_atomwise"] - alignn_layers: int = 4 - gcn_layers: int = 4 - # atom_input_features: int = 1 - atom_input_features: int = 92 + alignn_layers: int = 2 + gcn_layers: int = 2 + atom_input_features: int = 1 + # atom_input_features: int = 92 edge_input_features: int = 80 triplet_input_features: int = 40 embedding_features: int = 64 - # hidden_features: int = 64 - hidden_features: int = 256 + hidden_features: int = 64 + # hidden_features: int = 256 + # fc_layers: int = 1 + # fc_features: int = 64 output_features: int = 1 grad_multiplier: int = -1 calculate_gradient: bool = True atomwise_output_features: int = 0 graphwise_weight: float = 1.0 - gradwise_weight: float = 0.0 + gradwise_weight: float = 1.0 stresswise_weight: float = 0.0 atomwise_weight: float = 0.0 + # if link == log, apply `exp` to final outputs + # to constrain predictions to be positive + link: Literal["identity", "log", "logit"] = "identity" zero_inflated: bool = False classification: bool = False force_mult_natoms: bool = False - energy_mult_natoms: bool = False - # energy_mult_natoms: bool = True + energy_mult_natoms: bool = True include_pos_deriv: bool = False use_cutoff_function: bool = False - inner_cutoff: float = 6 # Ansgtrom + inner_cutoff: float = 3 # Ansgtrom stress_multiplier: float = 1 add_reverse_forces: bool = True # will make True as default soon lg_on_fly: bool = True # will make True as default soon + batch_stress: bool = True multiply_cutoff: bool = False use_penalty: bool = True extra_features: int = 0 - exponent: int = 3 + exponent: int = 5 penalty_factor: float = 0.1 penalty_threshold: float = 1 additional_output_features: int = 0 additional_output_weight: float = 0 - stress_method: int = 1 - link: Literal["identity", "log", "logit"] = "identity" - batch_stress: bool = True class Config: """Configure model settings behavior.""" @@ -380,7 +382,7 @@ def forward( features = self.extra_feature_embedding(features) g = g.local_var() result = {} - # print('g',g) + # initial node features: atom feature network... x = g.ndata.pop("atom_features") # print('x1',x,x.shape) @@ -557,42 +559,28 @@ def forward( forces = torch.squeeze(g.ndata["forces_ji"]) if self.config.stresswise_weight != 0: - if self.config.stress_method == 1: - g.ndata["cart_coords"] = compute_cartesian_coordinates( - g, lat - ) - r, bondlength = compute_pair_vector_and_distance(g) - stress = -160.21766208 * ( - torch.matmul(r.T, pair_forces) - # / (2 * g.edata["V"]) - / (2 * g.ndata["V"][0]) - ) - if self.config.stress_method == 2: - cart_coords = compute_cartesian_coordinates( - g, lat - ).view(g.batch_size, -1, 3) - forces_batched = forces.view(g.batch_size, -1, 3) - vols = torch.abs(torch.det(lat)) - if vols.ndim == 0: - vols = vols.unsqueeze(0) - stresses = [] - for graph_id in range(g.batch_size): - st = ( - -160.21766208 - * torch.matmul( - cart_coords[graph_id].T, - forces_batched[graph_id], - ) - / (vols[graph_id]) - ) - stresses.append(st) - # print(st) - stress = torch.stack(stresses) - if self.config.stress_method == 3: + # Under development, use with caution + # 1 eV/Angstrom3 = 160.21766208 GPa + # 1 GPa = 10 kbar + # Virial stress formula, assuming inital velocity = 0 + # Save volume as g.gdta['V']? + # print('pair_forces',pair_forces.shape) + # print('r',r.shape) + # print('g.ndata["V"]',g.ndata["V"].shape) + if not self.config.batch_stress: + # print('Not batch_stress') stress = ( - -1 * torch.einsum("ij, ik->jk", r, pair_forces) / 2 + -1 + * 160.21766208 + * ( + torch.matmul(r.T, pair_forces) + # / (2 * g.edata["V"]) + / (2 * g.ndata["V"][0]) + ) ) - if self.config.stress_method == 4: + # print("stress1", stress, stress.shape) + # print("g.batch_size", g.batch_size) + else: stresses = [] count_edge = 0 count_node = 0 @@ -617,6 +605,14 @@ def forward( stress = self.config.stress_multiplier * torch.stack( stresses ) + # print("stress2", stress, stress.shape) + # virial = ( + # 160.21766208 + # * 10 + # * torch.einsum("ij, ik->jk", + # result["r"], result["dy_dr"]) + # / 2 + # ) # / ( g.ndata["V"][0]) if self.link: out = self.link(out) From af3ae5d1c5711ef9cad6cf930de78f30e6627382 Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 2 Dec 2024 00:28:47 -0500 Subject: [PATCH 27/37] Fix calculator --- alignn/ff/ff.py | 123 ++++++++++++++----------------- alignn/models/alignn_atomwise.py | 42 ++++++++--- 2 files changed, 84 insertions(+), 81 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index b83099c..9947769 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -135,59 +135,52 @@ def default_path(): return dpath -def revised_path(): +def mp_2mill(): """Get defaukt model path.""" dpath = get_figshare_model_ff(model_name="revised") - print("model_path", dpath) + # print("model_path", dpath) return dpath -def alignnff_fmult(): +def mp_167k(): """Get default model path.""" dpath = get_figshare_model_ff(model_name="alignnff_fmult") - print("model_path", dpath) + # print("model_path", dpath) return dpath -def mptraj_path(): +def jv_307k(): """Get MPtraj model path.""" dpath = get_figshare_model_ff(model_name="v8.29.2024_mpf") - print("model_path", dpath) + # print("model_path", dpath) return dpath -def mlearn_path(): +def jv_2mill(): """Get model trained on mlearn path.""" dpath = get_figshare_model_ff(model_name="fmult_mlearn_only") - print("model_path", dpath) - return dpath - - -def fd_path(): - """Get defaukt model path.""" - dpath = get_figshare_model_ff(model_name="alignnff_fd") - print("model_path", dpath) + # print("model_path", dpath) return dpath def wt01_path(): """Get defaukt model path.""" dpath = get_figshare_model_ff(model_name="alignnff_wt01") - print("model_path", dpath) + # print("model_path", dpath) return dpath def wt1_path(): """Get defaukt model path.""" dpath = get_figshare_model_ff(model_name="alignnff_wt1") - print("model_path", dpath) + # print("model_path", dpath) return dpath def wt10_path(): """Get defaukt model path.""" dpath = get_figshare_model_ff(model_name="alignnff_wt10") - print("model_path", dpath) + # print("model_path", dpath) return dpath @@ -228,37 +221,38 @@ def __init__( config_filename="config.json", output_dir=None, batch_stress=True, + force_mult_natoms=False, + force_mult_batchsize=True, + force_multiplier=1, stress_wt=0.1, - **kwargs, ): """Initialize class.""" super(AlignnAtomwiseCalculator, self).__init__( - restart, ignore_bad_restart_file, label, atoms, directory, **kwargs - ) + restart, ignore_bad_restart_file, label, atoms, directory + ) # , **kwargs self.model = model self.device = device self.intensive = intensive - # config = loadjson(os.path.join(path, config_filename)) - # print('config',config) self.config = config self.include_stress = include_stress self.stress_wt = stress_wt - # self.force_multiplier = force_multiplier - # self.force_mult_natoms = force_mult_natoms + self.force_mult_natoms = force_mult_natoms + self.force_mult_batchsize = force_mult_batchsize + self.force_multiplier = force_multiplier if self.config is None: config = loadjson(os.path.join(path, config_filename)) - # print('config',config) - # config=TrainingConfig(**config).dict() self.config = config + if self.force_mult_natoms: + self.config["model"]["force_mult_natoms"] = True if self.include_stress: self.implemented_properties = ["energy", "forces", "stress"] - if config["model"]["stresswise_weight"] == 0: - config["model"]["stresswise_weight"] = 0.1 + if self.config["model"]["stresswise_weight"] == 0: + self.config["model"]["stresswise_weight"] = 0.1 else: self.implemented_properties = ["energy", "forces"] if batch_stress is not None: - config["model"]["batch_stress"] = batch_stress + self.config["model"]["batch_stress"] = batch_stress import torch if self.device is None: @@ -267,8 +261,10 @@ def __init__( ) if self.model is None: - if config["model"]["name"] == "alignn_atomwise": - model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"])) + if self.config["model"]["name"] == "alignn_atomwise": + model = ALIGNNAtomWise( + ALIGNNAtomWiseConfig(**self.config["model"]) + ) model.state_dict() model.load_state_dict( torch.load( @@ -276,19 +272,16 @@ def __init__( map_location=self.device, ) ) + model.to(device) + model.eval() + self.model = model else: model = self.model - model.to(device) - model.eval() - - self.net = model - self.net.to(self.device) def calculate(self, atoms, properties=None, system_changes=None): """Calculate properties.""" j_atoms = ase_to_atoms(atoms) num_atoms = j_atoms.num_atoms - # g, lg = Graph.atom_dgl_multigraph( g, lg = Graph.atom_dgl_multigraph( j_atoms, neighbor_strategy=self.config["neighbor_strategy"], @@ -297,13 +290,9 @@ def calculate(self, atoms, properties=None, system_changes=None): atom_features=self.config["atom_features"], use_canonize=self.config["use_canonize"], ) - # print('g',g) - # print('lg',lg) - # print('config',self.config) if self.config["model"]["alignn_layers"] > 0: - # g,lg = g - result = self.net( + result = self.model( ( g.to(self.device), lg.to(self.device), @@ -313,33 +302,31 @@ def calculate(self, atoms, properties=None, system_changes=None): ) ) else: - result = self.net( + result = self.model( (g.to(self.device, torch.tensor(atoms.cell).to(self.device))) ) - # print ('stress',result["stress"].detach().numpy()) + forces = forces = ( + result["grad"].detach().cpu().numpy() * self.force_multiplier + ) + stress = ( + full_3x3_to_voigt_6_stress( + result["stresses"][:3].reshape(3, 3).detach().cpu().numpy() + ) + * self.stress_wt + / 160.21766208 + ) + energy = result["out"].detach().cpu().numpy() if self.intensive: - energy = result["out"].detach().cpu().numpy() * num_atoms - else: - energy = result["out"].detach().cpu().numpy() + energy *= num_atoms + if self.force_mult_natoms: + forces *= num_atoms + if self.force_mult_batchsize: + forces *= self.config["batch_size"] self.results = { - "energy": energy, # * num_atoms, - "forces": result["grad"].detach().cpu().numpy() - * self.config["batch_size"], - "stress": full_3x3_to_voigt_6_stress( - # np.eye(3) - result["stresses"][:3] - .reshape(3, 3) - .detach() - .cpu() - .numpy() - ) - * self.stress_wt - / 160.21766208, - "dipole": np.zeros(3), - "charges": np.zeros(len(atoms)), - "magmom": 0.0, - "magmoms": np.zeros(len(atoms)), + "energy": energy, + "forces": forces, + "stress": stress, } @@ -404,7 +391,6 @@ def __init__( include_stress=self.include_stress, model_filename=self.model_filename, stress_wt=self.stress_wt, - force_multiplier=self.force_multiplier, force_mult_natoms=self.force_mult_natoms, batch_stress=self.batch_stress, # device="cuda" if torch.cuda.is_available() else "cpu", @@ -1200,7 +1186,7 @@ def phonons( model_filename="best_model.pt", on_relaxed_struct=False, force_mult_natoms=False, - stress_wt=-4800, + stress_wt=0.1, dim=[2, 2, 2], freq_conversion_factor=33.3566830, # ThztoCm-1 phonopy_bands_figname="phonopy_bands.png", @@ -1208,14 +1194,13 @@ def phonons( write_fc=False, min_freq_tol=-0.05, distance=0.2, - force_multiplier=1, ): """Make Phonon calculation setup.""" calc = AlignnAtomwiseCalculator( path=model_path, force_mult_natoms=force_mult_natoms, - force_multiplier=force_multiplier, stress_wt=stress_wt, + model_filename=model_filename, ) from phonopy import Phonopy diff --git a/alignn/models/alignn_atomwise.py b/alignn/models/alignn_atomwise.py index 4fae2c3..afaf369 100644 --- a/alignn/models/alignn_atomwise.py +++ b/alignn/models/alignn_atomwise.py @@ -479,16 +479,9 @@ def forward( natoms = torch.tensor([gg.num_nodes() for gg in dgl.unbatch(g)]).to( g.device ) + en_out = out if self.config.energy_mult_natoms: - # print('g.num_nodes()',g.num_nodes()) - # print('unbatch',dgl.unbatch(g)) - # print('natoms',natoms) - # print('out',out,out.shape) - # print() - # print() en_out = out * natoms # g.num_nodes() - else: - en_out = out if self.config.use_penalty: penalty_factor = ( self.config.penalty_factor @@ -559,16 +552,22 @@ def forward( forces = torch.squeeze(g.ndata["forces_ji"]) if self.config.stresswise_weight != 0: + # print("self.config.batch_stress",self.config.batch_stress) # Under development, use with caution # 1 eV/Angstrom3 = 160.21766208 GPa # 1 GPa = 10 kbar # Virial stress formula, assuming inital velocity = 0 - # Save volume as g.gdta['V']? - # print('pair_forces',pair_forces.shape) - # print('r',r.shape) - # print('g.ndata["V"]',g.ndata["V"].shape) if not self.config.batch_stress: # print('Not batch_stress') + g.ndata["cart_coords"] = compute_cartesian_coordinates( + g, lat + ) + r, bondlength = compute_pair_vector_and_distance(g) + stress = -160.21766208 * ( + torch.matmul(r.T, pair_forces) + # / (2 * g.edata["V"]) + / (2 * g.ndata["V"][0]) + ) stress = ( -1 * 160.21766208 @@ -578,6 +577,25 @@ def forward( / (2 * g.ndata["V"][0]) ) ) + # cart_coords = compute_cartesian_coordinates( + # g, lat + # ).view(g.batch_size, -1, 3) + # forces_batched = forces.view(g.batch_size, -1, 3) + # vols = torch.abs(torch.det(lat)) + # if vols.ndim == 0: + # vols = vols.unsqueeze(0) + # stresses = [] + # for graph_id in range(g.batch_size): + # st = ( + # -160.21766208 + # * torch.matmul( + # cart_coords[graph_id].T, + # forces_batched[graph_id], + # ) + # / (vols[graph_id]) + # ) + # stresses.append(st) + # stress = torch.stack(stresses) # print("stress1", stress, stress.shape) # print("g.batch_size", g.batch_size) else: From 9349975b2f13ef7898ca9204a45c7ef18134e6ef Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Mon, 2 Dec 2024 01:31:49 -0500 Subject: [PATCH 28/37] Update README.md --- README.md | 65 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index c07ad98..53222cf 100644 --- a/README.md +++ b/README.md @@ -187,42 +187,43 @@ Atomisitic line graph neural network-based FF (ALIGNN-FF) can be used to model b [ASE calculator](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html) provides interface to various codes. An example for ALIGNN-FF is give below. Note that there are multiple pretrained ALIGNN-FF models available, here we use the deafult_path model. As more accurate models are developed, they will be made available as well: ``` -from alignn.ff.ff import ( - AlignnAtomwiseCalculator, - default_path, - mptraj_path, - wt01_path, -) -import matplotlib.pyplot as plt -from ase import Atom, Atoms -import time -from ase.build import bulk +from alignn.ff.ff import AlignnAtomwiseCalculator,default_path +from jarvis.io.vasp.inputs import Poscar import numpy as np import matplotlib.pyplot as plt -from ase.build import make_supercell -%matplotlib inline - model_path = default_path() calc = AlignnAtomwiseCalculator(path=model_path) - -t1 = time.time() -# a = 5.43 -lattice_params = np.linspace(5.2, 5.6) -fcc_energies = [] -ready = True -for a in lattice_params: - atoms = bulk("Si", "diamond", a=a) - atoms.set_tags(np.ones(len(atoms))) - atoms.calc = calc - e = atoms.get_potential_energy() - fcc_energies.append(e) -t2 = time.time() -print("Time", t2 - t1) -plt.plot(lattice_params, fcc_energies, "-o") -plt.title("Si") -plt.xlabel("Lattice constant ($\AA$)") -plt.ylabel("Total energy (eV)") -plt.show() +# Source: https://www.ctcms.nist.gov/~knc6/static/JARVIS-DFT/JVASP-1002.xml +poscar="""Si2 +1.0 +3.3641499856336465 -2.5027128e-09 1.94229273881412 +1.121382991333525 3.1717517190189715 1.9422927388141193 +-2.5909987e-09 -1.8321133e-09 3.884586486670313 +Si +2 +Cartesian +3.92483875 2.77528125 6.7980237500000005 +0.56069125 0.39646875 0.9711462500000001 +""" +dx=np.arange(-0.1, 0.1, 0.01) +atoms=Poscar.from_string(poscar).atoms +print(atoms) +y = [] +vol = [] +for i in dx: + struct = atoms.strain_atoms(i) + struct_ase=struct.ase_converter() + struct_ase.calc=calc + y.append(struct_ase.get_potential_energy()) + vol.append(struct.volume) + + + +plt.plot(vol,y,'-o') +plt.xlabel('Volume ($\AA^3$)') +plt.ylabel('Total energy (eV)') +plt.savefig('Si_JVASP-1002.png') +plt.close() ``` To train ALIGNN-FF use `train_alignn.py` script which uses `atomwise_alignn` model: From d28077fdd9a17dd6e92a24dcba701fc6c59d0745 Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 2 Dec 2024 01:35:11 -0500 Subject: [PATCH 29/37] Fix tests. --- alignn/tests/test_alignn_ff.py | 2 -- alignn/tests/test_prop.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/alignn/tests/test_alignn_ff.py b/alignn/tests/test_alignn_ff.py index 020c789..2a7dbd7 100644 --- a/alignn/tests/test_alignn_ff.py +++ b/alignn/tests/test_alignn_ff.py @@ -15,8 +15,6 @@ AlignnAtomwiseCalculator, default_path, wt10_path, - alignnff_fmult, - fd_path, ForceField, ) from jarvis.io.vasp.inputs import Poscar diff --git a/alignn/tests/test_prop.py b/alignn/tests/test_prop.py index 5303125..1539413 100644 --- a/alignn/tests/test_prop.py +++ b/alignn/tests/test_prop.py @@ -11,7 +11,7 @@ from jarvis.core.atoms import Atoms from alignn.train_alignn import train_for_folder from jarvis.db.figshare import get_jid_data -from alignn.ff.ff import AlignnAtomwiseCalculator, default_path, revised_path +from alignn.ff.ff import AlignnAtomwiseCalculator, default_path import torch from jarvis.db.jsonutils import loadjson, dumpjson from alignn.config import TrainingConfig From 40c22a257036e587f0125cff4d33eaefc439b02a Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 2 Dec 2024 01:49:13 -0500 Subject: [PATCH 30/37] New FFs. --- alignn/ff/all_models_ff.json | 3 +++ alignn/ff/ff.py | 24 ++++++++++-------------- alignn/tests/test_prop.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/alignn/ff/all_models_ff.json b/alignn/ff/all_models_ff.json index 085fab8..76db5be 100644 --- a/alignn/ff/all_models_ff.json +++ b/alignn/ff/all_models_ff.json @@ -1,4 +1,7 @@ { + "v12.2.2024_dft_3d_307k": "https://figshare.com/ndownloader/files/50904240", + "v12.2.2024_mp_1.5mill": "https://figshare.com/ndownloader/files/50904783", + "v12.2.2024_mp_187k": "https://figshare.com/ndownloader/files/50904801", "v10.30.2024_dft_3d_307k": "https://figshare.com/ndownloader/files/50634327", "v10.30.2024_mp_168k": "https://figshare.com/ndownloader/files/50634318", "v8.29.2024_dft_3d": "https://figshare.com/ndownloader/files/48889834", diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 9947769..f63172e 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -127,24 +127,25 @@ def get_figshare_model_ff( def default_path(): """Get default model path.""" - dpath = get_figshare_model_ff(model_name="v5.27.2024") + dpath = get_figshare_model_ff(model_name="v12.2.2024_dft_3d_307k") + # dpath = get_figshare_model_ff(model_name="v5.27.2024") # dpath = get_figshare_model_ff(model_name="v8.29.2024_dft_3d") # dpath = get_figshare_model_ff(model_name="alignnff_wt10") # dpath = get_figshare_model_ff(model_name="alignnff_fmult") - # print("model_path", dpath) + print("model_path", dpath) return dpath def mp_2mill(): - """Get defaukt model path.""" - dpath = get_figshare_model_ff(model_name="revised") + """Get default model path.""" + dpath = get_figshare_model_ff(model_name="v12.2.2024_mp_1.5mill") # print("model_path", dpath) return dpath def mp_167k(): """Get default model path.""" - dpath = get_figshare_model_ff(model_name="alignnff_fmult") + dpath = get_figshare_model_ff(model_name="v12.2.2024_mp_187k") # print("model_path", dpath) return dpath @@ -156,13 +157,6 @@ def jv_307k(): return dpath -def jv_2mill(): - """Get model trained on mlearn path.""" - dpath = get_figshare_model_ff(model_name="fmult_mlearn_only") - # print("model_path", dpath) - return dpath - - def wt01_path(): """Get defaukt model path.""" dpath = get_figshare_model_ff(model_name="alignnff_wt01") @@ -216,7 +210,7 @@ def __init__( device=None, model=None, config=None, - path=".", + path=None, model_filename="best_model.pt", config_filename="config.json", output_dir=None, @@ -224,7 +218,7 @@ def __init__( force_mult_natoms=False, force_mult_batchsize=True, force_multiplier=1, - stress_wt=0.1, + stress_wt=0.05, ): """Initialize class.""" super(AlignnAtomwiseCalculator, self).__init__( @@ -239,6 +233,8 @@ def __init__( self.force_mult_natoms = force_mult_natoms self.force_mult_batchsize = force_mult_batchsize self.force_multiplier = force_multiplier + if path is None and model is None: + path = default_path() if self.config is None: config = loadjson(os.path.join(path, config_filename)) self.config = config diff --git a/alignn/tests/test_prop.py b/alignn/tests/test_prop.py index 1539413..32e6a46 100644 --- a/alignn/tests/test_prop.py +++ b/alignn/tests/test_prop.py @@ -191,7 +191,7 @@ def test_calculator(): energy = ase_atoms.get_potential_energy() forces = ase_atoms.get_forces() stress = ase_atoms.get_stress() - print("round(energy,3)", round(energy, 3)) + print("energy", energy) print("max(forces.flatten()),3)", max(forces.flatten())) print("max(stress.flatten()),3)", max(stress.flatten())) # assert round(energy,3)==round(-60.954999923706055,3) From 3a3987e6572c264c82932de9eac1122ac052b8ed Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 2 Dec 2024 02:05:53 -0500 Subject: [PATCH 31/37] Phonons fix. --- alignn/ff/ff.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index f63172e..75ddfb3 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -1183,6 +1183,7 @@ def phonons( on_relaxed_struct=False, force_mult_natoms=False, stress_wt=0.1, + force_multiplier=1, dim=[2, 2, 2], freq_conversion_factor=33.3566830, # ThztoCm-1 phonopy_bands_figname="phonopy_bands.png", @@ -1197,6 +1198,7 @@ def phonons( force_mult_natoms=force_mult_natoms, stress_wt=stress_wt, model_filename=model_filename, + force_multiplier=force_multiplier, ) from phonopy import Phonopy @@ -1331,8 +1333,8 @@ def phonons3( on_relaxed_struct=False, dim=[2, 2, 2], distance=0.2, - stress_wt=-4800, - force_multiplier=2, + stress_wt=0.1, + force_multiplier=1, ): """Make Phonon3 calculation setup.""" from phono3py import Phono3py From 777fa0734485ad43ba99fc5f9da3e2efcf729cd6 Mon Sep 17 00:00:00 2001 From: knc6 Date: Mon, 2 Dec 2024 02:40:58 -0500 Subject: [PATCH 32/37] calc fix --- alignn/ff/ff.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 75ddfb3..8e959de 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -1176,6 +1176,7 @@ def get_interface_energy( def phonons( atoms=None, + calc=None, enforce_c_size=8, line_density=5, model_path=".", @@ -1193,13 +1194,14 @@ def phonons( distance=0.2, ): """Make Phonon calculation setup.""" - calc = AlignnAtomwiseCalculator( - path=model_path, - force_mult_natoms=force_mult_natoms, - stress_wt=stress_wt, - model_filename=model_filename, - force_multiplier=force_multiplier, - ) + if calc is None: + calc = AlignnAtomwiseCalculator( + path=model_path, + force_mult_natoms=force_mult_natoms, + stress_wt=stress_wt, + model_filename=model_filename, + force_multiplier=force_multiplier, + ) from phonopy import Phonopy from phonopy.file_IO import ( @@ -1326,6 +1328,7 @@ def phonons( def phonons3( atoms=None, + calc=None, enforce_c_size=8, line_density=5, model_path=".", @@ -1339,9 +1342,12 @@ def phonons3( """Make Phonon3 calculation setup.""" from phono3py import Phono3py - calc = AlignnAtomwiseCalculator( - path=model_path, force_multiplier=force_multiplier, stress_wt=stress_wt - ) + if calc is None: + calc = AlignnAtomwiseCalculator( + path=model_path, + force_multiplier=force_multiplier, + stress_wt=stress_wt, + ) # kpoints = Kpoints().kpath(atoms, line_density=line_density) # dim = get_supercell_dims(cvn, enforce_c_size=enforce_c_size) @@ -1391,6 +1397,7 @@ def ase_phonon( N=2, path=[], jid=None, + calc=None, npoints=100, dataset="dft_3d", delta=0.01, @@ -1402,9 +1409,10 @@ def ase_phonon( force_multiplier=1, ): """Get phonon bandstructure and DOS using ASE.""" - calc = AlignnAtomwiseCalculator( - path=model_path, force_multiplier=force_multiplier - ) + if calc is None: + calc = AlignnAtomwiseCalculator( + path=model_path, force_multiplier=force_multiplier + ) # Setup crystal and EMT calculator # atoms = bulk("Al", "fcc", a=4.05) From 193c921182e0e44ab4e94dd042a09419bd59238e Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Mon, 2 Dec 2024 09:23:53 -0500 Subject: [PATCH 33/37] Update README.md --- README.md | 122 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 53222cf..1504c8c 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ # ALIGNN & ALIGNN-FF (Introduction) The Atomistic Line Graph Neural Network (https://www.nature.com/articles/s41524-021-00650-1) introduces a new graph convolution layer that explicitly models both two and three body interactions in atomistic systems. This is achieved by composing two edge-gated graph convolution layers, the first applied to the atomistic line graph *L(g)* (representing triplet interactions) and the second applied to the atomistic bond graph *g* (representing pair interactions). -A unified force-field model, ALIGNN-FF (https://pubs.rsc.org/en/content/articlehtml/2023/dd/d2dd00096b ) was developed that can model both structurally and chemically diverse solids with any combination of 89 elements from the periodic table. +Atomisitic line graph neural network-based FF (ALIGNN-FF) (https://pubs.rsc.org/en/content/articlehtml/2023/dd/d2dd00096b ) can be used to model both structurally and chemically diverse systems with any combination of 89 elements from the periodic table, specially for structural optimization. To train the ALIGNN-FF model, we have used the JARVIS-DFT dataset which contains around 75000 materials and 4 million energy-force entries, out of which 307113 are used in the training. These models can be further finetuned, or new models can be developed from scratch on a new dataset. + ![ALIGNN layer schematic](https://github.com/usnistgov/alignn/blob/develop/alignn/tex/schematic_lg.jpg) @@ -51,29 +52,26 @@ bash Miniconda3-latest-MacOSX-x86_64.sh (for Mac) ``` Download 32/64 bit python 3.10 miniconda exe and install (for windows) -#### Method 1 (conda based installation) +#### Method 1 (conda based installation, recommended) Now, let's make a conda environment, say "my_alignn", choose other name as you like:: ``` -conda create --name my_alignn python=3.10 +conda create --name my_alignn python=3.10 -y conda activate my_alignn +conda install dgl=2.1.0 pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia conda install alignn -y ``` -#### optional GPU dependencies notes - -If you need CUDA support, it's best to install PyTorch and DGL before installing alignn to ensure that you get a CUDA-enabled version of DGL. - -``` -conda install dgl=2.1.0 pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia -``` -#### Method 2 (edit/debug in-place install) +#### Method 2 (GitHub based installation) You can laso install a development version of alignn by cloning the repository and installing in place with pip: ``` +conda create --name my_alignn python=3.10 -y +conda activate my_alignn +conda install dgl=2.1.0 pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia git clone https://github.com/usnistgov/alignn cd alignn python -m pip install -e . @@ -82,12 +80,20 @@ python -m pip install -e . #### Method 3 (using pypi): -As an alternate method, ALIGNN can also be installed using `pip` command as follows: +As an alternate method, ALIGNN can also be installed using `pip`. Note, we have received several messages regarding dgl installation issues. You can look into dgl installation [here](https://www.dgl.ai/pages/start.html). Example for PyTorch 2.1+CUDA 12.1+Pip(Stable)+Windows: ``` +pip install -q dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html pip install alignn -pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html ``` +With no GPU/CUDA: +``` +pip install -q dgl -f https://data.dgl.ai/wheels/torch-2.1/repo.html +pip install alignn +``` + +You can find out installation examples in Google Colab notebooks below + Examples --------- @@ -105,7 +111,9 @@ Examples Here, we provide examples for property prediction tasks, development of machine-learning force-fields (MLFF), usage of pre-trained property predictor, MLFFs, webapps etc. -#### Dataset preparation for property prediction tasks +### Dataset preparation for property prediction tasks + + The main script to train model is `train_alignn.py`. A user needs at least the following info to train a model: 1) `id_prop.csv` with name of the file and corresponding value, 2) `config_example.json` a config file with training and hyperparameters. Users can keep their structure files in `POSCAR`, `.cif`, `.xyz` or `.pdb` files in a directory. In the examples below we will use POSCAR format files. In the same directory, there should be an `id_prop.csv` file. @@ -127,9 +135,9 @@ Now, the model is trained as follows. Please increase the `batch_size` parameter ``` train_alignn.py --root_dir "alignn/examples/sample_data" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp ``` + #### Classification example -While the above example is for regression, the follwoing example shows a classification task for metal/non-metal based on the above bandgap values. We transform the dataset -into 1 or 0 based on a threshold of 0.01 eV (controlled by the parameter, `classification_threshold`) and train a similar classification model. Currently, the script allows binary classification tasks only. +While the above example is for regression, the follwoing example shows a classification task for metal/non-metal based on the above bandgap values. We transform the dataset into 1 or 0 based on a threshold of 0.01 eV (controlled by the parameter, `classification_threshold`) and train a similar classification model. Currently, the script allows binary classification tasks only. ``` train_alignn.py --root_dir "alignn/examples/sample_data" --classification_threshold 0.01 --config "alignn/examples/sample_data/config_example.json" --output_dir=temp ``` @@ -140,13 +148,34 @@ An example is given below for training formation energy per atom, bandgap and to ``` train_alignn.py --root_dir "alignn/examples/sample_data_multi_prop" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp ``` -#### Automated model training -Users can try training using multiple example scripts to run multiple dataset (such as JARVIS-DFT, Materials project, QM9_JCTC etc.). Look into the [alignn/scripts/train_*.py](https://github.com/usnistgov/alignn/tree/main/alignn/scripts) folder. This is done primarily to make the trainings more automated rather than making folder/ csv files etc. + +#### Force-field training + +To train ALIGNN-FF we can use the same `train_alignn.py` script which uses `atomwise_alignn` model. + +AtomWise prediction example which looks for similar setup as before but unstead of `id_prop.csv`, it requires `id_prop.json` file (see example in the sample_data_ff directory). The json contains entries such as jid, energy, forces and stress. An example to compile vasprun.xml files into a id_prop.json is kept [here](https://colab.research.google.com/gist/knc6/5513b21f5fd83a7943509ffdf5c3608b/make_id_prop.ipynb). Note ALIGNN-FF requires energy stored as energy per atom: + + +``` +train_alignn.py --root_dir "alignn/examples/sample_data_ff" --config "alignn/examples/sample_data_ff/config_example_atomwise.json" --output_dir="temp" +``` + + +To finetune model, use `--restart_model_path` tag as well in the above with the path of a pretrained ALIGNN-FF model with same model confurations. + +``` +train_alignn.py --root_dir "alignn/examples/sample_data_ff" --restart_model_path "temp/best_model.pt" --config "alignn/examples/sample_data_ff/config_example_atomwise.json" --output_dir="temp1" +``` + +Starting version v2024.10.30, we also allow global training for multi-output along with energy (graph wise output), forces (atomwise gradients), charges/magnetic moments etc. (atomwise but non-gradients) properties with or without additional fingerprints/features in graph. See examples [here](https://github.com/usnistgov/alignn/tree/develop/alignn/examples). + +Users can also try training using multiple example scripts to run multiple dataset (such as JARVIS-DFT, Materials project, QM9_JCTC etc.). Look into the [alignn/scripts/train_*.py](https://github.com/usnistgov/alignn/tree/main/alignn/scripts) folder. This is done primarily to make the trainings more automated rather than making folder/ csv files etc. These scripts automatically download datasets from [Databases in jarvis-tools](https://jarvis-tools.readthedocs.io/en/master/databases.html) and train several models. Make sure you specify your specific queuing system details in the scripts. -#### other examples -Additional example trainings for [2D-exfoliation energy](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb), [superconductor transition temperature](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/ALIGNN_Sc.ipynb). +Additional example trainings for property prediction task: [2D-exfoliation energy](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb), [superconductor transition temperature](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/ALIGNN_Sc.ipynb). + +An example for training MLFF for Silicon is provided [here](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb). It is highly recommeded to get familiar with this example before developing a new model. Using pre-trained models @@ -167,6 +196,21 @@ An example of prediction formation energy per atom using JARVIS-DFT dataset trai pretrained.py --model_name jv_formation_energy_peratom_alignn --file_format poscar --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp ``` +A pretrained ALIGNN-FF (under active development right now) can be used for predicting several properties, such as: + +``` +run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task="unrelaxed_energy" +run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task="optimize" +run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task="ev_curve" +``` + +To know about other tasks, type. + +``` +run_alignn_ff.py -h +``` + +Several supporting scripts for stucture optimization, equation of states, phonon and related calculations are provided in the repo as well. If you need further assistance for a particular task, feel free to raise an GitHus issue. Web-app @@ -174,15 +218,17 @@ Web-app A basic web-app is for direct-prediction available at [JARVIS-ALIGNN app](https://jarvis.nist.gov/jalignn/). Given atomistic structure in POSCAR format it predict formation energy, total energy per atom and bandgap using data trained on JARVIS-DFT dataset. +Similarly, a web-app for [ALIGNN-FF](https://jarvis.nist.gov/jalignnff/) for structure optimization is also available. + ![JARVIS-ALIGNN](https://github.com/usnistgov/alignn/blob/develop/alignn/tex/jalignn.PNG) -ALIGNN-FF +ALIGNN-FF ASE Calculaor ------------------------- -Atomisitic line graph neural network-based FF (ALIGNN-FF) can be used to model both structurally and chemically diverse systems with any combination of 89 elements from the periodic table. To train the ALIGNN-FF model, we have used the JARVIS-DFT dataset which contains around 75000 materials and 4 million energy-force entries, out of which 307113 are used in the training. These models can be further finetuned, or new models can be developed from scratch on a new dataset. + [ASE calculator](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html) provides interface to various codes. An example for ALIGNN-FF is give below. Note that there are multiple pretrained ALIGNN-FF models available, here we use the deafult_path model. As more accurate models are developed, they will be made available as well: @@ -226,36 +272,10 @@ plt.savefig('Si_JVASP-1002.png') plt.close() ``` -To train ALIGNN-FF use `train_alignn.py` script which uses `atomwise_alignn` model: - -AtomWise prediction example which looks for similar setup as before but unstead of `id_prop.csv`, it requires `id_prop.json` file (see example in the sample_data_ff directory). An example to compile vasprun.xml files into a id_prop.json is kept [here](https://colab.research.google.com/gist/knc6/5513b21f5fd83a7943509ffdf5c3608b/make_id_prop.ipynb). Note ALIGNN-FF requires energy stored as energy per atom: - - -``` -train_alignn.py --root_dir "alignn/examples/sample_data_ff" --config "alignn/examples/sample_data_ff/config_example_atomwise.json" --output_dir=temp -``` - - -To finetune model, use `--restart_model_path` tag as well in the above with the path of a pretrained ALIGNN-FF model with same model confurations. - -An example for training MLFF for silicon is provided [here](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb). It is highly recommeded to get familiar with this example before developing a new model. Note: new model configs such as `lg_on_fly` and `add_reverse_forces` should be defaulted to True for newer versions. For MD runs, `use_cutoff_function` is recommended. - -A pretrained ALIGNN-FF (under active development right now) can be used for predicting several properties, such as: - -``` -run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task="unrelaxed_energy" -run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task="optimize" -run_alignn_ff.py --file_path alignn/examples/sample_data/POSCAR-JVASP-10.vasp --task="ev_curve" -``` -To know about other tasks, type. -``` -run_alignn_ff.py -h -``` -Several supporting scripts for stucture optimization, equation of states, phonon and related calculations are provided in the repo as well. If you need further assistance for a particular task, feel free to raise an GitHus issue. @@ -412,7 +432,7 @@ References 7) [Rapid Prediction of Phonon Structure and Properties using an Atomistic Line Graph Neural Network (ALIGNN)](https://journals.aps.org/prmaterials/abstract/10.1103/PhysRevMaterials.7.023803) 8) [Unified graph neural network force-field for the periodic table](https://pubs.rsc.org/en/content/articlehtml/2023/dd/d2dd00096b) 9) [Large Scale Benchmark of Materials Design Methods](https://www.nature.com/articles/s41524-024-01259-w) -10) [CHIPS-FF](https://github.com/usnistgov/chipsff) +10) [CHIPS-FF: Benchmarking universal force-fields](https://github.com/usnistgov/chipsff) Please see detailed publications list [here](https://jarvis-tools.readthedocs.io/en/master/publications.html). @@ -433,7 +453,9 @@ Please report bugs as Github issues (https://github.com/usnistgov/alignn/issues) Funding support -------------------- -NIST-MGI (https://www.nist.gov/mgi). +NIST-MGI (https://www.nist.gov/mgi) + +NIST-CHIPS (https://www.nist.gov/chips) Code of conduct -------------------- From 48f934e6b5ae91551fb4a139cdf56d9b4b9fb05f Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Mon, 2 Dec 2024 09:26:00 -0500 Subject: [PATCH 34/37] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1504c8c..50b6628 100644 --- a/README.md +++ b/README.md @@ -432,7 +432,8 @@ References 7) [Rapid Prediction of Phonon Structure and Properties using an Atomistic Line Graph Neural Network (ALIGNN)](https://journals.aps.org/prmaterials/abstract/10.1103/PhysRevMaterials.7.023803) 8) [Unified graph neural network force-field for the periodic table](https://pubs.rsc.org/en/content/articlehtml/2023/dd/d2dd00096b) 9) [Large Scale Benchmark of Materials Design Methods](https://www.nature.com/articles/s41524-024-01259-w) -10) [CHIPS-FF: Benchmarking universal force-fields](https://github.com/usnistgov/chipsff) +10) [Prediction of Magnetic Properties in van der Waals Magnets using Graph Neural Networks](https://doi.org/10.1103/PhysRevMaterials.8.114002) +11) [CHIPS-FF: Benchmarking universal force-fields](https://github.com/usnistgov/chipsff) Please see detailed publications list [here](https://jarvis-tools.readthedocs.io/en/master/publications.html). From bf6e9234ccde8e182bfc115b6f089a82156b8e97 Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Mon, 2 Dec 2024 09:35:35 -0500 Subject: [PATCH 35/37] Update ff.py --- alignn/ff/ff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alignn/ff/ff.py b/alignn/ff/ff.py index 8e959de..69a969f 100644 --- a/alignn/ff/ff.py +++ b/alignn/ff/ff.py @@ -132,7 +132,7 @@ def default_path(): # dpath = get_figshare_model_ff(model_name="v8.29.2024_dft_3d") # dpath = get_figshare_model_ff(model_name="alignnff_wt10") # dpath = get_figshare_model_ff(model_name="alignnff_fmult") - print("model_path", dpath) + # print("model_path", dpath) return dpath From ed60958d43edcab6bee57ff745653231f4db0d88 Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Mon, 2 Dec 2024 09:42:59 -0500 Subject: [PATCH 36/37] Update README.md --- README.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 50b6628..896ed33 100644 --- a/README.md +++ b/README.md @@ -167,9 +167,17 @@ To finetune model, use `--restart_model_path` tag as well in the above with the train_alignn.py --root_dir "alignn/examples/sample_data_ff" --restart_model_path "temp/best_model.pt" --config "alignn/examples/sample_data_ff/config_example_atomwise.json" --output_dir="temp1" ``` -Starting version v2024.10.30, we also allow global training for multi-output along with energy (graph wise output), forces (atomwise gradients), charges/magnetic moments etc. (atomwise but non-gradients) properties with or without additional fingerprints/features in graph. See examples [here](https://github.com/usnistgov/alignn/tree/develop/alignn/examples). +Starting version v2024.10.30, we also allow global training for multi-output along with energy (graph wise output), forces (atomwise gradients), charges/magnetic moments etc. (atomwise but non-gradients) properties with or without additional fingerprints/features in graph. See examples [here](https://github.com/usnistgov/alignn/tree/main/alignn/examples). -Users can also try training using multiple example scripts to run multiple dataset (such as JARVIS-DFT, Materials project, QM9_JCTC etc.). Look into the [alignn/scripts/train_*.py](https://github.com/usnistgov/alignn/tree/main/alignn/scripts) folder. This is done primarily to make the trainings more automated rather than making folder/ csv files etc. +Multi-GPU training is allowed with `DistributedDataParallel` with `torchrun` command. This feature is not thoroughly tested yet. +Example: + +``` +torchrun --nproc_per_node=4 train_alignn.py --root_dir DataDir --config config.json --output_dir temp +``` +For multi-GPU training make sure you have correct SLURM/PBS script setup correctly such as `#SBATCH -n 4, #SBATCH -N 1, #SBATCH --gres=gpu:4` etc. + +High-throughput like training: Users can also try training using multiple example scripts to run multiple dataset (such as JARVIS-DFT, Materials project, QM9_JCTC etc.). Look into the [alignn/scripts/train_*.py](https://github.com/usnistgov/alignn/tree/main/alignn/scripts) folder. This is done primarily to make the trainings more automated rather than making folder/ csv files etc. These scripts automatically download datasets from [Databases in jarvis-tools](https://jarvis-tools.readthedocs.io/en/master/databases.html) and train several models. Make sure you specify your specific queuing system details in the scripts. @@ -213,7 +221,7 @@ run_alignn_ff.py -h Several supporting scripts for stucture optimization, equation of states, phonon and related calculations are provided in the repo as well. If you need further assistance for a particular task, feel free to raise an GitHus issue. -Web-app +Web-apps ------------ A basic web-app is for direct-prediction available at [JARVIS-ALIGNN app](https://jarvis.nist.gov/jalignn/). Given atomistic structure in POSCAR format it predict formation energy, total energy per atom and bandgap using data trained on JARVIS-DFT dataset. From c17395ca0c36714d49e5e883d78b79e9740c6607 Mon Sep 17 00:00:00 2001 From: Kamal Choudhary Date: Mon, 2 Dec 2024 13:33:46 -0500 Subject: [PATCH 37/37] Update README.md --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 896ed33..027d021 100644 --- a/README.md +++ b/README.md @@ -101,9 +101,11 @@ Examples | Notebooks | Google Colab | Descriptions | | ---------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [Regression model](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb) | Examples for developing single output regression model for exfoliation energies of 2D materials. | -| [MLFF](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) | Examples of training a machine learning force field for Silicon. | +| [Regression task (grpah wise prediction)](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb) | Examples for developing single output regression model for exfoliation energies of 2D materials. | +| [Machine learning force-field training from scratch](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) | Examples of training a machine learning force field for Silicon. | | [ALIGNN-FF Relaxer+EV_curve+Phonons+Interface gamma_surface+Interface separation](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/ALIGNN_Structure_Relaxation_Phonons_Interface.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/ALIGNN_Structure_Relaxation_Phonons_Interface.ipynb) | Examples of using pre-trained ALIGNN-FF force-field model. | +| [Scaling/timing comaprison](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Timing_uMLFF.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Timing_uMLFF.ipynb) | Examples of analyzing scaling | +| [Running MD for Melt-Quench](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Fast_Melt_Quench.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Fast_Melt_Quench.ipynb) | Examples of making amorphous structure with moelcular dynamics. | | [Miscellaneous tasks](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb) | [![Open in Google Colab]](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Training_ALIGNN_model_example.ipynb) | Examples for developing single output (such as formation energy, bandgaps) or multi-output (such as phonon DOS, electron DOS) Regression or Classification (such as metal vs non-metal), Using several pretrained models. |