From 3c0fade874bb5a8189746c2484ff253af988cc85 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:34:22 -0700 Subject: [PATCH 1/8] clone so3 embedding object (#781) --- src/fairchem/core/models/equiformer_v2/transformer_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairchem/core/models/equiformer_v2/transformer_block.py b/src/fairchem/core/models/equiformer_v2/transformer_block.py index bdb97ea468..e7669d301b 100755 --- a/src/fairchem/core/models/equiformer_v2/transformer_block.py +++ b/src/fairchem/core/models/equiformer_v2/transformer_block.py @@ -652,7 +652,7 @@ def forward( batch, # for GraphDropPath node_offset: int = 0, ): - output_embedding = x + output_embedding = x.clone() x_res = output_embedding.embedding output_embedding.embedding = self.norm_1(output_embedding.embedding) From 434b956d12cba04ba132d63b5d583e511dfedda0 Mon Sep 17 00:00:00 2001 From: Daniel Levine Date: Tue, 30 Jul 2024 15:46:04 -0700 Subject: [PATCH 2/8] [BE] Remove large files from fairchem and add references to new location as needed (#761) * Remove large files from fairchem and add references to new location as needed * ruff differs from isort specification... * add fine-tuning supporting-info since it is over 2MB * add unittest * linting * typo * import * Use better function name and re-use fairchem_root function --------- Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> --- docs/core/datasets/oc20dense.md | 2 +- docs/tutorials/NRR/NRR_example.md | 2 +- src/fairchem/applications/AdsorbML/README.md | 2 +- .../2023_neurips_challenge/challenge_eval.py | 7 + .../core/scripts/download_large_files.py | 76 ++++ src/fairchem/data/oc/README.md | 3 +- src/fairchem/data/oc/core/bulk.py | 4 + src/fairchem/data/oc/databases/update.py | 17 +- src/fairchem/data/odac/README.md | 4 +- src/fairchem/data/odac/force_field/README.md | 2 +- .../promising_mof_energies/energy.py | 358 +++++++++--------- .../promising_mof_features/readme | 4 +- .../applications/cattsunami/tests/conftest.py | 15 +- tests/core/test_download_large_files.py | 16 + 14 files changed, 319 insertions(+), 193 deletions(-) create mode 100644 src/fairchem/core/scripts/download_large_files.py create mode 100644 tests/core/test_download_large_files.py diff --git a/docs/core/datasets/oc20dense.md b/docs/core/datasets/oc20dense.md index fb07a09ad0..64639889cc 100644 --- a/docs/core/datasets/oc20dense.md +++ b/docs/core/datasets/oc20dense.md @@ -11,7 +11,7 @@ The OC20Dense dataset is a validation dataset which was used to assess model per |ASE Trajectories |29G |112G | [ee937e5290f8f720c914dc9a56e0281f](https://dl.fbaipublicfiles.com/opencatalystproject/data/adsorbml/oc20_dense_trajectories.tar.gz) | The following files are also provided to be used for evaluation and general information: -* `oc20dense_mapping.pkl` : Mapping of the LMDB `sid` to general metadata information - +* `oc20dense_mapping.pkl` : Mapping of the LMDB `sid` to general metadata information. If this file is not present, run the command `python src/fairchem/core/scripts/download_large_files.py adsorbml` from the root of the fairchem repo to download it. - * `system_id`: Unique system identifier for an adsorbate, bulk, surface combination. * `config_id`: Unique configuration identifier, where `rand` and `heur` correspond to random and heuristic initial configurations, respectively. * `mpid`: Materials Project bulk identifier. diff --git a/docs/tutorials/NRR/NRR_example.md b/docs/tutorials/NRR/NRR_example.md index b69e078d1a..cc5ab6d074 100644 --- a/docs/tutorials/NRR/NRR_example.md +++ b/docs/tutorials/NRR/NRR_example.md @@ -62,7 +62,7 @@ To do this, we will enumerate adsorbate-slab configurations and run ML relaxatio +++ -Be sure to set the path in `fairchem/data/oc/configs/paths.py` to point to the correct place or pass the paths as an argument. The database pickles can be found in `fairchem/data/oc/databases/pkls`. We will show one explicitly here as an example and then run all of them in an automated fashion for brevity. +Be sure to set the path in `fairchem/data/oc/configs/paths.py` to point to the correct place or pass the paths as an argument. The database pickles can be found in `fairchem/data/oc/databases/pkls` (some pkl files are only downloaded by running the command `python src/fairchem/core/scripts/download_large_files.py oc` from the root of the fairchem repo). We will show one explicitly here as an example and then run all of them in an automated fashion for brevity. ```{code-cell} ipython3 import fairchem.data.oc diff --git a/src/fairchem/applications/AdsorbML/README.md b/src/fairchem/applications/AdsorbML/README.md index ca5be57379..700c06b67c 100644 --- a/src/fairchem/applications/AdsorbML/README.md +++ b/src/fairchem/applications/AdsorbML/README.md @@ -21,7 +21,7 @@ NOTE - ASE trajectories exclude systems that were not converged or had invalid c |ASE Trajectories |29G |112G | [ee937e5290f8f720c914dc9a56e0281f](https://dl.fbaipublicfiles.com/opencatalystproject/data/adsorbml/oc20_dense_trajectories.tar.gz) | The following files are also provided to be used for evaluation and general information: -* `oc20dense_mapping.pkl` : Mapping of the LMDB `sid` to general metadata information - +* `oc20dense_mapping.pkl` : Mapping of the LMDB `sid` to general metadata information. If this file is not present, run the command `python src/fairchem/core/scripts/download_large_files.py adsorbml` from the root of the fairchem repo to download it. - * `system_id`: Unique system identifier for an adsorbate, bulk, surface combination. * `config_id`: Unique configuration identifier, where `rand` and `heur` correspond to random and heuristic initial configurations, respectively. * `mpid`: Materials Project bulk identifier. diff --git a/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py b/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py index d7e801fe0b..01c492bbae 100644 --- a/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py +++ b/src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/challenge_eval.py @@ -7,6 +7,8 @@ import numpy as np +from fairchem.core.scripts import download_large_files + def is_successful(best_pred_energy, best_dft_energy, SUCCESS_THRESHOLD=0.1): """ @@ -161,6 +163,11 @@ def main(): # targets and metadata are expected to be in # the same directory as this script + if ( + not Path(__file__).with_name("oc20dense_val_targets.pkl").exists() + or not Path(__file__).with_name("ml_relaxed_dft_targets.pkl").exists() + ): + download_large_files.download_file_group("adsorbml") targets = pickle.load( open(Path(__file__).with_name("oc20dense_val_targets.pkl"), "rb") ) diff --git a/src/fairchem/core/scripts/download_large_files.py b/src/fairchem/core/scripts/download_large_files.py new file mode 100644 index 0000000000..f79fa21561 --- /dev/null +++ b/src/fairchem/core/scripts/download_large_files.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +from urllib.request import urlretrieve + +from fairchem.core.common.tutorial_utils import fairchem_root + +S3_ROOT = "https://dl.fbaipublicfiles.com/opencatalystproject/data/large_files/" + +FILE_GROUPS = { + "odac": [ + Path("configs/odac/s2ef/scaling_factors/painn.pt"), + Path("src/fairchem/data/odac/force_field/data_w_oms.json"), + Path( + "src/fairchem/data/odac/promising_mof/promising_mof_features/JmolData.jar" + ), + Path( + "src/fairchem/data/odac/promising_mof/promising_mof_energies/adsorption_energy.txt" + ), + Path("src/fairchem/data/odac/supercell_info.csv"), + ], + "oc": [Path("src/fairchem/data/oc/databases/pkls/bulks.pkl")], + "adsorbml": [ + Path( + "src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/oc20dense_mapping.pkl" + ), + Path( + "src/fairchem/applications/AdsorbML/adsorbml/2023_neurips_challenge/ml_relaxed_dft_targets.pkl" + ), + ], + "cattsunami": [ + Path("tests/applications/cattsunami/tests/autoframe_inputs_dissociation.pkl"), + Path("tests/applications/cattsunami/tests/autoframe_inputs_transfer.pkl"), + ], + "docs": [ + Path("docs/tutorials/NRR/NRR_example_bulks.pkl"), + Path("docs/core/fine-tuning/supporting-information.json"), + ], +} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "file_group", + type=str, + help="Group of files to download", + default="ALL", + choices=["ALL", *list(FILE_GROUPS)], + ) + return parser.parse_args() + + +def download_file_group(file_group): + if file_group in FILE_GROUPS: + files_to_download = FILE_GROUPS[file_group] + elif file_group == "ALL": + files_to_download = [item for group in FILE_GROUPS.values() for item in group] + else: + raise ValueError( + f'Requested file group {file_group} not recognized. Please select one of {["ALL", *list(FILE_GROUPS)]}' + ) + + fc_root = fairchem_root().parents[1] + for file in files_to_download: + if not (fc_root / file).exists(): + print(f"Downloading {file}...") + urlretrieve(S3_ROOT + file.name, fc_root / file) + else: + print(f"{file} already exists") + + +if __name__ == "__main__": + args = parse_args() + download_file_group(args.file_group) diff --git a/src/fairchem/data/oc/README.md b/src/fairchem/data/oc/README.md index 20205d1d5a..06aba8887f 100644 --- a/src/fairchem/data/oc/README.md +++ b/src/fairchem/data/oc/README.md @@ -9,6 +9,7 @@ This repository hosts the adsorbate-catalyst input generation workflow used in t To install just run in your favorite environment with python >= 3.9 * `pip install fairchem-data-oc` +* `python src/fairchem/core/scripts/download_large_files.py oc` ## Workflow @@ -155,7 +156,7 @@ python structure_generator.py \ ### Bulks -A database of bulk materials taken from existing databases (i.e. Materials Project) and relaxed with consistent RPBE settings may be found in `ocdata/databases/pkls/bulks.pkl`. To preview what bulks are available, view the corresponding mapping between indices and bulks (bulk id and composition): https://dl.fbaipublicfiles.com/opencatalystproject/data/input_generation/mapping_bulks_2021sep20.txt +A database of bulk materials taken from existing databases (i.e. Materials Project) and relaxed with consistent RPBE settings may be found in `databases/pkls/bulks.pkl` (if not, run the command `python src/fairchem/core/scripts/download_large_files.py oc` from the root of the fairchem repo). To preview what bulks are available, view the corresponding mapping between indices and bulks (bulk id and composition): https://dl.fbaipublicfiles.com/opencatalystproject/data/input_generation/mapping_bulks_2021sep20.txt ### Adsorbates diff --git a/src/fairchem/data/oc/core/bulk.py b/src/fairchem/data/oc/core/bulk.py index 9568ad3622..6710b43880 100644 --- a/src/fairchem/data/oc/core/bulk.py +++ b/src/fairchem/data/oc/core/bulk.py @@ -9,6 +9,8 @@ from fairchem.data.oc.core.slab import Slab from fairchem.data.oc.databases.pkls import BULK_PKL_PATH +from fairchem.core.scripts import download_large_files + if TYPE_CHECKING: import ase @@ -51,6 +53,8 @@ def __init__( self.src_id = None else: if bulk_db is None: + if bulk_db_path == BULK_PKL_PATH and not os.path.exists(BULK_PKL_PATH): + download_large_files.download_file_group("oc") with open(bulk_db_path, "rb") as fp: bulk_db = pickle.load(fp) diff --git a/src/fairchem/data/oc/databases/update.py b/src/fairchem/data/oc/databases/update.py index f9ca1f6452..bab75709c3 100644 --- a/src/fairchem/data/oc/databases/update.py +++ b/src/fairchem/data/oc/databases/update.py @@ -6,12 +6,15 @@ from __future__ import annotations import pickle +from pathlib import Path import ase.io from ase.atoms import Atoms from ase.calculators.singlepoint import SinglePointCalculator as SPC from tqdm import tqdm +from fairchem.core.scripts import download_large_files + # Monkey patch fix def pbc_patch(self): @@ -29,7 +32,7 @@ def set_pbc_patch(self, pbc): def update_pkls(): with open( - "ocdata/databases/pkls/adsorbates.pkl", + "oc/databases/pkls/adsorbates.pkl", "rb", ) as fp: data = pickle.load(fp) @@ -38,13 +41,15 @@ def update_pkls(): pbc = data[idx][0].cell._pbc data[idx][0]._pbc = pbc with open( - "ocdata/databases/pkls/adsorbates_new.pkl", + "oc/databases/pkls/adsorbates_new.pkl", "wb", ) as fp: pickle.dump(data, fp) + if not Path("oc/databases/pkls/bulks.pkl").exists(): + download_large_files.download_file_group("oc") with open( - "ocdata/databases/pkls/bulks.pkl", + "oc/databases/pkls/bulks.pkl", "rb", ) as fp: data = pickle.load(fp) @@ -64,7 +69,7 @@ def update_pkls(): bulks.append((atoms, bulk_id)) with open( - "ocdata/databases/pkls/bulks_new.pkl", + "oc/databases/pkls/bulks_new.pkl", "wb", ) as f: pickle.dump(bulks, f) @@ -73,7 +78,7 @@ def update_pkls(): def update_dbs(): for db_name in ["adsorbates", "bulks"]: db = ase.io.read( - f"ocdata/databases/ase/{db_name}.db", + f"oc/databases/ase/{db_name}.db", ":", ) new_data = [] @@ -90,7 +95,7 @@ def update_dbs(): new_data.append(atoms) ase.io.write( - f"ocdata/databases/ase/{db_name}_new.db", + f"oc/databases/ase/{db_name}_new.db", new_data, ) diff --git a/src/fairchem/data/odac/README.md b/src/fairchem/data/odac/README.md index d6529edd74..f46ababd05 100644 --- a/src/fairchem/data/odac/README.md +++ b/src/fairchem/data/odac/README.md @@ -4,9 +4,11 @@ To download the ODAC23 dataset, please see the links [here](https://fair-chem.gi Pre-trained ML models and configs are available [here](https://fair-chem.github.io/core/model_checkpoints.html#open-direct-air-capture-2023-odac23). +Large ODAC files can be downloaded by running the command `python src/fairchem/core/scripts/download_large_files.py odac` from the root of the fairchem repo. + This repository contains the list of [promising MOFs](https://github.com/FAIR-Chem/fairchem/tree/main/src/fairchem/data/odac/promising_mof) discovered in the ODAC23 paper, as well as details of the [classifical force field calculations](https://github.com/FAIR-Chem/fairchem/tree/main/src/fairchem/data/odac/force_field). -Information about supercells can be found in [supercell_info.csv](https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/data/odac/supercell_info.csv) for each example. +Information about supercells can be found in [supercell_info.csv](https://dl.fbaipublicfiles.com/opencatalystproject/data/large_files/supercell_info.csv) for each example (this file is downloaded to the local repo only when the above script is run). ## Citing diff --git a/src/fairchem/data/odac/force_field/README.md b/src/fairchem/data/odac/force_field/README.md index debe565bda..25714603f3 100644 --- a/src/fairchem/data/odac/force_field/README.md +++ b/src/fairchem/data/odac/force_field/README.md @@ -2,7 +2,7 @@ This folder contains data and scripts related to the classical FF analysis performed in this work. -- The `data_w_oms.json` file contains all successful FF interaction energy calculations with both system information and DFT-computed interaction energies. Calculations were performed across the in-domain training, validation, and test sets. +- The `data_w_oms.json` file contains all successful FF interaction energy calculations with both system information and DFT-computed interaction energies. Calculations were performed across the in-domain training, validation, and test sets. If this file is not present, run the command `python src/fairchem/core/scripts/download_large_files.py odac` from the root of the fairchem repo to download it. - The `data_w_ml.json` file contains the same information for systems with successful ML interaction energy predictions. Only systems in the in-domain test set are included here. - The `FF_analysis.py` script performs the error calculations discussed in the paper and generates the four panels of Figure 5. All of the data used in this analysis is contained in 'data_w_oms.json" for reproducibility. - The `FF_calcs` folder contains example calculations for classical FF interaction energy predictions. diff --git a/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py b/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py index 6a9d37924e..547806cc01 100644 --- a/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py +++ b/src/fairchem/data/odac/promising_mof/promising_mof_energies/energy.py @@ -1,8 +1,14 @@ from __future__ import annotations +import os + import matplotlib.pyploat as plt import pandas as pd +from fairchem.core.scripts import download_large_files + +if not os.path.exists("adsorption_energy.txt"): + download_large_files.download_file_group("odac") raw_ads_energy_data = pd.read_csv("adsorption_energy.txt", header=None, sep=" ") complete_data = pd.DataFrame( index=range(raw_ads_energy_data.shape[0]), @@ -170,12 +176,12 @@ current_lowest_energy < lowest_energy_data_co2.loc[index_this_case, "ads_energy_ev"] ): - lowest_energy_data_co2.loc[index_this_case, "ads_energy_ev"] = ( - current_lowest_energy - ) - lowest_energy_data_co2.loc[index_this_case, "configuration_index"] = ( - current_configuration_index - ) + lowest_energy_data_co2.loc[ + index_this_case, "ads_energy_ev" + ] = current_lowest_energy + lowest_energy_data_co2.loc[ + index_this_case, "configuration_index" + ] = current_configuration_index lowest_energy_data_co2.loc[index_this_case, "Name"] = current_name @@ -212,12 +218,12 @@ current_lowest_energy < lowest_energy_data_h2o.loc[index_this_case, "ads_energy_ev"] ): - lowest_energy_data_h2o.loc[index_this_case, "ads_energy_ev"] = ( - current_lowest_energy - ) - lowest_energy_data_h2o.loc[index_this_case, "configuration_index"] = ( - current_configuration_index - ) + lowest_energy_data_h2o.loc[ + index_this_case, "ads_energy_ev" + ] = current_lowest_energy + lowest_energy_data_h2o.loc[ + index_this_case, "configuration_index" + ] = current_configuration_index lowest_energy_data_h2o.loc[index_this_case, "Name"] = current_name lowest_energy_data_co_ads = pd.DataFrame( @@ -254,12 +260,12 @@ current_lowest_energy < lowest_energy_data_co_ads.loc[index_this_case, "ads_energy_ev"] ): - lowest_energy_data_co_ads.loc[index_this_case, "ads_energy_ev"] = ( - current_lowest_energy - ) - lowest_energy_data_co_ads.loc[index_this_case, "configuration_index"] = ( - current_configuration_index - ) + lowest_energy_data_co_ads.loc[ + index_this_case, "ads_energy_ev" + ] = current_lowest_energy + lowest_energy_data_co_ads.loc[ + index_this_case, "configuration_index" + ] = current_configuration_index lowest_energy_data_co_ads.loc[index_this_case, "Name"] = current_name @@ -298,12 +304,12 @@ current_lowest_energy < lowest_energy_data_co_ads_2.loc[index_this_case, "ads_energy_ev"] ): - lowest_energy_data_co_ads_2.loc[index_this_case, "ads_energy_ev"] = ( - current_lowest_energy - ) - lowest_energy_data_co_ads_2.loc[index_this_case, "configuration_index"] = ( - current_configuration_index - ) + lowest_energy_data_co_ads_2.loc[ + index_this_case, "ads_energy_ev" + ] = current_lowest_energy + lowest_energy_data_co_ads_2.loc[ + index_this_case, "configuration_index" + ] = current_configuration_index lowest_energy_data_co_ads_2.loc[index_this_case, "Name"] = current_name @@ -439,9 +445,9 @@ current_lowest_energy < lowest_energy_data_co2_defective.loc[index_this_case, "ads_energy_ev"] ): - lowest_energy_data_co2_defective.loc[index_this_case, "ads_energy_ev"] = ( - current_lowest_energy - ) + lowest_energy_data_co2_defective.loc[ + index_this_case, "ads_energy_ev" + ] = current_lowest_energy lowest_energy_data_co2_defective.loc[ index_this_case, "configuration_index" ] = current_configuration_index @@ -485,9 +491,9 @@ current_lowest_energy < lowest_energy_data_h2o_defective.loc[index_this_case, "ads_energy_ev"] ): - lowest_energy_data_h2o_defective.loc[index_this_case, "ads_energy_ev"] = ( - current_lowest_energy - ) + lowest_energy_data_h2o_defective.loc[ + index_this_case, "ads_energy_ev" + ] = current_lowest_energy lowest_energy_data_h2o_defective.loc[ index_this_case, "configuration_index" ] = current_configuration_index @@ -542,9 +548,9 @@ lowest_energy_data_co_ads_defective.loc[ index_this_case, "configuration_index" ] = current_configuration_index - lowest_energy_data_co_ads_defective.loc[index_this_case, "Name"] = ( - current_name - ) + lowest_energy_data_co_ads_defective.loc[ + index_this_case, "Name" + ] = current_name lowest_energy_data_co_ads_2_defective = pd.DataFrame( columns=complete_data_merged_defective_co_ads_2.columns @@ -600,9 +606,9 @@ lowest_energy_data_co_ads_2_defective.loc[ index_this_case, "configuration_index" ] = current_configuration_index - lowest_energy_data_co_ads_2_defective.loc[index_this_case, "Name"] = ( - current_name - ) + lowest_energy_data_co_ads_2_defective.loc[ + index_this_case, "Name" + ] = current_name adsorption_data_defective = pd.DataFrame( @@ -646,136 +652,132 @@ # adsorption_data_defective_defective.iloc[count,0]=mof_name - adsorption_data_defective.loc[count, "n_converged_CO2"] = ( - complete_data_merged_defective[ - (complete_data_merged_defective["MOF"] == mof_name) - & (complete_data_merged_defective["defect_conc"] == current_defect_conc) - & (complete_data_merged_defective["defect_index"] == current_defect_index) - & (complete_data_merged_defective["n_CO2"] == 1) - & (complete_data_merged_defective["n_H2O"] == 0) - ].shape[0] - ) - adsorption_data_defective.loc[count, "n_converged_H2O"] = ( - complete_data_merged_defective[ - (complete_data_merged_defective["MOF"] == mof_name) - & (complete_data_merged_defective["defect_conc"] == current_defect_conc) - & (complete_data_merged_defective["defect_index"] == current_defect_index) - & (complete_data_merged_defective["n_CO2"] == 0) - & (complete_data_merged_defective["n_H2O"] == 1) - ].shape[0] - ) - adsorption_data_defective.loc[count, "n_converged_co"] = ( - complete_data_merged_defective[ - (complete_data_merged_defective["MOF"] == mof_name) - & (complete_data_merged_defective["defect_conc"] == current_defect_conc) - & (complete_data_merged_defective["defect_index"] == current_defect_index) - & (complete_data_merged_defective["n_CO2"] == 1) - & (complete_data_merged_defective["n_H2O"] == 1) - ].shape[0] - ) - adsorption_data_defective.loc[count, "n_converged_co_2"] = ( - complete_data_merged_defective[ - (complete_data_merged_defective["MOF"] == mof_name) - & (complete_data_merged_defective["defect_conc"] == current_defect_conc) - & (complete_data_merged_defective["defect_index"] == current_defect_index) - & (complete_data_merged_defective["n_CO2"] == 1) - & (complete_data_merged_defective["n_H2O"] == 2) - ].shape[0] - ) + adsorption_data_defective.loc[ + count, "n_converged_CO2" + ] = complete_data_merged_defective[ + (complete_data_merged_defective["MOF"] == mof_name) + & (complete_data_merged_defective["defect_conc"] == current_defect_conc) + & (complete_data_merged_defective["defect_index"] == current_defect_index) + & (complete_data_merged_defective["n_CO2"] == 1) + & (complete_data_merged_defective["n_H2O"] == 0) + ].shape[ + 0 + ] + adsorption_data_defective.loc[ + count, "n_converged_H2O" + ] = complete_data_merged_defective[ + (complete_data_merged_defective["MOF"] == mof_name) + & (complete_data_merged_defective["defect_conc"] == current_defect_conc) + & (complete_data_merged_defective["defect_index"] == current_defect_index) + & (complete_data_merged_defective["n_CO2"] == 0) + & (complete_data_merged_defective["n_H2O"] == 1) + ].shape[ + 0 + ] + adsorption_data_defective.loc[ + count, "n_converged_co" + ] = complete_data_merged_defective[ + (complete_data_merged_defective["MOF"] == mof_name) + & (complete_data_merged_defective["defect_conc"] == current_defect_conc) + & (complete_data_merged_defective["defect_index"] == current_defect_index) + & (complete_data_merged_defective["n_CO2"] == 1) + & (complete_data_merged_defective["n_H2O"] == 1) + ].shape[ + 0 + ] + adsorption_data_defective.loc[ + count, "n_converged_co_2" + ] = complete_data_merged_defective[ + (complete_data_merged_defective["MOF"] == mof_name) + & (complete_data_merged_defective["defect_conc"] == current_defect_conc) + & (complete_data_merged_defective["defect_index"] == current_defect_index) + & (complete_data_merged_defective["n_CO2"] == 1) + & (complete_data_merged_defective["n_H2O"] == 2) + ].shape[ + 0 + ] if not lowest_energy_data_co2_defective[ (lowest_energy_data_co2_defective["MOF"] == mof_name) & (lowest_energy_data_co2_defective["defect_conc"] == current_defect_conc) & (lowest_energy_data_co2_defective["defect_index"] == current_defect_index) ].empty: - adsorption_data_defective.loc[count, "ads_CO2"] = ( - lowest_energy_data_co2_defective[ - (lowest_energy_data_co2_defective["MOF"] == mof_name) - & ( - lowest_energy_data_co2_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_co2_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 6] - ) - adsorption_data_defective.loc[count, "config_CO2"] = ( - lowest_energy_data_co2_defective[ - (lowest_energy_data_co2_defective["MOF"] == mof_name) - & ( - lowest_energy_data_co2_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_co2_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 5] - ) + adsorption_data_defective.loc[ + count, "ads_CO2" + ] = lowest_energy_data_co2_defective[ + (lowest_energy_data_co2_defective["MOF"] == mof_name) + & (lowest_energy_data_co2_defective["defect_conc"] == current_defect_conc) + & (lowest_energy_data_co2_defective["defect_index"] == current_defect_index) + ].iloc[ + 0, 6 + ] + adsorption_data_defective.loc[ + count, "config_CO2" + ] = lowest_energy_data_co2_defective[ + (lowest_energy_data_co2_defective["MOF"] == mof_name) + & (lowest_energy_data_co2_defective["defect_conc"] == current_defect_conc) + & (lowest_energy_data_co2_defective["defect_index"] == current_defect_index) + ].iloc[ + 0, 5 + ] if not lowest_energy_data_h2o_defective[ (lowest_energy_data_h2o_defective["MOF"] == mof_name) & (lowest_energy_data_h2o_defective["defect_conc"] == current_defect_conc) & (lowest_energy_data_h2o_defective["defect_index"] == current_defect_index) ].empty: - adsorption_data_defective.loc[count, "ads_H2O"] = ( - lowest_energy_data_h2o_defective[ - (lowest_energy_data_h2o_defective["MOF"] == mof_name) - & ( - lowest_energy_data_h2o_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_h2o_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 6] - ) - adsorption_data_defective.loc[count, "config_H2O"] = ( - lowest_energy_data_h2o_defective[ - (lowest_energy_data_h2o_defective["MOF"] == mof_name) - & ( - lowest_energy_data_h2o_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_h2o_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 5] - ) + adsorption_data_defective.loc[ + count, "ads_H2O" + ] = lowest_energy_data_h2o_defective[ + (lowest_energy_data_h2o_defective["MOF"] == mof_name) + & (lowest_energy_data_h2o_defective["defect_conc"] == current_defect_conc) + & (lowest_energy_data_h2o_defective["defect_index"] == current_defect_index) + ].iloc[ + 0, 6 + ] + adsorption_data_defective.loc[ + count, "config_H2O" + ] = lowest_energy_data_h2o_defective[ + (lowest_energy_data_h2o_defective["MOF"] == mof_name) + & (lowest_energy_data_h2o_defective["defect_conc"] == current_defect_conc) + & (lowest_energy_data_h2o_defective["defect_index"] == current_defect_index) + ].iloc[ + 0, 5 + ] if not lowest_energy_data_co_ads_defective[ (lowest_energy_data_co_ads_defective["MOF"] == mof_name) & (lowest_energy_data_co_ads_defective["defect_conc"] == current_defect_conc) & (lowest_energy_data_co_ads_defective["defect_index"] == current_defect_index) ].empty: - adsorption_data_defective.loc[count, "ads_co"] = ( - lowest_energy_data_co_ads_defective[ - (lowest_energy_data_co_ads_defective["MOF"] == mof_name) - & ( - lowest_energy_data_co_ads_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_co_ads_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 6] - ) - adsorption_data_defective.loc[count, "config_co"] = ( - lowest_energy_data_co_ads_defective[ - (lowest_energy_data_co_ads_defective["MOF"] == mof_name) - & ( - lowest_energy_data_co_ads_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_co_ads_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 5] - ) + adsorption_data_defective.loc[ + count, "ads_co" + ] = lowest_energy_data_co_ads_defective[ + (lowest_energy_data_co_ads_defective["MOF"] == mof_name) + & ( + lowest_energy_data_co_ads_defective["defect_conc"] + == current_defect_conc + ) + & ( + lowest_energy_data_co_ads_defective["defect_index"] + == current_defect_index + ) + ].iloc[ + 0, 6 + ] + adsorption_data_defective.loc[ + count, "config_co" + ] = lowest_energy_data_co_ads_defective[ + (lowest_energy_data_co_ads_defective["MOF"] == mof_name) + & ( + lowest_energy_data_co_ads_defective["defect_conc"] + == current_defect_conc + ) + & ( + lowest_energy_data_co_ads_defective["defect_index"] + == current_defect_index + ) + ].iloc[ + 0, 5 + ] if not lowest_energy_data_co_ads_2_defective[ (lowest_energy_data_co_ads_2_defective["MOF"] == mof_name) & (lowest_energy_data_co_ads_2_defective["defect_conc"] == current_defect_conc) @@ -784,32 +786,36 @@ == current_defect_index ) ].empty: - adsorption_data_defective.loc[count, "ads_co_2"] = ( - lowest_energy_data_co_ads_2_defective[ - (lowest_energy_data_co_ads_2_defective["MOF"] == mof_name) - & ( - lowest_energy_data_co_ads_2_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_co_ads_2_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 6] - ) - adsorption_data_defective.loc[count, "config_co_2"] = ( - lowest_energy_data_co_ads_2_defective[ - (lowest_energy_data_co_ads_2_defective["MOF"] == mof_name) - & ( - lowest_energy_data_co_ads_2_defective["defect_conc"] - == current_defect_conc - ) - & ( - lowest_energy_data_co_ads_2_defective["defect_index"] - == current_defect_index - ) - ].iloc[0, 5] - ) + adsorption_data_defective.loc[ + count, "ads_co_2" + ] = lowest_energy_data_co_ads_2_defective[ + (lowest_energy_data_co_ads_2_defective["MOF"] == mof_name) + & ( + lowest_energy_data_co_ads_2_defective["defect_conc"] + == current_defect_conc + ) + & ( + lowest_energy_data_co_ads_2_defective["defect_index"] + == current_defect_index + ) + ].iloc[ + 0, 6 + ] + adsorption_data_defective.loc[ + count, "config_co_2" + ] = lowest_energy_data_co_ads_2_defective[ + (lowest_energy_data_co_ads_2_defective["MOF"] == mof_name) + & ( + lowest_energy_data_co_ads_2_defective["defect_conc"] + == current_defect_conc + ) + & ( + lowest_energy_data_co_ads_2_defective["defect_index"] + == current_defect_index + ) + ].iloc[ + 0, 5 + ] # read the mofs missing DDEC charges diff --git a/src/fairchem/data/odac/promising_mof/promising_mof_features/readme b/src/fairchem/data/odac/promising_mof/promising_mof_features/readme index afb41617a0..4910e85eae 100644 --- a/src/fairchem/data/odac/promising_mof/promising_mof_features/readme +++ b/src/fairchem/data/odac/promising_mof/promising_mof_features/readme @@ -7,10 +7,10 @@ Three criterias have to be satisfied: 1. 2 rings are parallel; 2. the distance o 2. metal-oxygen-metal bridges: [$(select {metal})]~[$(select oxygen)]~[$(select {metal})] 3. uncoordinated nitrogen atoms: [$([#7X2r5])] -We recommend using the jmolData.jar for high-throughput calculations. jmol.jar, which takes more time to run, is good for visualization and debug. +We recommend using the JmolData.jar for high-throughput calculations. jmol.jar, which takes more time to run, is good for visualization and debug. Steps: 1. Change the content of 'list_MOF.txt' to the paths of the MOFs -2. Use 'java -jar JmolData.jar -on -s features.txt' to run the script +2. Use 'java -jar JmolData.jar -on -s features.txt' to run the script. If JmolData.jar is missing, run the command `python src/fairchem/core/scripts/download_large_files.py odac` from the root of the fairchem repo to download it. 3. The output will be saved in the 'output.txt' in the same directory by default, and it can be modified at the last line of the code. 'output.txt' has 10 columns: 1. ID is the index in 'list_MOF.txt'. diff --git a/tests/applications/cattsunami/tests/conftest.py b/tests/applications/cattsunami/tests/conftest.py index 24222d9cf7..9afdc0a963 100644 --- a/tests/applications/cattsunami/tests/conftest.py +++ b/tests/applications/cattsunami/tests/conftest.py @@ -1,6 +1,9 @@ -from pathlib import Path +import os import pickle +from pathlib import Path + import pytest +from fairchem.core.scripts import download_large_files @pytest.fixture(scope="class") @@ -17,11 +20,17 @@ def desorption_inputs(request): @pytest.fixture(scope="class") def dissociation_inputs(request): - with open(Path(__file__).parent / "autoframe_inputs_dissociation.pkl", "rb") as fp: + pkl_path = Path(__file__).parent / "autoframe_inputs_dissociation.pkl" + if not pkl_path.exists(): + download_large_files.download_file_group("cattsunami") + with open(pkl_path, "rb") as fp: request.cls.inputs = pickle.load(fp) @pytest.fixture(scope="class") def transfer_inputs(request): - with open(Path(__file__).parent / "autoframe_inputs_transfer.pkl", "rb") as fp: + pkl_path = Path(__file__).parent / "autoframe_inputs_transfer.pkl" + if not pkl_path.exists(): + download_large_files.download_file_group("cattsunami") + with open(pkl_path, "rb") as fp: request.cls.inputs = pickle.load(fp) diff --git a/tests/core/test_download_large_files.py b/tests/core/test_download_large_files.py new file mode 100644 index 0000000000..991f8ce348 --- /dev/null +++ b/tests/core/test_download_large_files.py @@ -0,0 +1,16 @@ +import os +from unittest.mock import patch + +from fairchem.core.scripts import download_large_files as dl_large + + +@patch.object(dl_large, "urlretrieve") +def test_download_large_files(url_mock): + def urlretrieve_mock(x, y): + if not os.path.exists(os.path.dirname(y)): + raise ValueError( + f"The path to {y} does not exist. fairchem directory structure has changed," + ) + + url_mock.side_effect = urlretrieve_mock + dl_large.download_file_group("ALL") From 7ec739763765c297fad0ce646b0b9b449f98cee7 Mon Sep 17 00:00:00 2001 From: Misko Date: Thu, 1 Aug 2024 21:31:08 -0700 Subject: [PATCH 3/8] list available fairchem packages; and show how to use pytest in [dev] (#790) --- docs/core/install.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/core/install.md b/docs/core/install.md index 8ad523f326..5eb4569f82 100644 --- a/docs/core/install.md +++ b/docs/core/install.md @@ -44,28 +44,28 @@ You can also install `pytorch` and `torch_geometric` dependencies from PyPI to s similarly by selecting the appropriate versions in the official [PyG docs](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) -## Install fairchem-core +## Standard installation of fairchem-core Install `fairchem-core` from PyPi ```bash pip install fairchem-core ``` -## Additional packages - +### Additional packages `fairchem` is a namespace package, meaning all packages are installed seperately. If you need to install other packages you can do so by: ```bash pip install fairchem-{package-to-install} ``` +Available `fairchem` packages are `fairchem-core`,`fairchem-data-oc`,`fairchem-demo-ocpapi`,`fairchem-applications-cattsunami` -## Development install - +## Development installation If you plan to make contributions you will need to fork and clone (for windows user please see next section) the repo, set up the environment, and install fairchem-core from source in editable mode with dev dependencies, ```bash git clone https://github.com/FAIR-Chem/fairchem.git cd fairchem pip install -e packages/fairchem-core[dev] +pytest tests/core ``` And similarly for any other namespace package: From 04a69b0353360fe9616047662fe9de4c2168b742 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 2 Aug 2024 10:19:10 -0700 Subject: [PATCH 4/8] Balanced batch sampler+base dataset (#753) * Update BalancedBatchSampler to use datasets' `data_sizes` method Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error` * Remove python 3.10 syntax * Documentation * Added set_epoch method * Format * Changed "resolved dataset" message to be a debug log to reduce log spam * clean up batchsampler and tests * base dataset class * move lin_ref to base dataset * inherit basedataset for ase dataset * filter indices prop * added create_dataset fn * yaml load fix * create dataset function instead of filtering in base * remove filtered_indices * make create_dataset and LMDBDatabase importable from datasets * create_dataset cleanup * test create_dataset * use metadata.natoms directly and add it to subset * use self.indices to handle shard * rename _data_sizes * fix Subset of metadata * minor change to metadata, added full path option * import updates * implement get_metadata for datasets; add tests for max_atoms and balanced partitioning * a[:len(a)+1] does not throw error, change to check for this * off by one fix * fixing tests * plug create_dataset into trainer * remove datasetwithsizes; fix base dataset integration; replace close_db with __del__ * lint * add/fix test; * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * Add extra test case for local batch size = 1 * fix example * fix test case * reorg changes * remove metadata_has_sizes in favor of basedataset function metadata_hasattr * fix data_parallel typo * fix up some tests * rename get_metadata to sample_property_metadata * add slow get_metadata for ase; add tests for get_metadata (ase+lmdb); add test for make lmdb metadata sizes * add support for different backends and ddp in pytest * fix tests and balanced batch sampler * make default dataset lmdb * lint * fix tests * test with world_size=0 by default * fix tests * fix tests.. * remove subsample from oc22 dataset * remove old datasets; add test for noddp * remove load balancing from docs * fix docs; add train_split_settings and test for this --------- Co-authored-by: Nima Shoghi Co-authored-by: Nima Shoghi Co-authored-by: lbluque Co-authored-by: Brandon Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi --- docs/core/fine-tuning/fine-tuning-oxides.md | 1 + .../advanced/fine-tuning-in-python.md | 2 +- src/fairchem/core/common/data_parallel.py | 236 ++++++++--------- src/fairchem/core/common/distutils.py | 6 +- src/fairchem/core/common/test_utils.py | 77 +++--- src/fairchem/core/datasets/__init__.py | 10 +- src/fairchem/core/datasets/ase_datasets.py | 44 ++-- src/fairchem/core/datasets/base_dataset.py | 227 +++++++++++++++++ src/fairchem/core/datasets/lmdb_dataset.py | 61 ++--- .../core/datasets/oc22_lmdb_dataset.py | 22 +- src/fairchem/core/scripts/make_lmdb_sizes.py | 16 +- src/fairchem/core/trainers/base_trainer.py | 91 ++++--- src/fairchem/core/trainers/ocp_trainer.py | 10 +- .../test_data_parallel_batch_sampler.py | 239 ++++++++++-------- tests/core/common/test_gp_utils.py | 110 ++++++-- tests/core/datasets/conftest.py | 28 ++ tests/core/datasets/test_ase_datasets.py | 46 ++-- tests/core/datasets/test_create_dataset.py | 180 +++++++++++++ tests/core/datasets/test_lmdb_dataset.py | 29 +++ tests/core/e2e/test_s2ef.py | 94 ++++++- tests/core/models/test_equiformer_v2.py | 15 +- 21 files changed, 1095 insertions(+), 449 deletions(-) create mode 100644 src/fairchem/core/datasets/base_dataset.py create mode 100644 tests/core/datasets/conftest.py create mode 100644 tests/core/datasets/test_create_dataset.py create mode 100644 tests/core/datasets/test_lmdb_dataset.py diff --git a/docs/core/fine-tuning/fine-tuning-oxides.md b/docs/core/fine-tuning/fine-tuning-oxides.md index 77a9350d3b..39c39cad40 100644 --- a/docs/core/fine-tuning/fine-tuning-oxides.md +++ b/docs/core/fine-tuning/fine-tuning-oxides.md @@ -205,6 +205,7 @@ from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', 'optim.loss_force', # the checkpoint setting causes an error + 'optim.load_balancing', 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, 'optim.eval_every': 10, diff --git a/docs/tutorials/advanced/fine-tuning-in-python.md b/docs/tutorials/advanced/fine-tuning-in-python.md index 1d14219c88..0eeb8e5485 100644 --- a/docs/tutorials/advanced/fine-tuning-in-python.md +++ b/docs/tutorials/advanced/fine-tuning-in-python.md @@ -75,7 +75,7 @@ We start by making the config.yml. We build this from the calculator checkpoint. from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', - delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', + delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes','optim.load_balancing', 'optim.loss_force', # the checkpoint setting causes an error 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index 4d5836b786..89c3b67445 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -9,20 +9,23 @@ import heapq import logging -from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal import numba import numpy as np -import numpy.typing as npt import torch -from torch.utils.data import BatchSampler, DistributedSampler, Sampler +import torch.distributed +from torch.utils.data import BatchSampler, Dataset, DistributedSampler +from typing_extensions import override from fairchem.core.common import distutils, gp_utils from fairchem.core.datasets import data_list_collater +from fairchem.core.datasets.base_dataset import ( + UnsupportedDatasetError, +) if TYPE_CHECKING: - from pathlib import Path - + from numpy.typing import NDArray from torch_geometric.data import Batch, Data @@ -35,30 +38,24 @@ def __call__(self, data_list: list[Data]) -> Batch: @numba.njit -def balanced_partition(sizes: npt.NDArray[np.int_], num_parts: int): +def _balanced_partition(sizes: NDArray[np.int_], num_parts: int): """ Greedily partition the given set by always inserting the largest element into the smallest partition. """ sort_idx = np.argsort(-sizes) # Sort in descending order - heap: list[tuple[list[int], list[int]]] = [ - (sizes[idx], [idx]) for idx in sort_idx[:num_parts] - ] + heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]] heapq.heapify(heap) for idx in sort_idx[num_parts:]: smallest_part = heapq.heappop(heap) new_size = smallest_part[0] + sizes[idx] - new_idx = smallest_part[1] + [idx] + new_idx = smallest_part[1] + [ + idx + ] # TODO should this be append to save time/space heapq.heappush(heap, (new_size, new_idx)) return [part[1] for part in heap] -@runtime_checkable -class _HasMetadata(Protocol): - @property - def metadata_path(self) -> Path: ... - - class StatefulDistributedSampler(DistributedSampler): """ More fine-grained state DataSampler that uses training iteration and epoch @@ -105,56 +102,83 @@ def set_epoch_and_start_iteration(self, epoch, start_iter): self.start_iter = start_iter -class BalancedBatchSampler(Sampler): - def _load_dataset(self, dataset, mode: Literal["atoms", "neighbors"]): - errors: list[str] = [] - if not isinstance(dataset, _HasMetadata): - errors.append(f"Dataset {dataset} does not have a metadata_path attribute.") - return None, errors - if not dataset.metadata_path.exists(): - errors.append(f"Metadata file {dataset.metadata_path} does not exist.") - return None, errors +def _ensure_supported(dataset: Any): + if not isinstance(dataset, Dataset): + raise UnsupportedDatasetError("BalancedBatchSampler requires a dataset.") + + if not dataset.metadata_hasattr("natoms"): + raise UnsupportedDatasetError( + "BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms." + ) - key = {"atoms": "natoms", "neighbors": "neighbors"}[mode] - sizes = np.load(dataset.metadata_path)[key] + logging.debug(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}") + return dataset - return sizes, errors +class BalancedBatchSampler(BatchSampler): def __init__( self, - dataset, + dataset: Dataset, + *, batch_size: int, num_replicas: int, rank: int, device: torch.device, seed: int, - mode: str | bool = "atoms", + mode: bool | Literal["atoms"] = "atoms", shuffle: bool = True, + on_error: Literal["warn_and_balance", "warn_and_no_balance", "raise"] = "raise", drop_last: bool = False, - force_balancing: bool = False, - throw_on_error: bool = False, - ) -> None: - if mode is True: - mode = "atoms" - - if isinstance(mode, str): - mode = mode.lower() - if mode not in ("atoms", "neighbors"): - raise ValueError( - f"Invalid mode {mode}. Must be one of 'atoms', 'neighbors', or a boolean." - ) + ): + """ + Initializes a BalancedBatchSampler object. - self.dataset = dataset - self.batch_size = batch_size - self.num_replicas = num_replicas - self.rank = rank - self.device = device - self.mode = mode - self.shuffle = shuffle - self.drop_last = drop_last + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): The size of each batch. + num_replicas (int): The number of processes participating in distributed training. + rank (int): The rank of the current process in distributed training. + device (torch.device): The device to use for the batches. + mode (str or bool, optional): The mode to use for balancing the batches. Defaults to "atoms". + shuffle (bool, optional): Whether to shuffle the samples. Defaults to True. + on_error (Literal["warn_and_balance", "warn_and_no_balance", "raise"], optional): The action to take when an error occurs (i.e., when we have an invalid dataset). Defaults to "raise". + - "warn_and_balance": Raise a warning and balance the batch by manually loading the data samples and counting the number of nodes (this is slow). + - "warn_and_no_balance": Raise a warning and do not do any balancing. + - "raise": Raise an error. + drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False. + """ + self.disabled = False + self.on_error = on_error + + if mode is False: + logging.warning(f"Disabled BalancedBatchSampler because {mode=}.") + self.disabled = True + elif mode.lower() != "atoms": + raise ValueError( + f"Only mode='atoms' or mode=True is supported, got {mode=}." + ) + elif num_replicas == 1: + logging.warning(f"Disabled BalancedBatchSampler because {num_replicas=}.") + self.disabled = True + + try: + dataset = _ensure_supported(dataset) + except UnsupportedDatasetError as error: + if self.on_error == "raise": + raise error + if self.on_error == "warn_and_balance": + logging.warning( + f"Failed to get data sizes from metadata, loading data to get sizes (THIS IS SLOW). {error}" + ) + elif self.on_error == "warn_and_no_balance": + logging.warning( + f"Failed to get data sizes, falling back to uniform partitioning. {error}" + ) + else: + raise ValueError(f"Unknown on_error={self.on_error}") from error - self.single_sampler = StatefulDistributedSampler( - self.dataset, + sampler = StatefulDistributedSampler( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, @@ -162,82 +186,59 @@ def __init__( batch_size=batch_size, seed=seed, ) - self.batch_sampler = BatchSampler( - self.single_sampler, - batch_size, - drop_last=drop_last, - ) - - self.sizes = None - self.balance_batches = False - if self.num_replicas <= 1: - logging.info("Batch balancing is disabled for single GPU training.") - return - - if self.mode is False: - logging.info( - "Batch balancing is disabled because `optim.load_balancing` is `False`" - ) - return - - self.sizes, errors = self._load_dataset(dataset, self.mode) - if self.sizes is None: - self.balance_batches = force_balancing - if force_balancing: - errors.append( - "BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead! " - "You can disable balancing by setting `optim.load_balancing` to `False`." - ) - else: - errors.append( - "Batches will not be balanced, which can incur significant overhead!" - ) - else: - self.balance_batches = True - - if errors: - msg = "BalancedBatchSampler: " + " ".join(errors) - if throw_on_error: - raise RuntimeError(msg) + super().__init__(sampler, batch_size=batch_size, drop_last=drop_last) + self.device = device - logging.warning(msg) + logging.info( + f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}" + ) - def __len__(self) -> int: - return len(self.batch_sampler) + def _get_natoms(self, batch_idx: list[int]): + if self.sampler.dataset.metadata_hasattr("natoms"): + return np.array( + self.sampler.dataset.get_metadata("natoms", batch_idx) + ).reshape(-1) + if self.on_error == "warn_and_balance": + return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx]) + return None def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None: - if not hasattr(self.single_sampler, "set_epoch_and_start_iteration"): + if not isinstance(self.sampler, StatefulDistributedSampler): if start_iteration != 0: raise NotImplementedError( f"{type(self.single_sampler)} does not support resuming from a nonzero step." ) - self.single_sampler.set_epoch(epoch) + self.sampler.set_epoch(epoch) else: - self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration) + self.sampler.set_epoch_and_start_iteration(epoch, start_iteration) + + def set_epoch(self, epoch: int) -> None: + if isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + @staticmethod + def _dist_enabled(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @override def __iter__(self): - if not self.balance_batches: - yield from self.batch_sampler + if self.disabled or not self._dist_enabled(): + yield from super().__iter__() return - for batch_idx in self.batch_sampler: - if self.sizes is None: - # Unfortunately, we need to load the data to know the image sizes - data_list = [self.dataset[idx] for idx in batch_idx] - - if self.mode == "atoms": - sizes = [data.num_nodes for data in data_list] - elif self.mode == "neighbors": - sizes = [data.edge_index.shape[1] for data in data_list] - else: - raise NotImplementedError( - f"Unknown load balancing mode: {self.mode}" - ) - else: - sizes = [self.sizes[idx] for idx in batch_idx] - - idx_sizes = torch.stack([torch.tensor(batch_idx), torch.tensor(sizes)]) + for batch_idx in super().__iter__(): + sizes = self._get_natoms(batch_idx) + if sizes is None: # on_error == "warn_and_no_balance" is set + yield batch_idx + continue + + idx_sizes = torch.stack( + [ + torch.tensor(batch_idx, device=self.device), + torch.tensor(sizes, device=self.device), + ] + ) idx_sizes_all = distutils.all_gather(idx_sizes, device=self.device) idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu() if gp_utils.initialized(): @@ -245,9 +246,10 @@ def __iter__(self): idx_all = idx_sizes_all[0] sizes_all = idx_sizes_all[1] - local_idx_balanced = balanced_partition( - sizes_all.numpy(), num_parts=self.num_replicas + local_idx_balanced = _balanced_partition( + sizes_all.numpy(), + num_parts=self.sampler.num_replicas, ) # Since DistributedSampler pads the last batch # this should always have an entry for each replica. - yield idx_all[local_idx_balanced[self.rank]] + yield idx_all[local_idx_balanced[self.sampler.rank]] diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 919f7ba66d..8989840641 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -98,7 +98,7 @@ def setup(config) -> None: ) else: config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend="nccl") + dist.init_process_group(backend=config.get("backend", "nccl")) def cleanup() -> None: @@ -144,7 +144,7 @@ def all_reduce( if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) dist.all_reduce(tensor, group=group) if average: tensor /= get_world_size() @@ -162,7 +162,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]: if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())] dist.all_gather(tensor_list, tensor, group=group) if not isinstance(data, torch.Tensor): diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index 8aaf822105..130daba2d5 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -44,9 +44,55 @@ class PGConfig: use_gp: bool = True +def init_env_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["RANK"] = str(rank) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + +def init_pg_and_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + # setup default process group + dist.init_process_group( + rank=rank, + world_size=pg_setup_params.world_size, + backend=pg_setup_params.backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) + # setup gp + if pg_setup_params.use_gp: + config = { + "gp_gpus": pg_setup_params.gp_group_size, + "distributed_backend": pg_setup_params.backend, + } + setup_gp(config) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + def spawn_multi_process( config: PGConfig, test_method: callable, + init_and_launch: callable, *test_method_args: Any, **test_method_kwargs: Any, ) -> list[Any]: @@ -72,7 +118,7 @@ def spawn_multi_process( torch.multiprocessing.spawn( # torch.multiprocessing.spawn sends rank as the first param # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn - _init_pg_and_rank_and_launch_test, + init_and_launch, args=( config, mp_output_dict, @@ -84,32 +130,3 @@ def spawn_multi_process( ) return [mp_output_dict[i] for i in range(config.world_size)] - - -def _init_pg_and_rank_and_launch_test( - rank: int, - pg_setup_params: PGConfig, - mp_output_dict: dict[int, object], - test_method: callable, - args: list[object], - kwargs: dict[str, object], -) -> None: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = pg_setup_params.port - os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) - os.environ["LOCAL_RANK"] = str(rank) - # setup default process group - dist.init_process_group( - rank=rank, - world_size=pg_setup_params.world_size, - backend=pg_setup_params.backend, - timeout=timedelta(seconds=10), # setting up timeout for distributed collectives - ) - # setup gp - if pg_setup_params.use_gp: - config = { - "gp_gpus": pg_setup_params.gp_group_size, - "distributed_backend": pg_setup_params.backend, - } - setup_gp(config) - mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme diff --git a/src/fairchem/core/datasets/__init__.py b/src/fairchem/core/datasets/__init__.py index 1fd4b51cd5..dc3f7d0e4d 100644 --- a/src/fairchem/core/datasets/__init__.py +++ b/src/fairchem/core/datasets/__init__.py @@ -5,23 +5,19 @@ from __future__ import annotations from .ase_datasets import AseDBDataset, AseReadDataset, AseReadMultiStructureDataset +from .base_dataset import create_dataset from .lmdb_database import LMDBDatabase from .lmdb_dataset import ( LmdbDataset, - SinglePointLmdbDataset, - TrajectoryLmdbDataset, data_list_collater, ) -from .oc22_lmdb_dataset import OC22LmdbDataset __all__ = [ "AseDBDataset", "AseReadDataset", "AseReadMultiStructureDataset", "LmdbDataset", - "SinglePointLmdbDataset", - "TrajectoryLmdbDataset", - "data_list_collater", - "OC22LmdbDataset", "LMDBDatabase", + "create_dataset", + "data_list_collater", ] diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index a258f42832..15c22322db 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -20,13 +20,12 @@ import ase import numpy as np -import torch.nn from torch import tensor -from torch.utils.data import Dataset from tqdm import tqdm from fairchem.core.common.registry import registry from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.lmdb_database import LMDBDatabase from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms @@ -60,7 +59,7 @@ def apply_one_tags( return atoms -class AseAtomsDataset(Dataset, ABC): +class AseAtomsDataset(BaseDataset, ABC): """ This is an abstract Dataset that includes helpful utilities for turning ASE atoms objects into OCP-usable data objects. This should not be instantiated directly @@ -81,7 +80,7 @@ def __init__( config: dict, atoms_transform: Callable[[ase.Atoms, Any, ...], ase.Atoms] = apply_one_tags, ) -> None: - self.config = config + super().__init__(config) a2g_args = config.get("a2g_args", {}) or {} @@ -96,19 +95,13 @@ def __init__( self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - self.lin_ref = None - if self.config.get("lin_ref", False): - lin_ref = torch.tensor( - np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - ) - self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) - self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): self.__getitem__ = cache(self.__getitem__) self.ids = self._load_dataset_get_ids(config) + self.num_samples = len(self.ids) if len(self.ids) == 0: raise ValueError( @@ -116,9 +109,6 @@ def __init__( f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" ) - def __len__(self) -> int: - return len(self.ids) - def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): @@ -174,11 +164,7 @@ def _load_dataset_get_ids(self, config): def get_relaxed_energy(self, identifier): raise NotImplementedError("IS2RE-Direct is not implemented with this dataset.") - def close_db(self) -> None: - # This method is sometimes called by a trainer - pass - - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: metadata = {} if num_samples < len(self): @@ -197,6 +183,18 @@ def get_metadata(self, num_samples: int = 100) -> dict: return metadata + def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here + if attr != "natoms": + return None + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) + return len(self.get_atoms(idx)) + @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): @@ -399,7 +397,7 @@ def get_atoms(self, idx: str) -> ase.Atoms: return atoms - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: return {} def get_relaxed_energy(self, identifier) -> float: @@ -556,17 +554,17 @@ def connect_db( return ase.db.connect(address, **connect_args) - def close_db(self) -> None: + def __del__(self): for db in self.dbs: if hasattr(db, "close"): db.close() - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: logging.warning( "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" ) if self.dbs[0].metadata == {}: - return super().get_metadata(num_samples) + return super().sample_property_metadata(num_samples) return copy.deepcopy(self.dbs[0].metadata) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py new file mode 100644 index 0000000000..2ca26596c3 --- /dev/null +++ b/src/fairchem/core/datasets/base_dataset.py @@ -0,0 +1,227 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from abc import ABCMeta +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + NamedTuple, + TypeVar, +) + +import numpy as np +import torch +from torch import randperm +from torch.utils.data import Dataset +from torch.utils.data import Subset as Subset_ + +from fairchem.core.common.registry import registry + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import ArrayLike + + +T_co = TypeVar("T_co", covariant=True) + + +class DatasetMetadata(NamedTuple): + natoms: ArrayLike | None = None + + +class UnsupportedDatasetError(ValueError): + pass + + +class BaseDataset(Dataset[T_co], metaclass=ABCMeta): + """Base Dataset class for all OCP datasets.""" + + def __init__(self, config: dict): + """Initialize + + Args: + config (dict): dataset configuration + """ + self.config = config + self.paths = [] + + if "src" in self.config: + if isinstance(config["src"], str): + self.paths = [Path(self.config["src"])] + else: + self.paths = tuple(Path(path) for path in config["src"]) + + self.lin_ref = None + if self.config.get("lin_ref", False): + lin_ref = torch.tensor( + np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] + ) + self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) + + def __len__(self) -> int: + return self.num_samples + + def metadata_hasattr(self, attr) -> bool: + if self._metadata is None: + return False + return hasattr(self._metadata, attr) + + @cached_property + def indices(self): + return np.arange(self.num_samples, dtype=int) + + @cached_property + def _metadata(self) -> DatasetMetadata: + # logic to read metadata file here + metadata_npzs = [] + if self.config.get("metadata_path", None) is not None: + metadata_npzs.append( + np.load(self.config["metadata_path"], allow_pickle=True) + ) + + else: + for path in self.paths: + if path.is_file(): + metadata_file = path.parent / "metadata.npz" + else: + metadata_file = path / "metadata.npz" + if metadata_file.is_file(): + metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) + + if len(metadata_npzs) == 0: + logging.warning( + f"Could not find dataset metadata.npz files in '{self.paths}'" + ) + return None + + metadata = DatasetMetadata( + **{ + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in DatasetMetadata._fields + } + ) + + assert metadata.natoms.shape[0] == len( + self + ), "Loaded metadata and dataset size mismatch." + + return metadata + + def get_metadata(self, attr, idx): + if self._metadata is not None: + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + return None + + +class Subset(Subset_, BaseDataset): + """A pytorch subset that also takes metadata if given.""" + + def __init__( + self, + dataset: BaseDataset, + indices: Sequence[int], + metadata: DatasetMetadata | None = None, + ) -> None: + super().__init__(dataset, indices) + self.metadata = metadata + self.indices = indices + self.num_samples = len(indices) + self.config = dataset.config + + @cached_property + def _metadata(self) -> DatasetMetadata: + return self.dataset._metadata + + def get_metadata(self, attr, idx): + if isinstance(idx, list): + return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) + return self.dataset.get_metadata(attr, self.indices[idx]) + + +def create_dataset(config: dict[str, Any], split: str) -> Subset: + """Create a dataset from a config dictionary + + Args: + config (dict): dataset config dictionary + split (str): name of split + + Returns: + Subset: dataset subset class + """ + # Initialize the dataset + dataset_cls = registry.get_dataset_class(config.get("format", "lmdb")) + assert issubclass(dataset_cls, Dataset), f"{dataset_cls} is not a Dataset" + + # remove information about other splits, only keep specified split + # this may only work with the mt config not main config + current_split_config = config.copy() + if "splits" in current_split_config: + current_split_config.pop("splits") + current_split_config.update(config["splits"][split]) + + seed = current_split_config.get("seed", 0) + if split != "train": + seed += ( + 1 # if we use same dataset for train / val , make sure its diff sampling + ) + + g = torch.Generator() + g.manual_seed(seed) + + dataset = dataset_cls(current_split_config) + # Get indices of the dataset + indices = dataset.indices + max_atoms = current_split_config.get("max_atoms", None) + if max_atoms is not None: + if not dataset.metadata_hasattr("natoms"): + raise ValueError("Cannot use max_atoms without dataset metadata") + indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms] + + # Apply dataset level transforms + # TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle? + first_n = current_split_config.get("first_n") + sample_n = current_split_config.get("sample_n") + no_shuffle = current_split_config.get("no_shuffle") + # this is true if at most one of the mutually exclusive arguments are set + if sum(arg is not None for arg in (first_n, sample_n, no_shuffle)) > 1: + raise ValueError( + "sample_n, first_n, no_shuffle are mutually exclusive arguments. Only one can be provided." + ) + if first_n is not None: + max_index = first_n + elif sample_n is not None: + # shuffle by default, user can disable to optimize if they have confidence in dataset + # shuffle all datasets by default to avoid biasing the sampling in concat dataset + # TODO only shuffle if split is train + max_index = sample_n + indices = indices[randperm(len(indices), generator=g)] + else: + max_index = len(indices) + indices = ( + indices if no_shuffle else indices[randperm(len(indices), generator=g)] + ) + + if max_index > len(indices): + msg = ( + f"Cannot take {max_index} data points from a dataset of only length {len(indices)}.\n" + f"Make sure to set first_n or sample_n to a number =< the total samples in dataset." + ) + if max_atoms is not None: + msg = msg[:-1] + f"that are smaller than the given max_atoms {max_atoms}." + raise ValueError(msg) + + indices = indices[:max_index] + + return Subset(dataset, indices, metadata=dataset._metadata) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index 91ced220ea..ca1fcc2b77 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -9,32 +9,33 @@ import bisect import logging import pickle -import warnings -from pathlib import Path -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import lmdb import numpy as np import torch -from torch.utils.data import Dataset from torch_geometric.data import Batch -from torch_geometric.data.data import BaseData from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms +if TYPE_CHECKING: + from pathlib import Path + + from torch_geometric.data.data import BaseData + T_co = TypeVar("T_co", covariant=True) @registry.register_dataset("lmdb") @registry.register_dataset("single_point_lmdb") @registry.register_dataset("trajectory_lmdb") -class LmdbDataset(Dataset[T_co]): - metadata_path: Path +class LmdbDataset(BaseDataset): sharded: bool r"""Dataset class to load from LMDB files containing relaxation @@ -50,20 +51,21 @@ class LmdbDataset(Dataset[T_co]): """ def __init__(self, config) -> None: - super().__init__() - self.config = config + super().__init__(config) assert not self.config.get( "train_on_oc20_total_energies", False ), "For training on total energies set dataset=oc22_lmdb" - self.path = Path(self.config["src"]) + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] + if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" - self.metadata_path = self.path / "metadata.npz" - self._keys = [] self.envs = [] for db_path in db_paths: @@ -86,7 +88,6 @@ def __init__(self, config) -> None: self._keylen_cumulative = np.cumsum(keylens).tolist() self.num_samples = sum(keylens) else: - self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) # If "length" encoded as ascii is present, use that @@ -113,19 +114,15 @@ def __init__(self, config) -> None: self.indices, self.config.get("total_shards", 1) ) # limit each process to see a subset of data based off defined shard - self.available_indices = self.shards[self.config.get("shard", 0)] - self.num_samples = len(self.available_indices) + self.indices = self.shards[self.config.get("shard", 0)] + self.num_samples = len(self.indices) self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - def __len__(self) -> int: - return self.num_samples - def __getitem__(self, idx: int) -> T_co: # if sharding, remap idx to appropriate idx of the sharded set - if self.sharded: - idx = self.available_indices[idx] + idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) @@ -165,14 +162,14 @@ def connect_db(self, lmdb_path: Path | None = None) -> lmdb.Environment: max_readers=1, ) - def close_db(self) -> None: + def __del__(self): if not self.path.is_file(): for env in self.envs: env.close() else: self.env.close() - def get_metadata(self, num_samples: int = 100): + def sample_property_metadata(self, num_samples: int = 100): # This will interogate the classic OCP LMDB format to determine # which properties are present and attempt to guess their shapes # and whether they are intensive or extensive. @@ -214,26 +211,6 @@ def get_metadata(self, num_samples: int = 100): } -class SinglePointLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "SinglePointLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - -class TrajectoryLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "TrajectoryLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData: batch = Batch.from_data_list(data_list) diff --git a/src/fairchem/core/datasets/oc22_lmdb_dataset.py b/src/fairchem/core/datasets/oc22_lmdb_dataset.py index 0c6d4e8bfb..867a72726f 100644 --- a/src/fairchem/core/datasets/oc22_lmdb_dataset.py +++ b/src/fairchem/core/datasets/oc22_lmdb_dataset.py @@ -9,22 +9,21 @@ import bisect import pickle -from pathlib import Path import lmdb import numpy as np import torch -from torch.utils.data import Dataset from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance as aii from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.modules.transforms import DataTransforms @registry.register_dataset("oc22_lmdb") -class OC22LmdbDataset(Dataset): +class OC22LmdbDataset(BaseDataset): r"""Dataset class to load from LMDB files containing relaxation trajectories or single point computations. @@ -43,10 +42,13 @@ class OC22LmdbDataset(Dataset): """ def __init__(self, config, transform=None) -> None: - super().__init__() - self.config = config + super().__init__(config) + + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] - self.path = Path(self.config["src"]) self.data2train = self.config.get("data2train", "all") if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) @@ -114,19 +116,11 @@ def __init__(self, config, transform=None) -> None: if self.train_on_oc20_total_energies: with open(config["oc20_ref"], "rb") as fp: self.oc20_ref = pickle.load(fp) - if self.config.get("lin_ref", False): - coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - self.lin_ref = torch.nn.Parameter(torch.tensor(coeff), requires_grad=False) - self.subsample = aii(self.config.get("subsample", False), bool) def __len__(self) -> int: - if self.subsample: - return min(self.subsample, self.num_samples) return self.num_samples def __getitem__(self, idx): - if self.data2train != "all": - idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) diff --git a/src/fairchem/core/scripts/make_lmdb_sizes.py b/src/fairchem/core/scripts/make_lmdb_sizes.py index 682fb58e65..ebf2122aeb 100644 --- a/src/fairchem/core/scripts/make_lmdb_sizes.py +++ b/src/fairchem/core/scripts/make_lmdb_sizes.py @@ -15,7 +15,7 @@ from tqdm import tqdm from fairchem.core.common.typing import assert_is_instance -from fairchem.core.datasets import SinglePointLmdbDataset, TrajectoryLmdbDataset +from fairchem.core.datasets.lmdb_dataset import LmdbDataset def get_data(index): @@ -28,14 +28,13 @@ def get_data(index): return index, natoms, neighbors -def main(args) -> None: +def make_lmdb_sizes(args) -> None: path = assert_is_instance(args.data_path, str) global dataset + dataset = LmdbDataset({"src": path}) if os.path.isdir(path): - dataset = TrajectoryLmdbDataset({"src": path}) outpath = os.path.join(path, "metadata.npz") elif os.path.isfile(path): - dataset = SinglePointLmdbDataset({"src": path}) outpath = os.path.join(os.path.dirname(path), "metadata.npz") output_indices = range(len(dataset)) @@ -63,7 +62,7 @@ def main(args) -> None: np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors) -if __name__ == "__main__": +def get_lmdb_sizes_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--data-path", @@ -77,5 +76,10 @@ def main(args) -> None: type=int, help="Num of workers to parallelize across", ) + return parser + + +if __name__ == "__main__": + parser = get_lmdb_sizes_parser() args: argparse.Namespace = parser.parse_args() - main(args) + make_lmdb_sizes(args) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index dce5099452..1c0c975f8a 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -7,6 +7,7 @@ from __future__ import annotations +import copy import datetime import errno import logging @@ -38,6 +39,7 @@ save_checkpoint, update_config, ) +from fairchem.core.datasets.base_dataset import create_dataset from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss @@ -241,12 +243,16 @@ def load_logger(self) -> None: def get_sampler( self, dataset, batch_size: int, shuffle: bool ) -> BalancedBatchSampler: - if "load_balancing" in self.config["optim"]: - balancing_mode = self.config["optim"]["load_balancing"] - force_balancing = True + balancing_mode = self.config["optim"].get("load_balancing", None) + on_error = self.config["optim"].get("load_balancing_on_error", None) + if balancing_mode is not None: + if on_error is None: + on_error = "raise" else: balancing_mode = "atoms" - force_balancing = False + + if on_error is None: + on_error = "warn_and_no_balance" if gp_utils.initialized(): num_replicas = gp_utils.get_dp_world_size() @@ -262,7 +268,7 @@ def get_sampler( device=self.device, mode=balancing_mode, shuffle=shuffle, - force_balancing=force_balancing, + on_error=on_error, seed=self.config["cmd"]["seed"], ) @@ -283,15 +289,26 @@ def load_datasets(self) -> None: self.val_loader = None self.test_loader = None + # This is hacky and scheduled to be removed next BE week + # move ['X_split_settings'] to ['splits'][X] + def convert_settings_to_split_settings(config, split_name): + config = copy.deepcopy(config) # make sure we dont modify the original + if f"{split_name}_split_settings" in config: + config["splits"] = { + split_name: config.pop(f"{split_name}_split_settings") + } + return config + # load train, val, test datasets if "src" in self.config["dataset"]: logging.info( f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}" ) - self.train_dataset = registry.get_dataset_class( - self.config["dataset"].get("format", "lmdb") - )(self.config["dataset"]) + self.train_dataset = create_dataset( + convert_settings_to_split_settings(self.config["dataset"], "train"), + "train", + ) self.train_sampler = self.get_sampler( self.train_dataset, self.config["optim"]["batch_size"], @@ -302,6 +319,16 @@ def load_datasets(self) -> None: self.train_sampler, ) + if ( + "first_n" in self.config["dataset"] + or "sample_n" in self.config["dataset"] + or "max_atom" in self.config["dataset"] + ): + logging.warn( + "Dataset attributes (first_n/sample_n/max_atom) passed to all datasets! Please don't do this, its dangerous!\n" + + "Add them under each dataset 'train_split_settings'/'val_split_settings'/'test_split_settings'" + ) + if "src" in self.config["val_dataset"]: if self.config["val_dataset"].get("use_train_settings", True): val_config = self.config["dataset"].copy() @@ -309,9 +336,9 @@ def load_datasets(self) -> None: else: val_config = self.config["val_dataset"] - self.val_dataset = registry.get_dataset_class( - val_config.get("format", "lmdb") - )(val_config) + self.val_dataset = create_dataset( + convert_settings_to_split_settings(val_config, "val"), "val" + ) self.val_sampler = self.get_sampler( self.val_dataset, self.config["optim"].get( @@ -331,9 +358,9 @@ def load_datasets(self) -> None: else: test_config = self.config["test_dataset"] - self.test_dataset = registry.get_dataset_class( - test_config.get("format", "lmdb") - )(test_config) + self.test_dataset = create_dataset( + convert_settings_to_split_settings(test_config, "test"), "test" + ) self.test_sampler = self.get_sampler( self.test_dataset, self.config["optim"].get( @@ -398,15 +425,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("train_on_free_atoms", True) + self.config["outputs"][target_name].get( + "train_on_free_atoms", True + ) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("eval_on_free_atoms", True) + self.config["outputs"][target_name].get( + "eval_on_free_atoms", True + ) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent @@ -429,11 +456,13 @@ def load_model(self) -> None: loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( - loader.dataset[0].x.shape[-1] - if loader - and hasattr(loader.dataset[0], "x") - and loader.dataset[0].x is not None - else None, + ( + loader.dataset[0].x.shape[-1] + if loader + and hasattr(loader.dataset[0], "x") + and loader.dataset[0].x is not None + else None + ), bond_feat_dim, 1, **self.config["model_attributes"], @@ -455,7 +484,9 @@ def load_model(self) -> None: self.logger.log_summary({"num_params": self.model.num_params}) if distutils.initialized() and not self.config["noddp"]: - self.model = DistributedDataParallel(self.model, device_ids=[self.device]) + self.model = DistributedDataParallel( + self.model, device_ids=None if self.cpu else [self.device] + ) @property def _unwrapped_model(self): @@ -639,9 +670,11 @@ def save( "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict() - if self.scheduler.scheduler_type != "Null" - else None, + "scheduler": ( + self.scheduler.scheduler.state_dict() + if self.scheduler.scheduler_type != "Null" + else None + ), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 9055d2d625..72c005893d 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -227,12 +227,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) - self.train_dataset.close_db() - if self.config.get("val_dataset", False): - self.val_dataset.close_db() - if self.config.get("test_dataset", False): - self.test_dataset.close_db() - def _forward(self, batch): out = self.model(batch.to(self.device)) @@ -648,7 +642,9 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[:-1] # np.split does not need last idx, assumes n-1:end + )[ + :-1 + ] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/common/test_data_parallel_batch_sampler.py b/tests/core/common/test_data_parallel_batch_sampler.py index 6205042652..6bd8effe26 100644 --- a/tests/core/common/test_data_parallel_batch_sampler.py +++ b/tests/core/common/test_data_parallel_batch_sampler.py @@ -1,9 +1,16 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from __future__ import annotations -import functools -import tempfile from contextlib import contextmanager from pathlib import Path +import functools +import tempfile from typing import TypeVar import numpy as np @@ -13,11 +20,13 @@ from fairchem.core.common.data_parallel import ( BalancedBatchSampler, StatefulDistributedSampler, + UnsupportedDatasetError, + _balanced_partition, ) +from fairchem.core.datasets.base_dataset import BaseDataset, DatasetMetadata DATA = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -SIZE_ATOMS = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] -SIZE_NEIGHBORS = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4] +SIZE_ATOMS = [2, 20, 3, 51, 10, 11, 41, 31, 13, 14] T_co = TypeVar("T_co", covariant=True) @@ -28,23 +37,57 @@ def _temp_file(name: str): yield Path(tmpdir) / name +@pytest.fixture() +def valid_dataset(): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return DatasetMetadata(natoms=np.array(SIZE_ATOMS)) + + def __init__(self, data) -> None: + super().__init__(config={}) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def get_metadata(self, attr, idx): + assert attr == "natoms" + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + + return _Dataset(DATA) + + @pytest.fixture() def valid_path_dataset(): - class _Dataset(Dataset[T_co]): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return self.metadata + def __init__(self, data, fpath: Path) -> None: + super().__init__(config={}) self.data = data - self.metadata_path = fpath + self.metadata = DatasetMetadata(natoms=np.load(fpath)["natoms"]) def __len__(self): return len(self.data) def __getitem__(self, idx): - return self.data[idx] + metadata_attr = getattr(self._metadata, "natoms") + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] with _temp_file("metadata.npz") as file: np.savez( natoms=np.array(SIZE_ATOMS), - neighbors=np.array(SIZE_NEIGHBORS), file=file, ) yield _Dataset(DATA, file) @@ -52,8 +95,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_path_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data self.metadata_path = Path("/tmp/does/not/exist.np") @@ -68,8 +113,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data def __len__(self): @@ -81,99 +128,68 @@ def __getitem__(self, idx): return _Dataset(DATA) -def test_lowercase(invalid_dataset) -> None: - sampler = BalancedBatchSampler( - dataset=invalid_dataset, +def test_lowercase(valid_dataset) -> None: + _ = BalancedBatchSampler( + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="ATOMS", - throw_on_error=False, - seed=0 - ) - assert sampler.mode == "atoms" - - sampler = BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="NEIGHBORS", - throw_on_error=False, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.mode == "neighbors" def test_invalid_mode(invalid_dataset) -> None: with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='natoms'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="natoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='neighbors'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, - mode="nneighbors", - throw_on_error=True, - seed=0 + mode="neighbors", + on_error="raise", + seed=0, ) def test_invalid_dataset(invalid_dataset) -> None: - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", - ): - BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 - ) - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. Batches will not be balanced, which can incur significant overhead!", - ): - BalancedBatchSampler( + with pytest.raises(UnsupportedDatasetError): + sampler = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) def test_invalid_path_dataset(invalid_path_dataset) -> None: with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -182,13 +198,11 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. Batches will not be balanced, which can incur significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -197,70 +211,59 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) -def test_valid_dataset(valid_path_dataset) -> None: +def test_valid_dataset(valid_dataset, valid_path_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - seed=0 - ) - assert (sampler.sizes == np.array(SIZE_ATOMS)).all() - - sampler = BalancedBatchSampler( - dataset=valid_path_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="neighbors", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert (sampler.sizes == np.array(SIZE_NEIGHBORS)).all() + assert ( + sampler._get_natoms(list(range(len(SIZE_ATOMS)))) == np.array(SIZE_ATOMS) + ).all() -def test_disabled(valid_path_dataset) -> None: +def test_disabled(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode=False, - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_single_node(valid_path_dataset) -> None: +def test_single_node(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=1, device=None, mode="atoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: +def test_stateful_distributed_sampler_noshuffle(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -272,12 +275,12 @@ def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: def test_stateful_distributed_sampler_vs_distributed_sampler( - valid_path_dataset, + valid_dataset, ) -> None: for seed in [0, 100, 200]: for batch_size in range(1, 4): stateful_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=2, @@ -286,7 +289,7 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( drop_last=True, ) sampler = DistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, rank=0, num_replicas=2, seed=seed, @@ -296,10 +299,10 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( assert list(stateful_sampler) == list(sampler) -def test_stateful_distributed_sampler(valid_path_dataset) -> None: +def test_stateful_distributed_sampler(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -309,7 +312,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: offset_step = 2 loaded_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, seed=0, @@ -319,7 +322,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(loaded_sampler) == original_order[offset_step * batch_size :] diff_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -328,14 +331,14 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(diff_sampler) != original_order -def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: - fullset = set(range(len(valid_path_dataset))) +def test_stateful_distributed_sampler_numreplicas(valid_dataset) -> None: + fullset = set(range(len(valid_dataset))) for drop_last in [True, False]: for num_replicas in range(1, 4): for batch_size in [1]: samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -360,14 +363,14 @@ def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: def test_stateful_distributed_sampler_numreplicas_drop_last( - valid_path_dataset, + valid_dataset, ) -> None: - fullset = set(range(len(valid_path_dataset))) + fullset = set(range(len(valid_dataset))) for num_replicas in range(1, 4): for batch_size in range(1, 4): samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -387,3 +390,15 @@ def test_stateful_distributed_sampler_numreplicas_drop_last( ) assert len(concat_idxs) == len(np.unique(concat_idxs)) assert len(concat_idxs) == (len(fullset) // num_replicas) * num_replicas + + +def test_balancedbatchsampler_partition(valid_dataset) -> None: + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS), 4) + == [[1, 9, 5, 0], [7, 8, 2], [3], [6, 4]] + ) + # test case with local batch size = 1, GPU0(rank0) always gets smallest + # we cant say anything about the remaining elements because it is a heap + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS)[[3, 6, 7, 1]], 4)[0] == [3] + ) diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index 9743d35a2f..05c7475d2c 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -7,42 +7,112 @@ gather_from_model_parallel_region, scatter_to_model_parallel_region, ) -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) def _dummy_call(x): return x -@pytest.mark.parametrize("world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])]) # noqa: PT006 + +@pytest.mark.parametrize( + "world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])] +) # noqa: PT006 def test_basic_setup(world_size: int, input: torch.Tensor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True) - output = spawn_multi_process(config, _dummy_call, input) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True + ) + output = spawn_multi_process( + config, _dummy_call, init_pg_and_rank_and_launch_test, input + ) assert output == expected_output -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1]), torch.Tensor([2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1]), torch.Tensor([2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1]), torch.Tensor([2, 3])], + ), + (2, 2, torch.Tensor([0, 1, 2]), [torch.Tensor([0, 1]), torch.Tensor([2])]), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])], + ), + ], ) -def test_scatter_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_to_model_parallel_region, input) +def test_scatter_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, + scatter_to_model_parallel_region, + init_pg_and_rank_and_launch_test, + input, + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) + def scatter_gather_fn(input: torch.Tensor, dim: int = 0): x = scatter_to_model_parallel_region(input, dim) return gather_from_model_parallel_region(x, dim) -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2]), torch.Tensor([0,1,2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ], ) -def test_gather_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_gather_fn, input) +def test_gather_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, scatter_gather_fn, init_pg_and_rank_and_launch_test, input + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) diff --git a/tests/core/datasets/conftest.py b/tests/core/datasets/conftest.py new file mode 100644 index 0000000000..eb7be94994 --- /dev/null +++ b/tests/core/datasets/conftest.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +from ase import build +from ase.calculators.singlepoint import SinglePointCalculator + + +@pytest.fixture(scope="module") +def structures(): + structures = [ + build.molecule("H2O", vacuum=4), + build.bulk("Cu"), + build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), + ] + for atoms in structures: + calc = SinglePointCalculator( + atoms, + energy=1, + forces=atoms.positions, + # there is an issue with ASE db when writing a db with 3x3 stress if is flattened to (9,) and then + # errors when trying to read it + stress=np.random.random((6,)), + ) + atoms.calc = calc + atoms.info["extensive_property"] = 3 * len(atoms) + atoms.info["tensor_property"] = np.random.random((6, 6)) + + structures[2].set_pbc(True) + return structures diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 01bd4ea2fc..7b114d877f 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -15,26 +15,6 @@ ) from fairchem.core.datasets.lmdb_database import LMDBDatabase -structures = [ - build.molecule("H2O", vacuum=4), - build.bulk("Cu"), - build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), -] -for atoms in structures: - calc = SinglePointCalculator( - atoms, - energy=1, - forces=atoms.positions, - # there is an issue with ASE db when writing a db with 3x3 stress it is flattened to (9,) and then - # errors when trying to read it - stress=np.random.random((6,)), - ) - atoms.calc = calc - atoms.info["extensive_property"] = 3 * len(atoms) - atoms.info["tensor_property"] = np.random.random((6, 6)) - -structures[2].set_pbc(True) - @pytest.fixture( params=[ @@ -46,7 +26,7 @@ "aselmdb_dataset", ], ) -def ase_dataset(request, tmp_path_factory): +def ase_dataset(request, structures, tmp_path_factory): tmp_path = tmp_path_factory.mktemp("dataset") mult = 1 a2g_args = { @@ -110,7 +90,7 @@ def ase_dataset(request, tmp_path_factory): return dataset, mult -def test_ase_dataset(ase_dataset): +def test_ase_dataset(ase_dataset, structures): dataset, mult = ase_dataset assert len(dataset) == mult * len(structures) for data in dataset: @@ -121,7 +101,7 @@ def test_ase_dataset(ase_dataset): assert isinstance(data.extensive_property, int) -def test_ase_read_dataset(tmp_path) -> None: +def test_ase_read_dataset(tmp_path, structures): # unfortunately there is currently no clean (already implemented) way to save atoms.info when saving # individual structures - so test separately for i, structure in enumerate(structures): @@ -137,13 +117,16 @@ def test_ase_read_dataset(tmp_path) -> None: assert len(dataset) == len(structures) data = dataset[0] del data - dataset.close_db() -def test_ase_metadata_guesser(ase_dataset) -> None: +def test_ase_get_metadata(ase_dataset): + assert ase_dataset[0].get_metadata("natoms", [0])[0] == 3 + + +def test_ase_metadata_guesser(ase_dataset): dataset, _ = ase_dataset - metadata = dataset.get_metadata() + metadata = dataset.sample_property_metadata() # Confirm energy metadata guessed properly assert metadata["targets"]["energy"]["extensive"] is False @@ -171,7 +154,7 @@ def test_ase_metadata_guesser(ase_dataset) -> None: assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" -def test_db_add_delete(tmp_path) -> None: +def test_db_add_delete(tmp_path, structures): database = db.connect(tmp_path / "asedb.db") for _i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) @@ -192,10 +175,9 @@ def test_db_add_delete(tmp_path) -> None: dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) assert len(dataset) == orig_len + len(new_structures) - 1 - dataset.close_db() -def test_ase_multiread_dataset(tmp_path) -> None: +def test_ase_multiread_dataset(tmp_path): atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] energies = np.linspace(1, 0, len(atoms_objects)) @@ -224,13 +206,17 @@ def test_ase_multiread_dataset(tmp_path) -> None: f.write(f"{tmp_path / 'test.traj'} {len(atoms_objects)}") dataset = AseReadMultiStructureDataset( - config={"index_file": str(tmp_path / "test_index_file")}, + config={ + "src": str(tmp_path), + "index_file": str(tmp_path / "test_index_file"), + }, ) assert len(dataset) == len(atoms_objects) dataset = AseReadMultiStructureDataset( config={ + "src": str(tmp_path), "index_file": str(tmp_path / "test_index_file"), "a2g_args": { "r_energy": True, diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py new file mode 100644 index 0000000000..d90271c53d --- /dev/null +++ b/tests/core/datasets/test_create_dataset.py @@ -0,0 +1,180 @@ +import os +import numpy as np +import pytest + +from fairchem.core.datasets import LMDBDatabase, create_dataset +from fairchem.core.datasets.base_dataset import BaseDataset +import tempfile +from fairchem.core.trainers.base_trainer import BaseTrainer + + +@pytest.fixture() +def lmdb_database(structures): + with tempfile.TemporaryDirectory() as tmpdirname: + num_atoms = [] + asedb_fn = f"{tmpdirname}/asedb.lmdb" + with LMDBDatabase(asedb_fn) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + num_atoms.append(len(atoms)) + np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms) + yield asedb_fn + + +def test_real_dataset_config(lmdb_database): + class TestTrainer(BaseTrainer): + def __init__(self, config): + self.config = config + + def train(self, x): + return None + + def get_sampler(self, *args, **kwargs): + return None + + def get_dataloader(self, *args, **kwargs): + return None + + config = { + "model_attributes": {}, + "optim": {"batch_size": 0}, + "dataset": { + "format": "ase_db", + "src": str(lmdb_database), + "first_n": 2, + "key_mapping": { + "y": "energy", + "force": "forces", + }, + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0, "stdev": 2.887317180633545}, + } + }, + }, + "val_dataset": {"src": str(lmdb_database)}, + "test_dataset": {}, + "relax_dataset": None, + } + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 2 + + # modify the config for split and see if it works as expected + config["dataset"].pop("first_n") + config["dataset"]["train_split_settings"] = {"first_n": 2} + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 3 + + +@pytest.mark.parametrize("max_atoms", [3, None]) +@pytest.mark.parametrize( + "key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)] +) +def test_create_dataset(key, value, max_atoms, structures, lmdb_database): + # now create a config + config = { + "format": "ase_db", + "src": str(lmdb_database), + key: value, + "max_atoms": max_atoms, + } + + dataset = create_dataset(config, split="train") + if max_atoms is not None: + structures = [s for s in structures if len(s) <= max_atoms] + assert all( + natoms <= max_atoms + for natoms in dataset.metadata.natoms[range(len(dataset))] + ) + if key == "first_n": # this assumes first_n are not shuffled + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures[:value], dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures[:value], dataset) + ) + elif key == "sample_n": + assert len(dataset) == value + else: # no shuffle all of them are in there + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures, dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures, dataset) + ) + + +# make sure we cant sample more than the number of elements in the dataset with sample_n +def test_sample_n_dataset(lmdb_database): + with pytest.raises(ValueError): + _ = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 100, + }, + split="train", + ) + + +def test_diff_seed_sample_dataset(lmdb_database): + dataset_a = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + assert (dataset_a.indices == dataset_b.indices).all() + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 1, + }, + split="train", + ) + assert not (dataset_a.indices == dataset_b.indices).all() + + +def test_del_dataset(): + class _Dataset(BaseDataset): + def __init__(self, fn) -> None: + super().__init__(config={}) + self.fn = fn + open(self.fn, "a").close() + + def __del__(self): + os.remove(self.fn) + + with tempfile.TemporaryDirectory() as tmpdirname: + fn = tmpdirname + "/test" + d = _Dataset(fn) + del d + assert not os.path.exists(fn) diff --git a/tests/core/datasets/test_lmdb_dataset.py b/tests/core/datasets/test_lmdb_dataset.py new file mode 100644 index 0000000000..f922e32ce3 --- /dev/null +++ b/tests/core/datasets/test_lmdb_dataset.py @@ -0,0 +1,29 @@ +from fairchem.core.datasets.base_dataset import create_dataset + +import numpy as np + +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes + + +def test_load_lmdb_dataset(tutorial_dataset_path): + + lmdb_path = str(tutorial_dataset_path / "s2ef/val_20") + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args(["--data-path", lmdb_path]) + make_lmdb_sizes(args) + + config = { + "format": "lmdb", + "src": lmdb_path, + } + + dataset = create_dataset(config, split="val") + + assert dataset.get_metadata("natoms", 0) == dataset[0].natoms + + all_metadata_natoms = np.array(dataset.get_metadata("natoms", range(len(dataset)))) + all_natoms = np.array([datapoint.natoms for datapoint in dataset]) + + assert (all_natoms == all_metadata_natoms).all() diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 8387d6e053..aea07201bd 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -9,6 +9,12 @@ import numpy as np import pytest import yaml +from fairchem.core.common.test_utils import ( + PGConfig, + init_env_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes from tensorboard.backend.event_processing.event_accumulator import EventAccumulator from fairchem.core._cli import Runner @@ -84,6 +90,7 @@ def _run_main( update_run_args_with=None, save_checkpoint_to=None, save_predictions_to=None, + world_size=0, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" @@ -91,6 +98,7 @@ def _run_main( yaml_config = yaml.safe_load(yaml_file) if update_dict_with is not None: yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) @@ -110,7 +118,19 @@ def _run_main( for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) config = build_config(args, override_args) - Runner()(config) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + spawn_multi_process( + pg_config, + Runner(distributed=True), + init_env_rank_and_launch_test, + config, + ) + else: + Runner()(config) if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") @@ -213,6 +233,72 @@ def test_train_and_predict( tutorial_val_src=tutorial_val_src, ) + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_balanced_batch_sampler_ddp( + self, world_size, ddp, configs, tutorial_val_src, torch_deterministic + ): + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args( + ["--data-path", str(tutorial_val_src)] + ) + make_lmdb_sizes(args) + + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1, "load_balancing": "atoms"}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + # train for a few steps and confirm same seeds get same results def test_different_seeds( self, @@ -290,9 +376,9 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet", 0.4, 0.06, id="gemnet"), - pytest.param("escn", 0.4, 0.06, id="escn"), - pytest.param("equiformer_v2", 0.4, 0.06, id="equiformer_v2"), + pytest.param("gemnet", 0.41, 0.06, id="gemnet"), + pytest.param("escn", 0.41, 0.06, id="escn"), + pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], ) def test_train_optimization( diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 34ed79ba2b..0034232cd2 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -18,7 +18,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.models.equiformer_v2.so3 import ( @@ -140,7 +144,9 @@ def test_energy_force_shape(self, snapshot): def test_ddp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 1 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -151,7 +157,9 @@ def test_ddp(self, snapshot): def test_gp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 2 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -225,4 +233,3 @@ def sign(x): embedding._l_primary(c) lp = embedding.embedding.clone() (test_matrix_lp == lp).all() - From 08b8c1ea9f1858d7f8f14df1718f26997c1ca799 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 2 Aug 2024 13:50:24 -0700 Subject: [PATCH 5/8] Move select models to backbone + heads format and add support for hydra (#782) * convert escn to bb + heads * convert dimenet to bb + heads * gemnet_oc to backbone and heads * add additional parameter backbone config to heads * gemnet to bb and heads * pain to bb and heads * add eqv2 bb+heads; move to canonical naming * fix calculator loading by leaving original class in code * fix issues with calculator loading * lint fixes * move dimenet++ heads to one * add test for dimenet * add painn test * hydra and tests for gemnetH dppH painnH * add escnH and equiformerv2H * add gemnetdt gemnetdtH * add smoke test for schnet and scn * remove old examples * typo * fix gemnet with grad forces; add test for this * remove unused params; add backbone and head interface; add typing * remove unused second order output heads * remove OC20 suffix from equiformer * remove comment * rename and lint * fix dimenet test * fix tests * refactor generate graph * refactor generate graph * fix a messy cherry pick * final messy fix * graph data interface in eqv2 * refactor * no bbconfigs * no more headconfigs in inits * rename hydra * fix eqV2 * update test configs * final fixes * fix tutorial * rm comments * fix test --------- Co-authored-by: lbluque Co-authored-by: Luis Barroso-Luque --- docs/legacy_tutorials/OCP_Tutorial.md | 2 +- src/fairchem/core/models/base.py | 137 +++++++- src/fairchem/core/models/dimenet_plus_plus.py | 147 +++++++-- .../core/models/equiformer_v2/__init__.py | 2 +- ...equiformer_v2_oc20.py => equiformer_v2.py} | 297 +++++++++++++++--- src/fairchem/core/models/escn/escn.py | 177 +++++++++-- src/fairchem/core/models/gemnet/gemnet.py | 195 +++++++++--- src/fairchem/core/models/gemnet_gp/gemnet.py | 64 ++-- .../core/models/gemnet_oc/gemnet_oc.py | 287 ++++++++++++++--- src/fairchem/core/models/painn/painn.py | 130 ++++++-- src/fairchem/core/models/schnet.py | 26 +- src/fairchem/core/models/scn/scn.py | 31 +- src/fairchem/core/trainers/base_trainer.py | 14 - tests/core/e2e/test_s2ef.py | 46 ++- tests/core/models/test_configs/test_dpp.yml | 50 +++ .../models/test_configs/test_dpp_hydra.yml | 55 ++++ .../test_configs/test_equiformerv2_hydra.yml | 98 ++++++ .../models/test_configs/test_escn_hydra.yml | 67 ++++ .../models/test_configs/test_gemnet_dt.yml | 79 +++++ .../test_configs/test_gemnet_dt_hydra.yml | 86 +++++ .../test_gemnet_dt_hydra_grad.yml | 84 +++++ .../{test_gemnet.yml => test_gemnet_oc.yml} | 0 .../test_configs/test_gemnet_oc_hydra.yml | 112 +++++++ .../test_gemnet_oc_hydra_grad.yml | 109 +++++++ tests/core/models/test_configs/test_painn.yml | 50 +++ .../models/test_configs/test_painn_hydra.yml | 58 ++++ .../core/models/test_configs/test_schnet.yml | 45 +++ tests/core/models/test_configs/test_scn.yml | 59 ++++ tests/core/models/test_dimenetpp.py | 3 - tests/core/models/test_equiformer_v2.py | 3 - tests/core/models/test_gemnet.py | 3 - tests/core/models/test_gemnet_oc.py | 3 - .../models/test_gemnet_oc_scaling_mismatch.py | 12 - tests/core/models/test_schnet.py | 2 +- 34 files changed, 2182 insertions(+), 351 deletions(-) rename src/fairchem/core/models/equiformer_v2/{equiformer_v2_oc20.py => equiformer_v2.py} (72%) create mode 100755 tests/core/models/test_configs/test_dpp.yml create mode 100755 tests/core/models/test_configs/test_dpp_hydra.yml create mode 100644 tests/core/models/test_configs/test_equiformerv2_hydra.yml create mode 100644 tests/core/models/test_configs/test_escn_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml rename tests/core/models/test_configs/{test_gemnet.yml => test_gemnet_oc.yml} (100%) create mode 100644 tests/core/models/test_configs/test_gemnet_oc_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml create mode 100644 tests/core/models/test_configs/test_painn.yml create mode 100644 tests/core/models/test_configs/test_painn_hydra.yml create mode 100755 tests/core/models/test_configs/test_schnet.yml create mode 100755 tests/core/models/test_configs/test_scn.yml diff --git a/docs/legacy_tutorials/OCP_Tutorial.md b/docs/legacy_tutorials/OCP_Tutorial.md index 8b5d4d522a..19fd93f6bc 100644 --- a/docs/legacy_tutorials/OCP_Tutorial.md +++ b/docs/legacy_tutorials/OCP_Tutorial.md @@ -1807,7 +1807,7 @@ Similarly, to predict forces, we pass edge features through a fully-connected la @registry.register_model("simple") class SimpleAtomEdgeModel(torch.nn.Module): - def __init__(self, num_atoms, bond_feat_dim, num_targets, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): + def __init__(self, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): super().__init__() self.radial_basis = RadialBasis( diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 42790643a9..eb8c9d543c 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -8,27 +8,42 @@ from __future__ import annotations import logging +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch -import torch.nn as nn +from torch import nn from torch_geometric.nn import radius_graph +from fairchem.core.common.registry import registry from fairchem.core.common.utils import ( compute_neighbors, get_pbc_distances, radius_graph_pbc, ) +if TYPE_CHECKING: + from torch_geometric.data import Batch -class BaseModel(nn.Module): - def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None) -> None: - super().__init__() - self.num_atoms = num_atoms - self.bond_feat_dim = bond_feat_dim - self.num_targets = num_targets - def forward(self, data): - raise NotImplementedError +@dataclass +class GraphData: + """Class to keep graph attributes nicely packaged.""" + + edge_index: torch.Tensor + edge_distance: torch.Tensor + edge_distance_vec: torch.Tensor + cell_offsets: torch.Tensor + offset_distances: torch.Tensor + neighbors: torch.Tensor + batch_full: torch.Tensor # used for GP functionality + atomic_numbers_full: torch.Tensor # used for GP functionality + node_offset: int = 0 # used for GP functionality + + +class GraphModelMixin: + """Mixin Model class implementing some general convenience properties and methods.""" def generate_graph( self, @@ -109,13 +124,16 @@ def generate_graph( ) neighbors = compute_neighbors(data, edge_index) - return ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - cell_offset_distances, - neighbors, + return GraphData( + edge_index=edge_index, + edge_distance=edge_dist, + edge_distance_vec=distance_vec, + cell_offsets=cell_offsets, + offset_distances=cell_offset_distances, + neighbors=neighbors, + node_offset=0, + batch_full=data.batch, + atomic_numbers_full=data.atomic_numbers.long(), ) @property @@ -130,3 +148,90 @@ def no_weight_decay(self) -> list: if "embedding" in name or "frequencies" in name or "bias" in name: no_wd_list.append(name) return no_wd_list + + +class HeadInterface(metaclass=ABCMeta): + @abstractmethod + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Head forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + emb: dict[str->torch.Tensor] + Embeddings of the input as generated by the backbone + + Returns + ------- + outputs: dict[str->torch.Tensor] + Return one or more targets generated by this head + """ + return + + +class BackboneInterface(metaclass=ABCMeta): + @abstractmethod + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + """Backbone forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + + Returns + ------- + embedding: dict[str->torch.Tensor] + Return backbone embeddings for the given input + """ + return + + +@registry.register_model("hydra") +class HydraModel(nn.Module, GraphModelMixin): + def __init__( + self, + backbone: dict, + heads: dict, + otf_graph: bool = True, + ): + super().__init__() + self.otf_graph = otf_graph + + backbone_model_name = backbone.pop("model") + self.backbone: BackboneInterface = registry.get_model_class( + backbone_model_name + )( + **backbone, + ) + + # Iterate through outputs_cfg and create heads + self.output_heads: dict[str, HeadInterface] = {} + + head_names_sorted = sorted(heads.keys()) + for head_name in head_names_sorted: + head_config = heads[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) + + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) + + self.output_heads = torch.nn.ModuleDict(self.output_heads) + + def forward(self, data: Batch): + emb = self.backbone(data) + # Predict all output properties for all structures in the batch for now. + out = {} + for k in self.output_heads: + out.update(self.output_heads[k](data, emb)) + + return out diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index 296a77bbba..aa08ea0672 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -34,6 +34,8 @@ from __future__ import annotations +import typing + import torch from torch import nn from torch_geometric.nn.inits import glorot_orthogonal @@ -49,7 +51,10 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch try: import sympy as sym @@ -57,7 +62,7 @@ sym = None -class InteractionPPBlock(torch.nn.Module): +class InteractionPPBlock(nn.Module): def __init__( self, hidden_channels: int, @@ -90,11 +95,11 @@ def __init__( self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) # Residual layers before and after skip connection. - self.layers_before_skip = torch.nn.ModuleList( + self.layers_before_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)] ) self.lin = nn.Linear(hidden_channels, hidden_channels) - self.layers_after_skip = torch.nn.ModuleList( + self.layers_after_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)] ) @@ -153,7 +158,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji): return h -class OutputPPBlock(torch.nn.Module): +class OutputPPBlock(nn.Module): def __init__( self, num_radial: int, @@ -169,7 +174,7 @@ def __init__( self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) - self.lins = torch.nn.ModuleList() + self.lins = nn.ModuleList() for _ in range(num_layers): self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) @@ -193,7 +198,7 @@ def forward(self, x, rbf, i, num_nodes: int | None = None): return self.lin(x) -class DimeNetPlusPlus(torch.nn.Module): +class DimeNetPlusPlus(nn.Module): r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. Args: @@ -241,7 +246,6 @@ def __init__( act = activation_resolver(act) super().__init__() - self.cutoff = cutoff if sym is None: @@ -256,7 +260,7 @@ def __init__( self.emb = EmbeddingBlock(num_radial, hidden_channels, act) - self.output_blocks = torch.nn.ModuleList( + self.output_blocks = nn.ModuleList( [ OutputPPBlock( num_radial, @@ -270,7 +274,7 @@ def __init__( ] ) - self.interaction_blocks = torch.nn.ModuleList( + self.interaction_blocks = nn.ModuleList( [ InteractionPPBlock( hidden_channels, @@ -330,13 +334,42 @@ def forward(self, z, pos, batch=None): raise NotImplementedError +@registry.register_model("dimenetplusplus_energy_and_force_head") +class DimeNetPlusPlusWrapEnergyAndForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.regress_forces = backbone.regress_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + outputs = { + "energy": ( + emb["P"].sum(dim=0) + if data.batch is None + else scatter(emb["P"], data.batch, dim=0) + ) + } + if self.regress_forces: + outputs["forces"] = ( + -1 + * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] + ) + ) + return outputs + + @registry.register_model("dimenetplusplus") -class DimeNetPlusPlusWrap(DimeNetPlusPlus, BaseModel): +class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin): def __init__( self, - num_atoms: int, - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, regress_forces: bool = True, hidden_channels: int = 128, @@ -353,7 +386,6 @@ def __init__( num_after_skip: int = 2, num_output_layers: int = 3, ) -> None: - self.num_targets = num_targets self.regress_forces = regress_forces self.use_pbc = use_pbc self.cutoff = cutoff @@ -362,7 +394,7 @@ def __init__( super().__init__( hidden_channels=hidden_channels, - out_channels=num_targets, + out_channels=1, num_blocks=num_blocks, int_emb_size=int_emb_size, basis_emb_size=basis_emb_size, @@ -380,22 +412,15 @@ def __init__( def _forward(self, data): pos = data.pos batch = data.batch - ( - edge_index, - dist, - _, - cell_offsets, - offsets, - neighbors, - ) = self.generate_graph(data) - - data.edge_index = edge_index - data.cell_offsets = cell_offsets - data.neighbors = neighbors - j, i = edge_index + graph = self.generate_graph(data) + + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( - edge_index, + graph.edge_index, data.cell_offsets, num_nodes=data.atomic_numbers.size(0), ) @@ -405,8 +430,8 @@ def _forward(self, data): pos_j = pos[idx_j].detach() if self.use_pbc: pos_ji, pos_kj = ( - pos[idx_j].detach() - pos_i + offsets[idx_ji], - pos[idx_k].detach() - pos_j + offsets[idx_kj], + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], ) else: pos_ji, pos_kj = ( @@ -418,8 +443,8 @@ def _forward(self, data): b = torch.cross(pos_ji, pos_kj).norm(dim=-1) angle = torch.atan2(b, a) - rbf = self.rbf(dist) - sbf = self.sbf(dist, angle, idx_kj) + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) # Embedding block. x = self.emb(data.atomic_numbers.long(), rbf, i, j) @@ -459,3 +484,57 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("dimenetplusplus_backbone") +class DimeNetPlusPlusWrapBackbone(DimeNetPlusPlusWrap, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + if self.regress_forces: + data.pos.requires_grad_(True) + pos = data.pos + graph = self.generate_graph(data) + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index + + _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( + graph.edge_index, + data.cell_offsets, + num_nodes=data.atomic_numbers.size(0), + ) + + # Calculate angles. + pos_i = pos[idx_i].detach() + pos_j = pos[idx_j].detach() + if self.use_pbc: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], + ) + else: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i, + pos[idx_k].detach() - pos_j, + ) + + a = (pos_ji * pos_kj).sum(dim=-1) + b = torch.cross(pos_ji, pos_kj).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) + + # Embedding block. + x = self.emb(data.atomic_numbers.long(), rbf, i, j) + P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) + + # Interaction blocks. + for interaction_block, output_block in zip( + self.interaction_blocks, self.output_blocks[1:] + ): + x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) + P += output_block(x, rbf, i, num_nodes=pos.size(0)) + + return {"P": P, "edge_embedding": x, "edge_idx": i} diff --git a/src/fairchem/core/models/equiformer_v2/__init__.py b/src/fairchem/core/models/equiformer_v2/__init__.py index 424b64f9ed..720f890f65 100644 --- a/src/fairchem/core/models/equiformer_v2/__init__.py +++ b/src/fairchem/core/models/equiformer_v2/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .equiformer_v2_oc20 import EquiformerV2_OC20 as EquiformerV2 +from .equiformer_v2 import EquiformerV2 __all__ = ["EquiformerV2"] diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py similarity index 72% rename from src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py rename to src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 8edf81319c..e2625eadaf 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -10,13 +10,15 @@ from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): pass +import typing + from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding @@ -42,13 +44,18 @@ TransBlockV2, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + + from fairchem.core.models.base import GraphData + # Statistics of IS2RE 100K _AVG_NUM_NODES = 77.81317 _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 @registry.register_model("equiformer_v2") -class EquiformerV2_OC20(BaseModel): +class EquiformerV2(nn.Module, GraphModelMixin): """ Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation @@ -108,9 +115,6 @@ class EquiformerV2_OC20(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = True, @@ -436,23 +440,12 @@ def forward(self, data): self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, ) - data_batch_full = data.batch data_batch = data.batch - atomic_numbers_full = atomic_numbers - node_offset = 0 if gp_utils.initialized(): ( atomic_numbers, @@ -462,12 +455,17 @@ def forward(self, data): edge_distance, edge_distance_vec, ) = self._init_gp_partitions( - atomic_numbers_full, - data_batch_full, - edge_index, - edge_distance, - edge_distance_vec, + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + ############################################################### # Entering Graph Parallel Region # after this point, if using gp, then node, edge tensors are split @@ -485,7 +483,9 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -496,7 +496,6 @@ def forward(self, data): ############################################################### # Init per node representations using an atomic number based embedding - offset = 0 x = SO3_Embedding( len(atomic_numbers), self.lmax_list, @@ -519,27 +518,27 @@ def forward(self, data): offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) # Edge encoding (distance and atom edge) - edge_distance = self.distance_expansion(edge_distance) + graph.edge_distance = self.distance_expansion(graph.edge_distance) if self.share_atom_edge_embedding and self.use_atom_edge_embedding: - source_element = atomic_numbers_full[ - edge_index[0] + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] ] # Source atom atomic number - target_element = atomic_numbers_full[ - edge_index[1] + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] ] # Target atom atomic number source_embedding = self.source_embedding(source_element) target_embedding = self.target_embedding(target_element) - edge_distance = torch.cat( - (edge_distance, source_embedding, target_embedding), dim=1 + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 ) # Edge-degree embedding edge_degree = self.edge_degree_embedding( - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, len(atomic_numbers), - node_offset, + graph.node_offset, ) x.embedding = x.embedding + edge_degree.embedding @@ -550,11 +549,11 @@ def forward(self, data): for i in range(self.num_layers): x = self.blocks[i]( x, # SO3_Embedding - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, batch=data_batch, # for GraphDropPath - node_offset=node_offset, + node_offset=graph.node_offset, ) # Final layer norm @@ -572,7 +571,7 @@ def forward(self, data): device=node_energy.device, dtype=node_energy.dtype, ) - energy.index_add_(0, data_batch_full, node_energy.view(-1)) + energy.index_add_(0, graph.batch_full, node_energy.view(-1)) energy = energy / self.avg_num_nodes # Add the per-atom linear references to the energy. @@ -594,8 +593,8 @@ def forward(self, data): with torch.cuda.amp.autocast(False): energy = energy.to(self.energy_lin_ref.dtype).index_add( 0, - data_batch_full, - self.energy_lin_ref[atomic_numbers_full], + graph.batch_full, + self.energy_lin_ref[graph.atomic_numbers_full], ) outputs = {"energy": energy} @@ -605,10 +604,10 @@ def forward(self, data): if self.regress_forces: forces = self.force_block( x, - atomic_numbers_full, - edge_distance, - edge_index, - node_offset=node_offset, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + node_offset=graph.node_offset, ) forces = forces.embedding.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() @@ -678,3 +677,209 @@ def no_weight_decay(self) -> set: no_wd_list.append(global_parameter_name) return set(no_wd_list) + + +@registry.register_model("equiformer_v2_backbone") +class EquiformerV2Backbone(EquiformerV2, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + return {"node_embedding": x, "graph": graph} + + +@registry.register_model("equiformer_v2_energy_head") +class EquiformerV2EnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.avg_num_nodes = backbone.avg_num_nodes + self.energy_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): + node_energy = self.energy_block(emb["node_embedding"]) + node_energy = node_energy.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) + energy = torch.zeros( + len(data.natoms), + device=node_energy.device, + dtype=node_energy.dtype, + ) + energy.index_add_(0, data.batch, node_energy.view(-1)) + return {"energy": energy / self.avg_num_nodes} + + +@registry.register_model("equiformer_v2_force_head") +class EquiformerV2ForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor]): + forces = self.force_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + forces = forces.embedding.narrow(1, 1, 3) + forces = forces.view(-1, 3).contiguous() + if gp_utils.initialized(): + forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + return {"forces": forces} diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0ec66b9dba..dfa872c398 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -10,13 +10,17 @@ import contextlib import logging import time +import typing import torch import torch.nn as nn +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.escn.so3 import ( CoefficientMapping, SO3_Embedding, @@ -36,7 +40,7 @@ @registry.register_model("escn") -class eSCN(BaseModel): +class eSCN(nn.Module, GraphModelMixin): """Equivariant Spherical Channel Network Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs @@ -64,9 +68,6 @@ class eSCN(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -79,7 +80,6 @@ def __init__( sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, - use_grid: bool = True, num_sphere_samples: int = 128, distance_function: str = "gaussian", basis_width_scalar: float = 1.0, @@ -232,22 +232,16 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() @@ -290,8 +284,8 @@ def forward(self, data): x_message = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -304,8 +298,8 @@ def forward(self, data): x = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -421,6 +415,149 @@ def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) +@registry.register_model("escn_backbone") +class eSCNBackbone(eSCN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + device = data.pos.device + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + + atomic_numbers = data.atomic_numbers.long() + num_atoms = len(atomic_numbers) + + graph = self.generate_graph(data) + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + self.SO3_edge_rot = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + offset = 0 + x = SO3_Embedding( + num_atoms, + self.lmax_list, + self.sphere_channels, + device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l=0,m=0 coefficients for each resolution + for i in range(self.num_resolutions): + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer + mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + if i > 0: + x_message = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Residual layer for all layers past the first + x.embedding = x.embedding + x_message.embedding + + else: + # No residual for the first layer + x = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. + # These values are fed into the output blocks. + x_pt = torch.tensor([], device=device) + offset = 0 + # Compute the embedding values at every sampled point on the sphere + for i in range(self.num_resolutions): + num_coefficients = int((x.lmax_list[i] + 1) ** 2) + x_pt = torch.cat( + [ + x_pt, + torch.einsum( + "abc, pb->apc", + x.embedding[:, offset : offset + num_coefficients], + self.sphharm_weights[i], + ).contiguous(), + ], + dim=2, + ) + offset = offset + num_coefficients + + x_pt = x_pt.view(-1, self.sphere_channels_all) + + return {"sphere_values": x_pt, "sphere_points": self.sphere_points} + + +@registry.register_model("escn_energy_head") +class eSCNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + # Output blocks for energy and forces + self.energy_block = EnergyBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + node_energy = self.energy_block(emb["sphere_values"]) + energy = torch.zeros(len(data.natoms), device=data.pos.device) + energy.index_add_(0, data.batch, node_energy.view(-1)) + # Scale energy to help balance numerical precision w.r.t. forces + return {"energy": energy * 0.001} + + +@registry.register_model("escn_force_head") +class eSCNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = ForceBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + return {"forces": self.force_block(emb["sphere_values"], emb["sphere_points"])} + + class LayerBlock(torch.nn.Module): """ Layer block: Perform one layer (message passing and aggregation) of the GNN diff --git a/src/fairchem/core/models/gemnet/gemnet.py b/src/fairchem/core/models/gemnet/gemnet.py index e719c219b8..59b3eda08f 100644 --- a/src/fairchem/core/models/gemnet/gemnet.py +++ b/src/fairchem/core/models/gemnet/gemnet.py @@ -7,14 +7,20 @@ from __future__ import annotations +import typing + import numpy as np import torch +import torch.nn as nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -28,17 +34,12 @@ @registry.register_model("gemnet_t") -class GemNetT(BaseModel): +class GemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -94,9 +95,6 @@ class GemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -132,7 +130,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -235,7 +232,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -421,18 +418,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.edge_distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -447,10 +436,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -530,7 +519,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -549,7 +538,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -557,11 +546,11 @@ def forward(self, data): if self.extensive: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} @@ -569,30 +558,18 @@ def forward(self, data): if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t @@ -601,3 +578,129 @@ def forward(self, data): @property def num_params(self): return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_t_backbone") +class GemNetTBackbone(GemNetT, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + edge_index, + neighbors, + D_st, + V_st, + id_swap, + id3_ba, + id3_ca, + id3_ragged_idx, + ) = self.generate_interaction_graph(data) + idx_s, idx_t = edge_index + + # Calculate triplet angles + cosφ_cab = inner_product_normalized(V_st[id3_ca], V_st[id3_ba]) + rad_cbf3, cbf3 = self.cbf_basis3(D_st, cosφ_cab, id3_ca) + + rbf = self.radial_basis(D_st) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, rbf, idx_s, idx_t) # (nEdges, emb_size_edge) + + rbf3 = self.mlp_rbf3(rbf) + cbf3 = self.mlp_cbf3(rad_cbf3, cbf3, id3_ca, id3_ragged_idx) + + rbf_h = self.mlp_rbf_h(rbf) + rbf_out = self.mlp_rbf_out(rbf) + + E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + rbf3=rbf3, + cbf3=cbf3, + id3_ragged_idx=id3_ragged_idx, + id_swap=id_swap, + id3_ba=id3_ba, + id3_ca=id3_ca, + rbf_h=rbf_h, + idx_s=idx_s, + idx_t=idx_t, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + F_st += F + E_t += E + return { + "F_st": F_st, + "E_t": E_t, + "edge_vec": V_st, + "edge_idx": idx_t, + "node_embedding": h, + "edge_embedding": m, + } + + +@registry.register_model("gemnet_t_energy_and_grad_force_head") +class GemNetTEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.extensive = backbone.extensive + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t} + + if self.regress_forces and not self.direct_forces: + outputs["forces"] = -torch.autograd.grad( + E_t.sum(), data.pos, create_graph=True + )[0] + # (nAtoms, 3) + return outputs + + +@registry.register_model("gemnet_t_force_head") +class GemNetTForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # map forces in edge directions + F_st_vec = emb["F_st"][:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.size(0), + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (nAtoms, 3) diff --git a/src/fairchem/core/models/gemnet_gp/gemnet.py b/src/fairchem/core/models/gemnet_gp/gemnet.py index 81fbd40694..a75756dcc1 100644 --- a/src/fairchem/core/models/gemnet_gp/gemnet.py +++ b/src/fairchem/core/models/gemnet_gp/gemnet.py @@ -9,13 +9,14 @@ import numpy as np import torch +from torch import nn from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -29,17 +30,12 @@ @registry.register_model("gp_gemnet_t") -class GraphParallelGemNetT(BaseModel): +class GraphParallelGemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -95,9 +91,6 @@ class GraphParallelGemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -134,7 +127,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -239,7 +231,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -415,18 +407,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -441,10 +425,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -563,7 +547,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -585,7 +569,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -601,41 +585,29 @@ def forward(self, data): E_t = gp_utils.gather_from_model_parallel_region(E_t, dim=0) E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} if self.regress_forces: if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t_full, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index e1176d00c9..0aea3d81ba 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -7,9 +7,11 @@ from __future__ import annotations import logging +import typing import numpy as np import torch +import torch.nn as nn from torch_scatter import segment_coo from fairchem.core.common.registry import registry @@ -18,7 +20,7 @@ get_max_neighbors_mask, scatter_det, ) -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .initializers import get_initializer @@ -40,17 +42,15 @@ repeat_blocks, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + @registry.register_model("gemnet_oc") -class GemNetOC(BaseModel): +class GemNetOC(nn.Module, GraphModelMixin): """ Arguments --------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -179,9 +179,6 @@ class GemNetOC(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -249,11 +246,11 @@ def __init__( super().__init__() if len(kwargs) > 0: logging.warning(f"Unrecognized arguments: {list(kwargs.keys())}") - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive + self.activation = activation self.atom_edge_interaction = atom_edge_interaction self.edge_atom_interaction = edge_atom_interaction self.atom_interaction = atom_interaction @@ -357,7 +354,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) - self.out_energy = Dense(emb_size_atom, num_targets, bias=False, activation=None) + self.out_energy = Dense(emb_size_atom, 1, bias=False, activation=None) if direct_forces: out_mlp_F = [ Dense( @@ -373,9 +370,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) - self.out_forces = Dense( - emb_size_edge, num_targets, bias=False, activation=None - ) + self.out_forces = Dense(emb_size_edge, 1, bias=False, activation=None) out_initializer = get_initializer(output_init) self.out_energy.reset_parameters(out_initializer) @@ -870,15 +865,7 @@ def subselect_edges( def generate_graph_dict(self, data, cutoff, max_neighbors): """Generate a radius/nearest neighbor graph.""" otf_graph = cutoff > 6 or max_neighbors > 50 or self.otf_graph - - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - num_neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, cutoff=cutoff, max_neighbors=max_neighbors, @@ -886,15 +873,15 @@ def generate_graph_dict(self, data, cutoff, max_neighbors): ) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - edge_vector = -distance_vec / edge_dist[:, None] - cell_offsets = -cell_offsets # a - c + offset + edge_vector = -graph.edge_distance_vec / graph.edge_distance[:, None] + cell_offsets = -graph.cell_offsets # a - c + offset graph = { - "edge_index": edge_index, - "distance": edge_dist, + "edge_index": graph.edge_index, + "distance": graph.edge_distance, "vector": edge_vector, "cell_offset": cell_offsets, - "num_neighbors": num_neighbors, + "num_neighbors": graph.neighbors, } # Mask interaction edges if required @@ -1285,11 +1272,11 @@ def forward(self, data): if self.extensive: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) E_t = E_t.squeeze(1) # (num_molecules) outputs = {"energy": E_t} @@ -1308,19 +1295,19 @@ def forward(self, data): dim=0, dim_size=int(nEdges / 2), reduce="mean", - ) # (nEdges/2, num_targets) - F_st = F_st[id_undir] # (nEdges, num_targets) + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) # map forces in edge directions F_st_vec = F_st[:, :, None] * main_graph["vector"][:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter_det( F_st_vec, idx_t, dim=0, dim_size=num_atoms, reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) else: F_t = self.force_scaler.calc_forces_and_update(E_t, pos) @@ -1333,3 +1320,233 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_oc_backbone") +class GemNetOCBackbone(GemNetOC, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + num_atoms = atomic_numbers.shape[0] + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + main_graph, + a2a_graph, + a2ee2a_graph, + qint_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ) = self.get_graphs_and_indices(data) + _, idx_t = main_graph["edge_index"] + + ( + basis_rad_raw, + basis_atom_update, + basis_output, + bases_qint, + bases_e2e, + bases_a2e, + bases_e2a, + basis_a2a_rad, + ) = self.get_bases( + main_graph=main_graph, + a2a_graph=a2a_graph, + a2ee2a_graph=a2ee2a_graph, + qint_graph=qint_graph, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + num_atoms=num_atoms, + ) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"]) + # (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[0](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E, xs_F = [x_E], [x_F] + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + bases_qint=bases_qint, + bases_e2e=bases_e2e, + bases_a2e=bases_a2e, + bases_e2a=bases_e2a, + basis_a2a_rad=basis_a2a_rad, + basis_atom_update=basis_atom_update, + edge_index_main=main_graph["edge_index"], + a2ee2a_graph=a2ee2a_graph, + a2a_graph=a2a_graph, + id_swap=id_swap, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E.append(x_E) + xs_F.append(x_F) + + return { + "xs_E": xs_E, + "xs_F": xs_F, + "edge_vec": main_graph["vector"], + "edge_idx": idx_t, + "num_neighbors": main_graph["num_neighbors"], + } + + +@registry.register_model("gemnet_oc_energy_and_grad_force_head") +class GemNetOCEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__( + self, + backbone: BackboneInterface, + num_global_out_layers: int, + output_init: str = "HeOrthogonal", + ): + super().__init__() + self.extensive = backbone.extensive + + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + self.force_scaler = backbone.force_scaler + + out_mlp_E = [ + Dense( + backbone.atom_emb.emb_size * (len(backbone.int_blocks) + 1), + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) + + self.out_energy = Dense( + backbone.atom_emb.emb_size, + 1, + bias=False, + activation=None, + ) + + out_initializer = get_initializer(output_init) + self.out_energy.reset_parameters(out_initializer) + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # Global output block for final predictions + x_E = self.out_mlp_E(torch.cat(emb["xs_E"], dim=-1)) + with torch.cuda.amp.autocast(False): + E_t = self.out_energy(x_E.float()) + + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t.squeeze(1)} # (num_molecules) + + if self.regress_forces and not self.direct_forces: + F_t = self.force_scaler.calc_forces_and_update(outputs["energy"], data.pos) + outputs["forces"] = F_t.squeeze(1) + return outputs + + +@registry.register_model("gemnet_oc_force_head") +class GemNetOCForceHead(nn.Module, HeadInterface): + def __init__( + self, backbone, num_global_out_layers: int, output_init: str = "HeOrthogonal" + ): + super().__init__() + + self.direct_forces = backbone.direct_forces + self.forces_coupled = backbone.forces_coupled + + emb_size_edge = backbone.edge_emb.dense.linear.out_features + if self.direct_forces: + out_mlp_F = [ + Dense( + emb_size_edge * (len(backbone.int_blocks) + 1), + emb_size_edge, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + emb_size_edge, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) + self.out_forces = Dense( + emb_size_edge, + 1, + bias=False, + activation=None, + ) + out_initializer = get_initializer(output_init) + self.out_forces.reset_parameters(out_initializer) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) + with torch.cuda.amp.autocast(False): + F_st = self.out_forces(x_F.float()) + + if self.forces_coupled: # enforce F_st = F_ts + nEdges = emb["edge_idx"].shape[0] + id_undir = repeat_blocks( + emb["num_neighbors"] // 2, + repeats=2, + continuous_indexing=True, + ) + F_st = scatter_det( + F_st, + id_undir, + dim=0, + dim_size=int(nEdges / 2), + reduce="mean", + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) + + # map forces in edge directions + F_st_vec = F_st[:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter_det( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.long().shape[0], + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (num_atoms, 3) + return {} diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index 8843f02b2e..ec9e9f465c 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -32,15 +32,19 @@ from __future__ import annotations import math +import typing import torch from torch import nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_geometric.nn import MessagePassing from torch_scatter import scatter, segment_coo from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.gemnet.layers.base_layers import ScaledSiLU from fairchem.core.models.gemnet.layers.embedding_block import AtomEmbedding from fairchem.core.models.gemnet.layers.radial_basis import RadialBasis @@ -51,7 +55,7 @@ @registry.register_model("painn") -class PaiNN(BaseModel): +class PaiNN(nn.Module, GraphModelMixin): r"""PaiNN model based on the description in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra, https://arxiv.org/abs/2102.03150. @@ -59,9 +63,6 @@ class PaiNN(BaseModel): def __init__( self, - num_atoms: int, - bond_feat_dim: int, - num_targets: int, hidden_channels: int = 512, num_layers: int = 6, num_rbf: int = 128, @@ -310,23 +311,16 @@ def symmetrize_edges( ) def generate_graph_values(self, data): - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # Unit vectors pointing from edge_index[1] to edge_index[0], # i.e., edge_index[0] - edge_index[1] divided by the norm. # make sure that the distances are not close to zero before dividing - mask_zero = torch.isclose(edge_dist, torch.tensor(0.0), atol=1e-6) - edge_dist[mask_zero] = 1.0e-6 - edge_vector = distance_vec / edge_dist[:, None] + mask_zero = torch.isclose(graph.edge_distance, torch.tensor(0.0), atol=1e-6) + graph.edge_distance[mask_zero] = 1.0e-6 + edge_vector = graph.edge_distance_vec / graph.edge_distance[:, None] - empty_image = neighbors == 0 + empty_image = graph.neighbors == 0 if torch.any(empty_image): raise ValueError( f"An image has no neighbors: id={data.id[empty_image]}, " @@ -342,11 +336,11 @@ def generate_graph_values(self, data): [edge_vector], id_swap, ) = self.symmetrize_edges( - edge_index, - cell_offsets, - neighbors, + graph.edge_index, + graph.cell_offsets, + graph.neighbors, data.batch, - [edge_dist], + [graph.edge_distance], [edge_vector], ) @@ -436,6 +430,50 @@ def __repr__(self) -> str: ) +@registry.register_model("painn_backbone") +class PaiNNBackbone(PaiNN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data) -> dict[str, torch.Tensor]: + pos = data.pos + z = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos = pos.requires_grad_(True) + + ( + edge_index, + neighbors, + edge_dist, + edge_vector, + id_swap, + ) = self.generate_graph_values(data) + + assert z.dim() == 1 + assert z.dtype == torch.long + + edge_rbf = self.radial_basis(edge_dist) # rbf * envelope + + x = self.atom_emb(z) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + + #### Interaction blocks ############################################### + + for i in range(self.num_layers): + dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) + + x = x + dx + vec = vec + dvec + x = x * self.inv_sqrt_2 + + dx, dvec = self.update_layers[i](x, vec) + + x = x + dx + vec = vec + dvec + x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) + + return {"node_embedding": x, "node_vec": vec} + + class PaiNNMessage(MessagePassing): def __init__( self, @@ -625,3 +663,53 @@ def forward(self, x, v): x = self.act(x) return x, v + + +@registry.register_model("painn_energy_head") +class PaiNNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.out_energy = nn.Sequential( + nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2), + ScaledSiLU(), + nn.Linear(backbone.hidden_channels // 2, 1), + ) + + nn.init.xavier_uniform_(self.out_energy[0].weight) + self.out_energy[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_energy[2].weight) + self.out_energy[2].bias.data.fill_(0) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + per_atom_energy = self.out_energy(emb["node_embedding"]).squeeze(1) + return {"energy": scatter(per_atom_energy, data.batch, dim=0)} + + +@registry.register_model("painn_force_head") +class PaiNNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + if self.direct_forces: + self.out_forces = PaiNNOutput(backbone.hidden_channels) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + forces = self.out_forces(emb["node_embedding"], emb["node_vec"]) + else: + forces = ( + -1 + * torch.autograd.grad( + emb["node_embedding"], + data.pos, + grad_outputs=torch.ones_like(emb["node_embedding"]), + create_graph=True, + )[0] + ) + return {"forces": forces} diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 2f89c17e1f..5ca70a354e 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -13,11 +13,11 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin @registry.register_model("schnet") -class SchNetWrap(SchNet, BaseModel): +class SchNetWrap(SchNet, GraphModelMixin): r"""Wrapper around the continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" `_. Each layer uses interaction @@ -28,9 +28,6 @@ class SchNetWrap(SchNet, BaseModel): h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), Args: - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets (int): Number of targets to predict. use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. (default: :obj:`True`) regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating @@ -54,9 +51,6 @@ class SchNetWrap(SchNet, BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -67,7 +61,7 @@ def __init__( cutoff: float = 10.0, readout: str = "add", ) -> None: - self.num_targets = num_targets + self.num_targets = 1 self.regress_forces = regress_forces self.use_pbc = use_pbc self.cutoff = cutoff @@ -88,25 +82,17 @@ def _forward(self, data): z = data.atomic_numbers.long() pos = data.pos batch = data.batch - - ( - edge_index, - edge_weight, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) if self.use_pbc: assert z.dim() == 1 assert z.dtype == torch.long - edge_attr = self.distance_expansion(edge_weight) + edge_attr = self.distance_expansion(graph.edge_distance) h = self.embedding(z) for interaction in self.interactions: - h = h + interaction(h, edge_index, edge_weight, edge_attr) + h = h + interaction(h, graph.edge_index, graph.edge_distance, edge_attr) h = self.lin1(h) h = self.act(h) diff --git a/src/fairchem/core/models/scn/scn.py b/src/fairchem/core/models/scn/scn.py index bf8454f212..84806e19e8 100644 --- a/src/fairchem/core/models/scn/scn.py +++ b/src/fairchem/core/models/scn/scn.py @@ -18,7 +18,7 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -33,7 +33,7 @@ @registry.register_model("scn") -class SphericalChannelNetwork(BaseModel): +class SphericalChannelNetwork(nn.Module, GraphModelMixin): """Spherical Channel Network Paper: Spherical Channels for Modeling Atomic Interactions @@ -75,9 +75,6 @@ class SphericalChannelNetwork(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -262,15 +259,7 @@ def _forward_helper(self, data): atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) pos = data.pos - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures @@ -278,12 +267,12 @@ def _forward_helper(self, data): # Calculate which message block each edge should use. Based on edge distance rank. edge_rank = self._rank_edge_distances( - edge_distance, edge_index, self.max_num_neighbors + graph.edge_distance, graph.edge_index, self.max_num_neighbors ) # Reorder edges so that they are grouped by distance rank (lowest to highest) last_cutoff = -0.1 - message_block_idx = torch.zeros(len(edge_distance), device=pos.device) + message_block_idx = torch.zeros(len(graph.edge_distance), device=pos.device) edge_distance_reorder = torch.tensor([], device=self.device) edge_index_reorder = torch.tensor([], device=self.device) edge_distance_vec_reorder = torch.tensor([], device=self.device) @@ -297,21 +286,21 @@ def _forward_helper(self, data): edge_distance_reorder = torch.cat( [ edge_distance_reorder, - torch.masked_select(edge_distance, mask), + torch.masked_select(graph.edge_distance, mask), ], dim=0, ) edge_index_reorder = torch.cat( [ edge_index_reorder, - torch.masked_select(edge_index, mask.view(1, -1).repeat(2, 1)).view( - 2, -1 - ), + torch.masked_select( + graph.edge_index, mask.view(1, -1).repeat(2, 1) + ).view(2, -1), ], dim=1, ) edge_distance_vec_mask = torch.masked_select( - edge_distance_vec, mask.view(-1, 1).repeat(1, 3) + graph.edge_distance_vec, mask.view(-1, 1).repeat(1, 3) ).view(-1, 3) edge_distance_vec_reorder = torch.cat( [edge_distance_vec_reorder, edge_distance_vec_mask], dim=0 diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 1c0c975f8a..c21409863e 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -450,21 +450,7 @@ def load_model(self) -> None: if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") - # TODO: depreicated, remove. - bond_feat_dim = None - bond_feat_dim = self.config["model_attributes"].get("num_gaussians", 50) - - loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( - ( - loader.dataset[0].x.shape[-1] - if loader - and hasattr(loader.dataset[0], "x") - and loader.dataset[0].x is not None - else None - ), - bond_feat_dim, - 1, **self.config["model_attributes"], ).to(self.device) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index aea07201bd..1584becd45 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -27,9 +27,32 @@ @pytest.fixture() def configs(): return { + "scn": Path("tests/core/models/test_configs/test_scn.yml"), "escn": Path("tests/core/models/test_configs/test_escn.yml"), - "gemnet": Path("tests/core/models/test_configs/test_gemnet.yml"), + "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), + "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), + "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), + "gemnet_dt_hydra": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" + ), + "gemnet_dt_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" + ), + "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), + "gemnet_oc_hydra": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" + ), + "gemnet_oc_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" + ), + "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), + "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), + "painn": Path("tests/core/models/test_configs/test_painn.yml"), + "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), + "equiformer_v2_hydra": Path( + "tests/core/models/test_configs/test_equiformerv2_hydra.yml" + ), } @@ -173,7 +196,7 @@ def smoke_test_train( rundir=str(train_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), @@ -194,7 +217,7 @@ def smoke_test_train( rundir=str(predictions_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), @@ -216,9 +239,22 @@ def smoke_test_train( @pytest.mark.parametrize( "model_name", [ - pytest.param("gemnet", id="gemnet"), + pytest.param("schnet", id="schnet"), + pytest.param("scn", id="scn"), + pytest.param("gemnet_dt", id="gemnet_dt"), + pytest.param("gemnet_dt_hydra", id="gemnet_dt_hydra"), + pytest.param("gemnet_dt_hydra_grad", id="gemnet_dt_hydra_grad"), + pytest.param("gemnet_oc", id="gemnet_oc"), + pytest.param("gemnet_oc_hydra", id="gemnet_oc_hydra"), + pytest.param("gemnet_oc_hydra_grad", id="gemnet_oc_hydra_grad"), + pytest.param("dimenet++", id="dimenet++"), + pytest.param("dimenet++_hydra", id="dimenet++_hydra"), + pytest.param("painn", id="painn"), + pytest.param("painn_hydra", id="painn_hydra"), pytest.param("escn", id="escn"), + pytest.param("escn_hydra", id="escn_hydra"), pytest.param("equiformer_v2", id="equiformer_v2"), + pytest.param("equiformer_v2_hydra", id="equiformer_v2_hydra"), ], ) def test_train_and_predict( @@ -376,7 +412,7 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet", 0.41, 0.06, id="gemnet"), + pytest.param("gemnet_oc", 0.41, 0.06, id="gemnet_oc"), pytest.param("escn", 0.41, 0.06, id="escn"), pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], diff --git a/tests/core/models/test_configs/test_dpp.yml b/tests/core/models/test_configs/test_dpp.yml new file mode 100755 index 0000000000..a79294bd15 --- /dev/null +++ b/tests/core/models/test_configs/test_dpp.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: dimenetplusplus #_bbwheads + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_dpp_hydra.yml b/tests/core/models/test_configs/test_dpp_hydra.yml new file mode 100755 index 0000000000..1120cc905f --- /dev/null +++ b/tests/core/models/test_configs/test_dpp_hydra.yml @@ -0,0 +1,55 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: dimenetplusplus_backbone + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 1 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + heads: + energy: + module: dimenetplusplus_energy_and_force_head + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml new file mode 100644 index 0000000000..4c00fe6a2e --- /dev/null +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -0,0 +1,98 @@ + + +trainer: forces + +model: + name: hydra + backbone: + model: equiformer_v2_backbone + use_pbc: True + regress_forces: True + otf_graph: True + + enforce_max_neighbors_strictly: False + + max_neighbors: 1 + max_radius: 12.0 + max_num_elements: 90 + + num_layers: 1 + sphere_channels: 4 + attn_hidden_channels: 4 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. + num_heads: 1 + attn_alpha_channels: 4 # Not used when `use_s2_act_attn` is True. + attn_value_channels: 4 + ffn_hidden_channels: 8 + norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] + + lmax_list: [1] + mmax_list: [1] + grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. + + num_sphere_samples: 128 + + edge_channels: 32 + use_atom_edge_embedding: True + distance_function: 'gaussian' + num_distance_basis: 16 # not used + + attn_activation: 'silu' + use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. + ffn_activation: 'silu' # ['silu', 'swiglu'] + use_gate_act: False # [True, False] Switch between gate activation and S2 activation + use_grid_mlp: False # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. + + alpha_drop: 0.0 # [0.0, 0.1] + drop_path_rate: 0.0 # [0.0, 0.05] + proj_drop: 0.0 + + weight_init: 'normal' # ['uniform', 'normal'] + heads: + energy: + module: equiformer_v2_energy_head + forces: + module: equiformer_v2_force_head + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_escn_hydra.yml b/tests/core/models/test_configs/test_escn_hydra.yml new file mode 100644 index 0000000000..ba5db1f53e --- /dev/null +++ b/tests/core/models/test_configs/test_escn_hydra.yml @@ -0,0 +1,67 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: escn_backbone + num_layers: 2 + max_neighbors: 10 + cutoff: 12.0 + sphere_channels: 8 + hidden_channels: 8 + lmax_list: [2] + mmax_list: [2] + num_sphere_samples: 64 + distance_function: "gaussian" + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + heads: + energy: + module: escn_energy_head + forces: + module: escn_force_head + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_dt.yml b/tests/core/models/test_configs/test_gemnet_dt.yml new file mode 100644 index 0000000000..b04b6dfda0 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt.yml @@ -0,0 +1,79 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: gemnet_t + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml new file mode 100644 index 0000000000..a612741470 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml @@ -0,0 +1,86 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + heads: + energy: + module: gemnet_t_energy_and_grad_force_head + forces: + module: gemnet_t_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml new file mode 100644 index 0000000000..83d46bdd4d --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml @@ -0,0 +1,84 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: False + heads: + energy_and_forces: + module: gemnet_t_energy_and_grad_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet.yml b/tests/core/models/test_configs/test_gemnet_oc.yml similarity index 100% rename from tests/core/models/test_configs/test_gemnet.yml rename to tests/core/models/test_configs/test_gemnet_oc.yml diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml new file mode 100644 index 0000000000..97343e90e6 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml @@ -0,0 +1,112 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + forces: + module: gemnet_oc_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml new file mode 100644 index 0000000000..334c3cb4db --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml @@ -0,0 +1,109 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: False + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_painn.yml b/tests/core/models/test_configs/test_painn.yml new file mode 100644 index 0000000000..c1f24d0bb5 --- /dev/null +++ b/tests/core/models/test_configs/test_painn.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: painn #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_painn_hydra.yml b/tests/core/models/test_configs/test_painn_hydra.yml new file mode 100644 index 0000000000..0b39aa1731 --- /dev/null +++ b/tests/core/models/test_configs/test_painn_hydra.yml @@ -0,0 +1,58 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: painn_backbone #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + heads: + energy: + module: painn_energy_head + forces: + module: painn_force_head + + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_schnet.yml b/tests/core/models/test_configs/test_schnet.yml new file mode 100755 index 0000000000..97faf3962a --- /dev/null +++ b/tests/core/models/test_configs/test_schnet.yml @@ -0,0 +1,45 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: schnet + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + cutoff: 6.0 + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 64. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 20 + eval_batch_size: 20 + eval_every: 10000 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + warmup_factor: 0.2 + max_epochs: 15 diff --git a/tests/core/models/test_configs/test_scn.yml b/tests/core/models/test_configs/test_scn.yml new file mode 100755 index 0000000000..c080c48557 --- /dev/null +++ b/tests/core/models/test_configs/test_scn.yml @@ -0,0 +1,59 @@ +# A total of 64 32GB GPUs were used for training. +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: scn + num_interactions: 2 + hidden_channels: 16 + sphere_channels: 8 + sphere_channels_reduce: 8 + num_sphere_samples: 8 + num_basis_functions: 8 + distance_function: "gaussian" + show_timing_info: False + max_num_neighbors: 40 + cutoff: 8.0 + lmax: 4 + num_bands: 2 + use_grid: True + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + +optim: + batch_size: 2 + eval_batch_size: 1 + num_workers: 2 + lr_initial: 0.0004 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + eval_every: 5000 + lr_gamma: 0.3 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 260000 + - 340000 + - 420000 + - 500000 + - 800000 + - 1000000 + warmup_steps: 100 + warmup_factor: 0.2 + max_epochs: 12 + clip_grad_norm: 100 + ema_decay: 0.999 diff --git a/tests/core/models/test_dimenetpp.py b/tests/core/models/test_dimenetpp.py index 76a546037b..d1daec728b 100644 --- a/tests/core/models/test_dimenetpp.py +++ b/tests/core/models/test_dimenetpp.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("dimenetplusplus")( - None, - 32, - 1, cutoff=6.0, regress_forces=True, use_pbc=False, diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 0034232cd2..3194dd2df7 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -63,9 +63,6 @@ def _load_model(): checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("equiformer_v2")( - None, - -1, - 1, use_pbc=True, regress_forces=True, otf_graph=True, diff --git a/tests/core/models/test_gemnet.py b/tests/core/models/test_gemnet.py index 3fa0c6babc..b4c5414cc4 100644 --- a/tests/core/models/test_gemnet.py +++ b/tests/core/models/test_gemnet.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("gemnet_t")( - None, - -1, - 1, cutoff=6.0, num_spherical=7, num_radial=128, diff --git a/tests/core/models/test_gemnet_oc.py b/tests/core/models/test_gemnet_oc.py index d84669750f..7729c14483 100644 --- a/tests/core/models/test_gemnet_oc.py +++ b/tests/core/models/test_gemnet_oc.py @@ -58,9 +58,6 @@ def load_model(request) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_gemnet_oc_scaling_mismatch.py b/tests/core/models/test_gemnet_oc_scaling_mismatch.py index 8f1c36d277..29ea40c0fa 100644 --- a/tests/core/models/test_gemnet_oc_scaling_mismatch.py +++ b/tests/core/models/test_gemnet_oc_scaling_mismatch.py @@ -35,9 +35,6 @@ def test_no_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -111,9 +108,6 @@ def test_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -189,9 +183,6 @@ def test_no_file_exists(self) -> None: with pytest.raises(ValueError): registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -245,9 +236,6 @@ def test_not_fitted(self) -> None: setup_imports() model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_schnet.py b/tests/core/models/test_schnet.py index aa704604f7..3dd21be4e1 100644 --- a/tests/core/models/test_schnet.py +++ b/tests/core/models/test_schnet.py @@ -46,7 +46,7 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("schnet")( - None, 32, 1, cutoff=6.0, regress_forces=True, use_pbc=True + cutoff=6.0, regress_forces=True, use_pbc=True ) request.cls.model = model From 029d4d3c3c246d29d46407128807d24a1b56cfa6 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Sun, 4 Aug 2024 21:14:35 -0600 Subject: [PATCH 6/8] (OTF) Normalization and element references (#715) * denorm targets in _forward only * linear reference class * atomref in normalizer * raise input error * clean up normalizer interface * add element refs * add element refs correctly * ruff * fix save_checkpoint * reference and dereference * 2xnorm linref trainer add * clean-up * otf linear reference fit * fix tensor device * otf element references and normalizers * use only present elements when fitting * lint * _forward norm and derefd values * fix list of paths in src * total mean and std * fitted flag to avoid refitting normalizers/references on rerun * allow passing lstsq driver * element ref unit tests * remove superfluous type * lint fix * allow setting batch_size explicitly * test applying element refs * normalizer tests * increase distributed timeout * save normalizers and linear refs in otf_fit * remove debug code * fix removing refs * swap otf_fit for fit, and save all normalizers in one file * log loading and saving normalizers * fit references and normalizer scripts * lint fixes * allow absent optim key in config * lin-ref description * read files based on extension * pass seed * rename dataset fixture * check if file is none * pass generator correctly * separate method for norms and refs * add normalizer code back * fix Generator construction * import order * log warnings if multiple inputs are passed * raise Error if duplicate references or norms are set * use len batch * assert element reference targets are scalar * fix name and rename method * load and save norms and refs using same logic * fix creating normalizer * remove print statements * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * warn instead of error when duplicate norm/ref target names * allow timeout to be read from config * test seed noseed ref fits * lotsa refactoring * lotsa fixing * more fixing... * num_workers zero to prevent mp issues * add otf norms smoke test and fixes * allow overriding normalization fit values * update tests * fix normalizer loading * use rmsd instead of only stdev * fix tests * correct rmsd calc and fix loading * clean up norm loading and log values * logg linear reference metrics * load element references state dict * fix loading and tests * fix imports in scripts * fix test? * fix test * use numpy as default to fit references * minor fixes * rm torch_tempdir fixture --------- Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> --- src/fairchem/core/common/distutils.py | 16 +- src/fairchem/core/datasets/ase_datasets.py | 17 +- .../core/modules/normalization/__init__.py | 0 .../core/modules/normalization/_load_utils.py | 113 +++++++ .../normalization/element_references.py | 290 ++++++++++++++++++ .../core/modules/normalization/normalizer.py | 290 ++++++++++++++++++ src/fairchem/core/modules/normalizer.py | 56 ---- src/fairchem/core/modules/transforms.py | 8 +- src/fairchem/core/scripts/fit_normalizers.py | 119 +++++++ src/fairchem/core/scripts/fit_references.py | 91 ++++++ src/fairchem/core/trainers/base_trainer.py | 142 ++++++--- src/fairchem/core/trainers/ocp_trainer.py | 55 ++-- tests/core/e2e/test_s2ef.py | 139 ++++++--- .../models/test_configs/test_equiformerv2.yml | 93 +++--- tests/core/models/test_configs/test_escn.yml | 50 +-- .../models/test_configs/test_gemnet_oc.yml | 53 ++-- tests/core/modules/conftest.py | 48 +++ tests/core/modules/test_element_references.py | 159 ++++++++++ tests/core/modules/test_normalizer.py | 98 ++++++ 19 files changed, 1572 insertions(+), 265 deletions(-) create mode 100644 src/fairchem/core/modules/normalization/__init__.py create mode 100644 src/fairchem/core/modules/normalization/_load_utils.py create mode 100644 src/fairchem/core/modules/normalization/element_references.py create mode 100644 src/fairchem/core/modules/normalization/normalizer.py delete mode 100644 src/fairchem/core/modules/normalizer.py create mode 100644 src/fairchem/core/scripts/fit_normalizers.py create mode 100644 src/fairchem/core/scripts/fit_references.py create mode 100644 tests/core/modules/conftest.py create mode 100644 tests/core/modules/test_element_references.py create mode 100644 tests/core/modules/test_normalizer.py diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 8989840641..f6bf88ccaf 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -10,7 +10,8 @@ import logging import os import subprocess -from typing import TypeVar +from datetime import timedelta +from typing import Any, TypeVar import torch import torch.distributed as dist @@ -27,6 +28,7 @@ def os_environ_get_or_throw(x: str) -> str: def setup(config) -> None: + timeout = timedelta(minutes=config.get("timeout", 30)) if config["submit"]: node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: @@ -72,6 +74,7 @@ def setup(config) -> None: init_method=config["init_method"], world_size=config["world_size"], rank=config["rank"], + timeout=timeout, ) except subprocess.CalledProcessError as e: # scontrol failed raise e @@ -95,10 +98,11 @@ def setup(config) -> None: rank=world_rank, world_size=world_size, init_method="env://", + timeout=timeout, ) else: config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend=config.get("backend", "nccl")) + dist.init_process_group(backend=config.get("backend", "nccl"), timeout=timeout) def cleanup() -> None: @@ -135,6 +139,14 @@ def broadcast( dist.broadcast(tensor, src, group, async_op) +def broadcast_object_list( + object_list: list[Any], src: int, group=dist.group.WORLD, device: str | None = None +) -> None: + if get_world_size() == 1: + return + dist.broadcast_object_list(object_list, src, group, device) + + def all_reduce( data, group=dist.group.WORLD, average: bool = False, device=None ) -> torch.Tensor: diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index 15c22322db..08618c9f25 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -13,7 +13,7 @@ import os import warnings from abc import ABC, abstractmethod -from functools import cache, reduce +from functools import cache from glob import glob from pathlib import Path from typing import Any, Callable @@ -467,13 +467,14 @@ class AseDBDataset(AseAtomsDataset): def _load_dataset_get_ids(self, config: dict) -> list[int]: if isinstance(config["src"], list): - if os.path.isdir(config["src"][0]): - filepaths = reduce( - lambda x, y: x + y, - (glob(f"{path}/*") for path in config["src"]), - ) - else: - filepaths = config["src"] + filepaths = [] + for path in config["src"]: + if os.path.isdir(path): + filepaths.extend(glob(f"{path}/*")) + elif os.path.isfile(path): + filepaths.append(path) + else: + raise RuntimeError(f"Error reading dataset in {path}!") elif os.path.isfile(config["src"]): filepaths = [config["src"]] elif os.path.isdir(config["src"]): diff --git a/src/fairchem/core/modules/normalization/__init__.py b/src/fairchem/core/modules/normalization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/fairchem/core/modules/normalization/_load_utils.py b/src/fairchem/core/modules/normalization/_load_utils.py new file mode 100644 index 0000000000..0825886db9 --- /dev/null +++ b/src/fairchem/core/modules/normalization/_load_utils.py @@ -0,0 +1,113 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Callable + +import torch + +from fairchem.core.common.utils import save_checkpoint + +if TYPE_CHECKING: + from pathlib import Path + + from torch.nn import Module + from torch.utils.data import Dataset + + +def _load_check_duplicates(config: dict, name: str) -> dict[str, torch.nn.Module]: + """Attempt to load a single file with normalizers/element references and check config for duplicate targets. + + Args: + config: configuration dictionary + name: Name of module to use for logging + + Returns: + dictionary of normalizer or element reference modules + """ + modules = {} + if "file" in config: + modules = torch.load(config["file"]) + logging.info(f"Loaded {name} for the following targets: {list(modules.keys())}") + # make sure that element-refs are not specified both as fit and file + fit_targets = config["fit"]["targets"] if "fit" in config else [] + duplicates = list( + filter( + lambda x: x in fit_targets, + list(config) + list(modules.keys()), + ) + ) + if len(duplicates) > 0: + logging.warning( + f"{name} values for the following targets {duplicates} have been specified to be fit and also read" + f" from a file. The files read from file will be used instead of fitting." + ) + duplicates = list(filter(lambda x: x in modules, config)) + if len(duplicates) > 0: + logging.warning( + f"Duplicate {name} values for the following targets {duplicates} where specified in the file " + f"{config['file']} and an explicitly set file. The normalization values read from " + f"{config['file']} will be used." + ) + return modules + + +def _load_from_config( + config: dict, + name: str, + fit_fun: Callable[[list[str], Dataset, Any, ...], dict[str, Module]], + create_fun: Callable[[str | Path], Module], + dataset: Dataset, + checkpoint_dir: str | Path | None = None, + **fit_kwargs, +) -> dict[str, torch.nn.Module]: + """Load or fit normalizers or element references from config + + If a fit is done, a fitted key with value true is added to the config to avoid re-fitting + once a checkpoint has been saved. + + Args: + config: configuration dictionary + name: Name of module to use for logging + fit_fun: Function to fit modules + create_fun: Function to create a module from file + checkpoint_dir: directory to save modules. If not given, modules won't be saved. + + Returns: + dictionary of normalizer or element reference modules + + """ + modules = _load_check_duplicates(config, name) + for target in config: + if target == "fit" and not config["fit"].get("fitted", False): + # remove values for output targets that have already been read from files + targets = [ + target for target in config["fit"]["targets"] if target not in modules + ] + fit_kwargs.update( + {k: v for k, v in config["fit"].items() if k != "targets"} + ) + modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs)) + config["fit"]["fitted"] = True + # if a single file for all outputs is not provided, + # then check if a single file is provided for a specific output + elif target != "file": + modules[target] = create_fun(**config[target]) + # save the linear references for possible subsequent use + if checkpoint_dir is not None: + path = save_checkpoint( + modules, + checkpoint_dir, + f"{name}.pt", + ) + logging.info( + f"{name} checkpoint for targets {list(modules.keys())} have been saved to: {path}" + ) + + return modules diff --git a/src/fairchem/core/modules/normalization/element_references.py b/src/fairchem/core/modules/normalization/element_references.py new file mode 100644 index 0000000000..e41dbe588c --- /dev/null +++ b/src/fairchem/core/modules/normalization/element_references.py @@ -0,0 +1,290 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from fairchem.core.datasets import data_list_collater + +from ._load_utils import _load_from_config + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +class LinearReferences(nn.Module): + """Represents an elemental linear references model for a target property. + + In an elemental reference associates a value with each chemical element present in the dataset. + Elemental references define a chemical composition model, i.e. a rough approximation of a target + property (energy) using elemental references is done by summing the elemental references multiplied + by the number of times the corresponding element is present. + + Elemental references energies can be taken as: + - the energy of a chemical species in its elemental state + (i.e. lowest energy polymorph of single element crystal structures for solids) + - fitting a linear model to a dataset, where the features are the counts of each element in each data point. + see the function fit_linear references below for details + + Training GNNs to predict the difference between DFT and the predictions of a chemical composition + model represent a useful normalization scheme that can improve model accuracy. See for example the + "Alternative reference scheme" section of the OC22 manuscript: https://arxiv.org/pdf/2206.08917 + """ + + def __init__( + self, + element_references: torch.Tensor | None = None, + max_num_elements: int = 118, + ): + """ + Args: + element_references (Tensor): tensor with linear reference values + max_num_elements (int): max number of elements - 118 is a stretch + metrics (dict): dictionary with accuracy metrics in predicting values for structures used in fitting. + """ + super().__init__() + self.register_buffer( + name="element_references", + tensor=element_references + if element_references is not None + else torch.zeros(max_num_elements + 1), + ) + + def _apply_refs( + self, target: torch.Tensor, batch: Batch, sign: int, reshaped: bool = True + ) -> torch.Tensor: + """Apply references batch-wise""" + indices = batch.atomic_numbers.to( + dtype=torch.int, device=self.element_references.device + ) + elemrefs = self.element_references[indices].to(dtype=target.dtype) + # this option should not exist, all tensors should have compatible shapes in dataset and trainer outputs + if reshaped: + elemrefs = elemrefs.view(batch.natoms.sum(), -1) + + return target.index_add(0, batch.batch, elemrefs, alpha=sign) + + @torch.autocast(device_type="cuda", enabled=False) + def dereference( + self, target: torch.Tensor, batch: Batch, reshaped: bool = True + ) -> torch.Tensor: + """Remove linear references""" + return self._apply_refs(target, batch, -1, reshaped=reshaped) + + @torch.autocast(device_type="cuda", enabled=False) + def forward( + self, target: torch.Tensor, batch: Batch, reshaped: bool = True + ) -> torch.Tensor: + """Add linear references""" + return self._apply_refs(target, batch, 1, reshaped=reshaped) + + +def create_element_references( + file: str | Path | None = None, + state_dict: dict | None = None, +) -> LinearReferences: + """Create an element reference module. + + Args: + type (str): type of reference (only linear implemented) + file (str or Path): path to pt or npz file + state_dict (dict): a state dict of a element reference module + + Returns: + LinearReference + """ + if file is not None and state_dict is not None: + logging.warning( + "Both a file and a state_dict for element references was given." + "The references will be read from the file and the provided state_dict will be ignored." + ) + + # path takes priority if given + if file is not None: + extension = Path(file).suffix + if extension == ".pt": + # try to load a pt file + state_dict = torch.load(file) + elif extension == ".npz": + state_dict = {} + with np.load(file) as values: + # legacy linref files + if "coeff" in values: + state_dict["element_references"] = torch.tensor(values["coeff"]) + else: + state_dict["element_references"] = torch.tensor( + values["element_references"] + ) + else: + raise RuntimeError( + f"Element references file with extension '{extension}' is not supported." + ) + + if "element_references" not in state_dict: + raise RuntimeError("Unable to load linear element references!") + + return LinearReferences(element_references=state_dict["element_references"]) + + +@torch.no_grad() +def fit_linear_references( + targets: list[str], + dataset: Dataset, + batch_size: int, + num_batches: int | None = None, + num_workers: int = 0, + max_num_elements: int = 118, + log_metrics: bool = True, + use_numpy: bool = True, + driver: str | None = None, + shuffle: bool = True, + seed: int = 0, +) -> dict[str, LinearReferences]: + """Fit a set linear references for a list of targets using a given number of batches. + + Args: + targets: list of target names + dataset: data set to fit linear references with + batch_size: size of batch + num_batches: number of batches to use in fit. If not given will use all batches + num_workers: number of workers to use in data loader. + Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function + in distributed mode. The issue has to do with pickling the functions in load_references_from_config + see function below... + max_num_elements: max number of elements in dataset. If not given will use an ambitious value of 118 + log_metrics: if true will compute MAE, RMSE and R2 score of fit and log. + use_numpy: use numpy.linalg.lstsq instead of torch. This tends to give better solutions. + driver: backend used to solve linear system. See torch.linalg.lstsq docs. Ignored if use_numpy=True + shuffle: whether to shuffle when loading the dataset + seed: random seed used to shuffle the sampler if shuffle=True + + Returns: + dict of fitted LinearReferences objects + """ + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(data_list_collater, otf_graph=True), + num_workers=num_workers, + persistent_workers=num_workers > 0, + generator=torch.Generator().manual_seed(seed), + ) + + num_batches = num_batches if num_batches is not None else len(data_loader) + if num_batches > len(data_loader): + logging.warning( + f"The given num_batches {num_batches} is larger than total batches of size {batch_size} in dataset. " + f"num_batches will be ignored and the whole dataset will be used." + ) + num_batches = len(data_loader) + + max_num_elements += 1 # + 1 since H starts at index 1 + # solving linear system happens on CPU, which allows handling poorly conditioned and + # rank deficient matrices, unlike torch lstsq on GPU + composition_matrix = torch.zeros( + num_batches * batch_size, + max_num_elements, + ) + + target_vectors = { + target: torch.zeros(num_batches * batch_size) for target in targets + } + + logging.info( + f"Fitting linear references using {num_batches * batch_size} samples in {num_batches} " + f"batches of size {batch_size}." + ) + for i, batch in tqdm( + enumerate(data_loader), total=num_batches, desc="Fitting linear references" + ): + if i == 0: + assert all( + len(batch[target].squeeze().shape) == 1 for target in targets + ), "element references can only be used for scalar targets" + elif i == num_batches: + break + + next_batch_size = len(batch) if i == len(data_loader) - 1 else batch_size + for target in targets: + target_vectors[target][ + i * batch_size : i * batch_size + next_batch_size + ] = batch[target].to(torch.float64) + for j, data in enumerate(batch.to_data_list()): + composition_matrix[i * batch_size + j] = torch.bincount( + data.atomic_numbers.int(), + minlength=max_num_elements, + ).to(torch.float64) + + # reduce the composition matrix to only features that are non-zero to improve rank + mask = composition_matrix.sum(axis=0) != 0.0 + reduced_composition_matrix = composition_matrix[:, mask] + elementrefs = {} + + for target in targets: + coeffs = torch.zeros(max_num_elements) + + if use_numpy: + solution = torch.tensor( + np.linalg.lstsq( + reduced_composition_matrix.numpy(), + target_vectors[target].numpy(), + rcond=None, + )[0] + ) + else: + lstsq = torch.linalg.lstsq( + reduced_composition_matrix, target_vectors[target], driver=driver + ) + solution = lstsq.solution + + coeffs[mask] = solution + elementrefs[target] = LinearReferences(coeffs) + + if log_metrics is True: + y = target_vectors[target] + y_pred = torch.matmul(reduced_composition_matrix, solution) + y_mean = target_vectors[target].mean() + N = len(target_vectors[target]) + ss_res = ((y - y_pred) ** 2).sum() + ss_tot = ((y - y_mean) ** 2).sum() + mae = (abs(y - y_pred)).sum() / N + rmse = (((y - y_pred) ** 2).sum() / N).sqrt() + r2 = 1 - (ss_res / ss_tot) + logging.info( + f"Training accuracy metrics for fitted linear element references: mae={mae}, rmse={rmse}, r2 score={r2}" + ) + + return elementrefs + + +def load_references_from_config( + config: dict[str, Any], + dataset: Dataset, + seed: int = 0, + checkpoint_dir: str | Path | None = None, +) -> dict[str, LinearReferences]: + """Create a dictionary with element references from a config.""" + return _load_from_config( + config, + "element_references", + fit_linear_references, + create_element_references, + dataset, + checkpoint_dir, + seed=seed, + ) diff --git a/src/fairchem/core/modules/normalization/normalizer.py b/src/fairchem/core/modules/normalization/normalizer.py new file mode 100644 index 0000000000..f16db7d398 --- /dev/null +++ b/src/fairchem/core/modules/normalization/normalizer.py @@ -0,0 +1,290 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import warnings +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from fairchem.core.datasets import data_list_collater + +from ._load_utils import _load_from_config + +if TYPE_CHECKING: + from collections.abc import Mapping + + from fairchem.core.modules.normalization.element_references import LinearReferences + + +class Normalizer(nn.Module): + """Normalize/denormalize a tensor and optionally add a atom reference offset.""" + + def __init__( + self, + mean: float | torch.Tensor = 0.0, + rmsd: float | torch.Tensor = 1.0, + ): + """tensor is taken as a sample to calculate the mean and rmsd""" + super().__init__() + + if isinstance(mean, float): + mean = torch.tensor(mean) + if isinstance(rmsd, float): + rmsd = torch.tensor(rmsd) + + self.register_buffer(name="mean", tensor=mean) + self.register_buffer(name="rmsd", tensor=rmsd) + + @torch.autocast(device_type="cuda", enabled=False) + def norm(self, tensor: torch.Tensor) -> torch.Tensor: + return (tensor - self.mean) / self.rmsd + + @torch.autocast(device_type="cuda", enabled=False) + def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: + return normed_tensor * self.rmsd + self.mean + + def forward(self, normed_tensor: torch.Tensor) -> torch.Tensor: + return self.denorm(normed_tensor) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + # check if state dict is legacy state dicts + if "std" in state_dict: + state_dict = { + "mean": torch.tensor(state_dict["mean"]), + "rmsd": state_dict["std"], + } + + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + +def create_normalizer( + file: str | Path | None = None, + state_dict: dict | None = None, + tensor: torch.Tensor | None = None, + mean: float | torch.Tensor | None = None, + rmsd: float | torch.Tensor | None = None, + stdev: float | torch.Tensor | None = None, +) -> Normalizer: + """Build a target data normalizers with optional atom ref + + Only one of file, state_dict, tensor, or (mean and rmsd) will be used to create a normalizer. + If more than one set of inputs are given priority will be given following the order in which they are listed above. + + Args: + file (str or Path): path to pt or npz file. + state_dict (dict): a state dict for Normalizer module + tensor (Tensor): a tensor with target values used to compute mean and std + mean (float | Tensor): mean of target data + rmsd (float | Tensor): rmsd of target data, rmsd from mean = stdev, rmsd from 0 = rms + stdev: standard deviation (deprecated, use rmsd instead) + + Returns: + Normalizer + """ + if stdev is not None: + warnings.warn( + "Use of 'stdev' is deprecated, use 'rmsd' instead", DeprecationWarning + ) + if rmsd is not None: + logging.warning( + "Both 'stdev' and 'rmsd' values where given to create a normalizer, rmsd values will be used." + ) + + # old configs called it stdev, using this in the function signature reduces overhead code elsewhere + if stdev is not None and rmsd is None: + rmsd = stdev + + # path takes priority if given + if file is not None: + if state_dict is not None or tensor is not None or mean is not None: + logging.warning( + "A file to a normalizer has been given. Normalization values will be read from it, and all other inputs" + " will be ignored." + ) + extension = Path(file).suffix + if extension == ".pt": + # try to load a pt file + state_dict = torch.load(file) + elif extension == ".npz": + # try to load an NPZ file + values = np.load(file) + mean = values.get("mean") + rmsd = values.get("rmsd") or values.get("std") # legacy files + tensor = None # set to None since values read from file are prioritized + else: + raise RuntimeError( + f"Normalizer file with extension '{extension}' is not supported." + ) + + # state dict is second priority + if state_dict is not None: + if tensor is not None or mean is not None: + logging.warning( + "The state_dict provided will be used to set normalization values. All other inputs will be ignored." + ) + normalizer = Normalizer() + normalizer.load_state_dict(state_dict) + return normalizer + + # if not then read target value tensor + if tensor is not None: + if mean is not None: + logging.warning( + "Normalization values will be computed from input tensor, all other inputs will be ignored." + ) + mean = torch.mean(tensor) + rmsd = torch.std(tensor) + elif mean is not None and rmsd is not None: + if not isinstance(mean, torch.Tensor): + mean = torch.tensor(mean) + if not isinstance(rmsd, torch.Tensor): + rmsd = torch.tensor(rmsd) + + # if mean and rmsd are still None than raise an error + if mean is None or rmsd is None: + raise ValueError( + "Incorrect inputs. One of the following sets of inputs must be given: ", + "a file path to a .pt or .npz file, or mean and rmsd values, or a tensor of target values", + ) + + return Normalizer(mean=mean, rmsd=rmsd) + + +@torch.no_grad() +def fit_normalizers( + targets: list[str], + dataset: Dataset, + batch_size: int, + override_values: dict[str, dict[str, float]] | None = None, + rmsd_correction: int | None = None, + element_references: dict | None = None, + num_batches: int | None = None, + num_workers: int = 0, + shuffle: bool = True, + seed: int = 0, +) -> dict[str, Normalizer]: + """Estimate mean and rmsd from data to create normalizers + + Args: + targets: list of target names + dataset: data set to fit linear references with + batch_size: size of batch + override_values: dictionary with target names and values to override. i.e. {"forces": {"mean": 0.0}} will set + the forces mean to zero. + rmsd_correction: correction to use when computing mean in std/rmsd. See docs for torch.std. + If not given, will always use 0 when mean == 0, and 1 otherwise. + element_references: + num_batches: number of batches to use in fit. If not given will use all batches + num_workers: number of workers to use in data loader + Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function + in distributed mode. The issue has to do with pickling the functions in load_normalizers_from_config + see function below... + shuffle: whether to shuffle when loading the dataset + seed: random seed used to shuffle the sampler if shuffle=True + + Returns: + dict of normalizer objects + """ + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(data_list_collater, otf_graph=True), + num_workers=num_workers, + persistent_workers=num_workers > 0, + generator=torch.Generator().manual_seed(seed), + ) + + num_batches = num_batches if num_batches is not None else len(data_loader) + if num_batches > len(data_loader): + logging.warning( + f"The given num_batches {num_batches} is larger than total batches of size {batch_size} in dataset. " + f"num_batches will be ignored and the whole dataset will be used." + ) + num_batches = len(data_loader) + + element_references = element_references or {} + target_vectors = defaultdict(list) + + logging.info( + f"Estimating mean and rmsd for normalization using {num_batches * batch_size} samples in {num_batches} batches " + f"of size {batch_size}." + ) + for i, batch in tqdm( + enumerate(data_loader), total=num_batches, desc="Estimating mean and rmsd" + ): + if i == num_batches: + break + + for target in targets: + target_vector = batch[target] + if target in element_references: + target_vector = element_references[target].dereference( + target_vector, batch, reshaped=False + ) + target_vectors[target].append(target_vector) + + normalizers = {} + for target in targets: + target_vector = torch.cat(target_vectors[target], dim=0) + values = {"mean": target_vector.mean()} + if target in override_values: + for name, val in override_values[target].items(): + values[name] = torch.tensor(val) + # calculate root mean square deviation + if "rmsd" not in values: + if rmsd_correction is None: + rmsd_correction = 0 if values["mean"] == 0.0 else 1 + values["rmsd"] = ( + ((target_vector - values["mean"]) ** 2).sum() + / max(len(target_vector) - rmsd_correction, 1) + ).sqrt() + normalizers[target] = create_normalizer(**values) + + return normalizers + + +def load_normalizers_from_config( + config: dict[str, Any], + dataset: Dataset, + seed: int = 0, + checkpoint_dir: str | Path | None = None, + element_references: dict[str, LinearReferences] | None = None, +) -> dict[str, Normalizer]: + """Create a dictionary with element references from a config.""" + # edit the config slightly to extract override args + if "fit" in config: + override_values = { + target: vals + for target, vals in config["fit"]["targets"].items() + if isinstance(vals, dict) + } + config["fit"]["override_values"] = override_values + config["fit"]["targets"] = list(config["fit"]["targets"].keys()) + + return _load_from_config( + config, + "normalizers", + fit_normalizers, + create_normalizer, + dataset, + checkpoint_dir, + seed=seed, + element_references=element_references, + ) diff --git a/src/fairchem/core/modules/normalizer.py b/src/fairchem/core/modules/normalizer.py deleted file mode 100644 index 75f34e83f4..0000000000 --- a/src/fairchem/core/modules/normalizer.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import torch - - -class Normalizer: - """Normalize a Tensor and restore it later.""" - - def __init__( - self, - tensor: torch.Tensor | None = None, - mean=None, - std=None, - device=None, - ) -> None: - """tensor is taken as a sample to calculate the mean and std""" - if tensor is None and mean is None: - return - - if device is None: - device = "cpu" - - self.mean: torch.Tensor - self.std: torch.Tensor - if tensor is not None: - self.mean = torch.mean(tensor, dim=0).to(device) - self.std = torch.std(tensor, dim=0).to(device) - return - - if mean is not None and std is not None: - self.mean = torch.tensor(mean).to(device) - self.std = torch.tensor(std).to(device) - - def to(self, device) -> None: - self.mean = self.mean.to(device) - self.std = self.std.to(device) - - def norm(self, tensor: torch.Tensor) -> torch.Tensor: - return (tensor - self.mean) / self.std - - def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: - return normed_tensor * self.std + self.mean - - def state_dict(self): - return {"mean": self.mean, "std": self.std} - - def load_state_dict(self, state_dict) -> None: - self.mean = state_dict["mean"].to(self.mean.device) - self.std = state_dict["std"].to(self.mean.device) diff --git a/src/fairchem/core/modules/transforms.py b/src/fairchem/core/modules/transforms.py index 3a86be468c..52675fd28f 100644 --- a/src/fairchem/core/modules/transforms.py +++ b/src/fairchem/core/modules/transforms.py @@ -19,10 +19,12 @@ def __call__(self, data_object): return data_object for transform_fn in self.config: - # TODO: Normalization information used in the trainers. Ignore here - # for now. - if transform_fn == "normalizer": + # TODO: Normalization information used in the trainers. Ignore here for now + # TODO: if we dont use them here, these should not be defined as "transforms" in the config + # TODO: add them as another entry under dataset, maybe "standardize"? + if transform_fn in ("normalizer", "element_references"): continue + data_object = eval(transform_fn)(data_object, self.config[transform_fn]) return data_object diff --git a/src/fairchem/core/scripts/fit_normalizers.py b/src/fairchem/core/scripts/fit_normalizers.py new file mode 100644 index 0000000000..0cfa2f2db5 --- /dev/null +++ b/src/fairchem/core/scripts/fit_normalizers.py @@ -0,0 +1,119 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_config, save_checkpoint +from fairchem.core.modules.normalization.element_references import ( + create_element_references, +) +from fairchem.core.modules.normalization.normalizer import fit_normalizers + + +def fit_norms( + config: dict, + output_path: str | Path, + linref_file: str | Path | None = None, + linref_target: str = "energy", +) -> None: + """Fit dataset mean and std using the standard config + + Args: + config: config + output_path: output path + linref_file: path to fitted linear references. IF these are used in training they must be used to compute mean/std + linref_target: target using linear references, basically always energy. + """ + output_path = Path(output_path).resolve() + elementrefs = ( + {linref_target: create_element_references(linref_file)} + if linref_file is not None + else {} + ) + + try: + # load the training dataset + train_dataset = registry.get_dataset_class( + config["dataset"]["train"].get("format", "lmdb") + )(config["dataset"]["train"]) + except KeyError as err: + raise ValueError("Train dataset is not specified in config!") from err + + try: + norm_config = config["dataset"]["train"]["transforms"]["normalizer"]["fit"] + except KeyError as err: + raise ValueError( + "The provided config does not specify a 'fit' block for 'normalizer'!" + ) from err + + targets = list(norm_config["targets"].keys()) + override_values = { + target: vals + for target, vals in norm_config["targets"].items() + if isinstance(vals, dict) + } + + normalizers = fit_normalizers( + targets=targets, + override_values=override_values, + element_references=elementrefs, + dataset=train_dataset, + batch_size=norm_config.get("batch_size", 32), + num_batches=norm_config.get("num_batches"), + num_workers=config.get("optim", {}).get("num_workers", 16), + ) + path = save_checkpoint( + normalizers, + output_path, + "normalizers.pt", + ) + logging.info(f"normalizers have been saved to {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to configuration yaml file", + ) + parser.add_argument( + "--out-path", + default=".", + type=str, + help="Output path to save normalizers", + ) + parser.add_argument( + "--linref-path", + type=str, + help="Path to linear references used.", + ) + parser.add_argument( + "--linref-target", + default="energy", + type=str, + help="target for which linear references are used.", + ) + args = parser.parse_args() + config, dup_warning, dup_error = load_config(args.config) + + if len(dup_warning) > 0: + logging.warning( + f"The following keys in the given config have duplicates: {dup_warning}." + ) + if len(dup_error) > 0: + raise RuntimeError( + f"The following include entries in the config have duplicates: {dup_error}" + ) + + fit_norms(config, args.out_path, args.linref_path) diff --git a/src/fairchem/core/scripts/fit_references.py b/src/fairchem/core/scripts/fit_references.py new file mode 100644 index 0000000000..f7f0c84dd7 --- /dev/null +++ b/src/fairchem/core/scripts/fit_references.py @@ -0,0 +1,91 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_config, save_checkpoint +from fairchem.core.modules.normalization.element_references import fit_linear_references + + +def fit_linref(config: dict, output_path: str | Path) -> None: + """Fit linear references using the standard config + + Args: + config: config + output_path: output path + """ + # load the training dataset + output_path = Path(output_path).resolve() + + try: + # load the training dataset + train_dataset = registry.get_dataset_class( + config["dataset"]["train"].get("format", "lmdb") + )(config["dataset"]["train"]) + except KeyError as err: + raise ValueError("Train dataset is not specified in config!") from err + + try: + elementref_config = config["dataset"]["train"]["transforms"][ + "element_references" + ]["fit"] + except KeyError as err: + raise ValueError( + "The provided config does not specify a 'fit' block for 'element_refereces'!" + ) from err + + element_refs = fit_linear_references( + targets=elementref_config["targets"], + dataset=train_dataset, + batch_size=elementref_config.get("batch_size", 32), + num_batches=elementref_config.get("num_batches"), + num_workers=config.get("optim", {}).get("num_workers", 16), + max_num_elements=elementref_config.get("max_num_elements", 118), + driver=elementref_config.get("driver", None), + ) + + for target, references in element_refs.items(): + path = save_checkpoint( + references.state_dict(), + output_path, + f"{target}_linref.pt", + ) + logging.info(f"{target} linear references have been saved to: {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to configuration yaml file", + ) + parser.add_argument( + "--out-path", + default=".", + type=str, + help="Output path to save linear references", + ) + args = parser.parse_args() + config, dup_warning, dup_error = load_config(args.config) + + if len(dup_warning) > 0: + logging.warning( + f"The following keys in the given config have duplicates: {dup_warning}." + ) + if len(dup_error) > 0: + raise RuntimeError( + f"The following include entries in the config have duplicates: {dup_error}" + ) + + fit_linref(config, args.out_path) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index c21409863e..40c7e65de6 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -43,7 +43,11 @@ from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss -from fairchem.core.modules.normalizer import Normalizer +from fairchem.core.modules.normalization.element_references import ( + LinearReferences, + load_references_from_config, +) +from fairchem.core.modules.normalization.normalizer import load_normalizers_from_config from fairchem.core.modules.scaling.compat import load_scales_compat from fairchem.core.modules.scaling.util import ensure_fitted from fairchem.core.modules.scheduler import LRScheduler @@ -185,6 +189,11 @@ def __init__( if distutils.is_master(): logging.info(yaml.dump(self.config, default_flow_style=False)) + self.elementrefs = {} + self.normalizers = {} + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None self.load() @abstractmethod @@ -208,6 +217,7 @@ def load(self) -> None: self.load_seed_from_config() self.load_logger() self.load_datasets() + self.load_references_and_normalizers() self.load_task() self.load_model() self.load_loss() @@ -395,20 +405,68 @@ def convert_settings_to_split_settings(config, split_name): self.relax_sampler, ) - def load_task(self): - # Normalizer for the dataset. - + def load_references_and_normalizers(self): + """Load or create element references and normalizers from config""" # Is it troublesome that we assume any normalizer info is in train? What if there is no # training dataset? What happens if we just specify a test - normalizer = self.config["dataset"].get("transforms", {}).get("normalizer", {}) - self.normalizers = {} - if normalizer: - for target in normalizer: - self.normalizers[target] = Normalizer( - mean=normalizer[target].get("mean", 0), - std=normalizer[target].get("stdev", 1), + + elementref_config = ( + self.config["dataset"].get("transforms", {}).get("element_references") + ) + norms_config = self.config["dataset"].get("transforms", {}).get("normalizer") + elementrefs, normalizers = {}, {} + if distutils.is_master(): + if elementref_config is not None: + # put them in a list to allow broadcasting python objects + elementrefs = load_references_from_config( + elementref_config, + dataset=self.train_dataset, + seed=self.config["cmd"]["seed"], + checkpoint_dir=self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None, + ) + + if norms_config is not None: + normalizers = load_normalizers_from_config( + norms_config, + dataset=self.train_dataset, + seed=self.config["cmd"]["seed"], + checkpoint_dir=self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None, + element_references=elementrefs, ) + # log out the values that will be used. + for output, normalizer in normalizers.items(): + logging.info( + f"Normalization values for output {output}: mean={normalizer.mean.item()}, rmsd={normalizer.rmsd.item()}." + ) + + # put them in a list to broadcast them + elementrefs, normalizers = [elementrefs], [normalizers] + distutils.broadcast_object_list( + object_list=elementrefs, src=0, device=self.device + ) + distutils.broadcast_object_list( + object_list=normalizers, src=0, device=self.device + ) + # make sure element refs and normalizers are on this device + self.elementrefs.update( + { + output: elementref.to(self.device) + for output, elementref in elementrefs[0].items() + } + ) + self.normalizers.update( + { + output: normalizer.to(self.device) + for output, normalizer in normalizers[0].items() + } + ) + + def load_task(self): self.output_targets = {} for target_name in self.config["outputs"]: self.output_targets[target_name] = self.config["outputs"][target_name] @@ -425,15 +483,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config["outputs"][target_name].get( - "train_on_free_atoms", True - ) + self.config[ + "outputs" + ][target_name].get("train_on_free_atoms", True) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config["outputs"][target_name].get( - "eval_on_free_atoms", True - ) + self.config[ + "outputs" + ][target_name].get("eval_on_free_atoms", True) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent @@ -550,9 +608,20 @@ def load_checkpoint( target_key = key if target_key in self.normalizers: - self.normalizers[target_key].load_state_dict( + mkeys = self.normalizers[target_key].load_state_dict( checkpoint["normalizers"][key] ) + assert len(mkeys.missing_keys) == 0 + assert len(mkeys.unexpected_keys) == 0 + + for key, state_dict in checkpoint.get("elementrefs", {}).items(): + elementrefs = LinearReferences( + max_num_elements=len(state_dict["element_references"]) - 1 + ) + mkeys = elementrefs.load_state_dict(state_dict) + self.elementrefs[key] = elementrefs + assert len(mkeys.missing_keys) == 0 + assert len(mkeys.unexpected_keys) == 0 if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) @@ -649,32 +718,40 @@ def save( training_state: bool = True, ) -> str | None: if not self.is_debug and distutils.is_master(): + state = { + "state_dict": self.model.state_dict(), + "normalizers": { + key: value.state_dict() for key, value in self.normalizers.items() + }, + "elementrefs": { + key: value.state_dict() for key, value in self.elementrefs.items() + }, + "config": self.config, + "val_metrics": metrics, + "amp": self.scaler.state_dict() if self.scaler else None, + } if training_state: - return save_checkpoint( + state.update( { "epoch": self.epoch, "step": self.step, - "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": ( self.scheduler.scheduler.state_dict() if self.scheduler.scheduler_type != "Null" else None ), - "normalizers": { - key: value.state_dict() - for key, value in self.normalizers.items() - }, "config": self.config, - "val_metrics": metrics, "ema": self.ema.state_dict() if self.ema else None, - "amp": self.scaler.state_dict() if self.scaler else None, "best_val_metric": self.best_val_metric, "primary_metric": self.evaluation_metrics.get( "primary_metric", self.evaluator.task_primary_metric[self.name], ), }, + ) + ckpt_path = save_checkpoint( + state, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) @@ -683,22 +760,13 @@ def save( self.ema.store() self.ema.copy_to() ckpt_path = save_checkpoint( - { - "state_dict": self.model.state_dict(), - "normalizers": { - key: value.state_dict() - for key, value in self.normalizers.items() - }, - "config": self.config, - "val_metrics": metrics, - "amp": self.scaler.state_dict() if self.scaler else None, - }, + state, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) if self.ema: self.ema.restore() - return ckpt_path + return ckpt_path return None def update_best( diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 72c005893d..26269c6da4 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -11,6 +11,7 @@ import os from collections import defaultdict from itertools import chain +from typing import TYPE_CHECKING import numpy as np import torch @@ -25,6 +26,9 @@ from fairchem.core.modules.scaling.util import ensure_fitted from fairchem.core.trainers.base_trainer import BaseTrainer +if TYPE_CHECKING: + from torch_geometric.data import Batch + @registry.register_trainer("ocp") @registry.register_trainer("energy") @@ -148,7 +152,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Get a batch. batch = next(train_loader_iter) - # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) @@ -227,10 +230,21 @@ def train(self, disable_eval_tqdm: bool = False) -> None: if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) + def _denorm_preds(self, target_key: str, prediction: torch.Tensor, batch: Batch): + """Convert model output from a batch into raw prediction by denormalizing and adding references""" + # denorm the outputs + if target_key in self.normalizers: + prediction = self.normalizers[target_key](prediction) + + # add element references + if target_key in self.elementrefs: + prediction = self.elementrefs[target_key](prediction, batch) + + return prediction + def _forward(self, batch): out = self.model(batch.to(self.device)) - ### TODO: Move into BaseModel in OCP 2.0 outputs = {} batch_size = batch.natoms.numel() num_atoms_in_batch = batch.natoms.sum() @@ -254,10 +268,7 @@ def _forward(self, batch): for subtarget_key in self.output_targets[target_key]["decomposition"]: irreps = self.output_targets[subtarget_key]["irrep_dim"] - _pred = out[subtarget_key] - - if self.normalizers.get(subtarget_key, False): - _pred = self.normalizers[subtarget_key].denorm(_pred) + _pred = self._denorm_preds(subtarget_key, out[subtarget_key], batch) ## Fill in the corresponding irreps prediction ## Reshape irrep prediction to (batch_size, irrep_dim) @@ -278,7 +289,6 @@ def _forward(self, batch): pred = pred.view(num_atoms_in_batch, -1) else: pred = pred.view(batch_size, -1) - outputs[target_key] = pred return outputs @@ -307,8 +317,6 @@ def _compute_loss(self, out, batch): natoms = natoms[mask] num_atoms_in_batch = natoms.numel() - if self.normalizers.get(target_name, False): - target = self.normalizers[target_name].norm(target) ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 if self.output_targets[target_name]["level"] == "atom": @@ -316,6 +324,14 @@ def _compute_loss(self, out, batch): else: target = target.view(batch_size, -1) + # to keep the loss coefficient weights balanced we remove linear references + # subtract element references from target data + if target_name in self.elementrefs: + target = self.elementrefs[target_name].dereference(target, batch) + # normalize the targets data + if target_name in self.normalizers: + target = self.normalizers[target_name].norm(target) + mult = loss_info["coefficient"] loss.append( mult @@ -373,11 +389,8 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): else: target = target.view(batch_size, -1) + out[target_name] = self._denorm_preds(target_name, out[target_name], batch) targets[target_name] = target - if self.normalizers.get(target_name, False): - out[target_name] = self.normalizers[target_name].denorm( - out[target_name] - ) targets["natoms"] = natoms out["natoms"] = natoms @@ -385,7 +398,7 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): return evaluator.eval(out, targets, prev_metrics=metrics) # Takes in a new data source and generates predictions on it. - @torch.no_grad() + @torch.no_grad def predict( self, data_loader, @@ -419,7 +432,7 @@ def predict( predictions = defaultdict(list) - for _i, batch in tqdm( + for _, batch in tqdm( enumerate(data_loader), total=len(data_loader), position=rank, @@ -430,9 +443,7 @@ def predict( out = self._forward(batch) for target_key in self.config["outputs"]: - pred = out[target_key] - if self.normalizers.get(target_key, False): - pred = self.normalizers[target_key].denorm(pred) + pred = self._denorm_preds(target_key, out[target_key], batch) if per_image: ### Save outputs in desired precision, default float16 @@ -449,7 +460,8 @@ def predict( else: dtype = torch.float16 - pred = pred.cpu().detach().to(dtype) + pred = pred.detach().cpu().to(dtype) + ### Split predictions into per-image predictions if self.config["outputs"][target_key]["level"] == "atom": batch_natoms = batch.natoms @@ -510,6 +522,7 @@ def predict( return predictions + @torch.no_grad def run_relaxations(self, split="val"): ensure_fitted(self._unwrapped_model) @@ -642,9 +655,7 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[ - :-1 - ] # np.split does not need last idx, assumes n-1:end + )[:-1] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 1584becd45..54055d0c3b 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -7,19 +7,20 @@ from pathlib import Path import numpy as np +import numpy.testing as npt import pytest import yaml +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from fairchem.core._cli import Runner +from fairchem.core.common.flags import flags from fairchem.core.common.test_utils import ( PGConfig, init_env_rank_and_launch_test, spawn_multi_process, ) -from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator - -from fairchem.core._cli import Runner -from fairchem.core.common.flags import flags from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes setup_logging() @@ -66,21 +67,56 @@ def tutorial_val_src(tutorial_dataset_path): return tutorial_dataset_path / "s2ef/val_20" -def oc20_lmdb_train_and_val_from_paths(train_src, val_src, test_src=None): +def oc20_lmdb_train_and_val_from_paths( + train_src, val_src, test_src=None, otf_norms=False +): datasets = {} if train_src is not None: datasets["train"] = { "src": train_src, - "normalize_labels": True, - "target_mean": -0.7554450631141663, - "target_std": 2.887317180633545, - "grad_target_mean": 0.0, - "grad_target_std": 2.887317180633545, + "format": "lmdb", + "key_mapping": {"y": "energy", "force": "forces"}, } + if otf_norms is True: + datasets["train"].update( + { + "transforms": { + "element_references": { + "fit": { + "targets": ["energy"], + "batch_size": 4, + "num_batches": 10, + "driver": "gelsd", + } + }, + "normalizer": { + "fit": { + "targets": {"energy": None, "forces": {"mean": 0.0}}, + "batch_size": 4, + "num_batches": 10, + } + }, + } + } + ) + else: + datasets["train"].update( + { + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0.0, "stdev": 2.887317180633545}, + } + } + } + ) if val_src is not None: - datasets["val"] = {"src": val_src} + datasets["val"] = {"src": val_src, "format": "lmdb"} if test_src is not None: - datasets["test"] = {"src": test_src} + datasets["test"] = {"src": test_src, "format": "lmdb"} return datasets @@ -124,7 +160,6 @@ def _run_main( yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) - run_args = { "run_dir": rundir, "logdir": f"{rundir}/logs", @@ -168,11 +203,6 @@ def _run_main( ) -@pytest.fixture(scope="class") -def torch_tempdir(tmpdir_factory): - return tmpdir_factory.mktemp("torch_tempdir") - - """ These tests are intended to be as quick as possible and test only that the network is runnable and outputs training+validation to tensorboard output These should catch errors such as shape mismatches or otherways to code wise break a network @@ -180,12 +210,7 @@ def torch_tempdir(tmpdir_factory): class TestSmoke: - def smoke_test_train( - self, - model_name, - input_yaml, - tutorial_val_src, - ): + def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False): with tempfile.TemporaryDirectory() as tempdirname: # first train a very simple model, checkpoint train_rundir = Path(tempdirname) / "train" @@ -201,6 +226,7 @@ def smoke_test_train( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), + otf_norms=otf_norms, ), }, save_checkpoint_to=checkpoint_path, @@ -222,6 +248,7 @@ def smoke_test_train( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), + otf_norms=otf_norms, ), }, update_run_args_with={ @@ -231,42 +258,65 @@ def smoke_test_train( save_predictions_to=predictions_filename, ) + if otf_norms is True: + norm_path = glob.glob( + str(train_rundir / "checkpoints" / "*" / "normalizers.pt") + ) + assert len(norm_path) == 1 + assert os.path.isfile(norm_path[0]) + ref_path = glob.glob( + str(train_rundir / "checkpoints" / "*" / "element_references.pt") + ) + assert len(ref_path) == 1 + assert os.path.isfile(ref_path[0]) + # verify predictions from train and predict are identical energy_from_train = np.load(training_predictions_filename)["energy"] energy_from_checkpoint = np.load(predictions_filename)["energy"] - assert np.isclose(energy_from_train, energy_from_checkpoint).all() + npt.assert_allclose( + energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6 + ) + # not all models are tested with otf normalization estimation + # only gemnet_oc, escn, equiformer, and their hydra versions @pytest.mark.parametrize( - "model_name", + ("model_name", "otf_norms"), [ - pytest.param("schnet", id="schnet"), - pytest.param("scn", id="scn"), - pytest.param("gemnet_dt", id="gemnet_dt"), - pytest.param("gemnet_dt_hydra", id="gemnet_dt_hydra"), - pytest.param("gemnet_dt_hydra_grad", id="gemnet_dt_hydra_grad"), - pytest.param("gemnet_oc", id="gemnet_oc"), - pytest.param("gemnet_oc_hydra", id="gemnet_oc_hydra"), - pytest.param("gemnet_oc_hydra_grad", id="gemnet_oc_hydra_grad"), - pytest.param("dimenet++", id="dimenet++"), - pytest.param("dimenet++_hydra", id="dimenet++_hydra"), - pytest.param("painn", id="painn"), - pytest.param("painn_hydra", id="painn_hydra"), - pytest.param("escn", id="escn"), - pytest.param("escn_hydra", id="escn_hydra"), - pytest.param("equiformer_v2", id="equiformer_v2"), - pytest.param("equiformer_v2_hydra", id="equiformer_v2_hydra"), + ("schnet", False), + ("scn", False), + ("gemnet_dt", False), + ("gemnet_dt_hydra", False), + ("gemnet_dt_hydra_grad", False), + ("gemnet_oc", False), + ("gemnet_oc", True), + ("gemnet_oc_hydra", False), + ("gemnet_oc_hydra", True), + ("gemnet_oc_hydra_grad", False), + ("dimenet++", False), + ("dimenet++_hydra", False), + ("painn", False), + ("painn_hydra", False), + ("escn", False), + ("escn", True), + ("escn_hydra", False), + ("escn_hydra", True), + ("equiformer_v2", False), + ("equiformer_v2", True), + ("equiformer_v2_hydra", False), + ("equiformer_v2_hydra", True), ], ) def test_train_and_predict( self, model_name, + otf_norms, configs, tutorial_val_src, ): self.smoke_test_train( - model_name=model_name, input_yaml=configs[model_name], tutorial_val_src=tutorial_val_src, + otf_norms=otf_norms, ) @pytest.mark.parametrize( @@ -307,7 +357,6 @@ def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_determinist def test_balanced_batch_sampler_ddp( self, world_size, ddp, configs, tutorial_val_src, torch_deterministic ): - # make dataset metadata parser = get_lmdb_sizes_parser() args, override_args = parser.parse_known_args( diff --git a/tests/core/models/test_configs/test_equiformerv2.yml b/tests/core/models/test_configs/test_equiformerv2.yml index 54d5e61c95..8c5c200fdf 100644 --- a/tests/core/models/test_configs/test_equiformerv2.yml +++ b/tests/core/models/test_configs/test_equiformerv2.yml @@ -1,6 +1,53 @@ +trainer: forces + +logger: + name: tensorboard +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold + primary_metric: forces_mae -trainer: forces +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae model: name: equiformer_v2 @@ -45,47 +92,3 @@ model: proj_drop: 0.0 weight_init: 'normal' # ['uniform', 'normal'] - -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - -logger: - name: tensorboard - -task: - dataset: lmdb - type: regression - metric: mae - primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 - - -optim: - batch_size: 5 - eval_batch_size: 2 - num_workers: 0 - lr_initial: 0.0025 - optimizer: AdamW - optimizer_params: {"amsgrad": True,weight_decay: 0.0} - eval_every: 190 - max_epochs: 50 - force_coefficient: 20 - scheduler: "Null" - energy_coefficient: 1 - clip_grad_norm: 20 - loss_energy: mae - loss_force: l2mae diff --git a/tests/core/models/test_configs/test_escn.yml b/tests/core/models/test_configs/test_escn.yml index 5148e409e5..5848587cdd 100644 --- a/tests/core/models/test_configs/test_escn.yml +++ b/tests/core/models/test_configs/test_escn.yml @@ -1,31 +1,37 @@ trainer: forces -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - logger: name: tensorboard -task: - dataset: lmdb - type: regression - metric: mae +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 model: name: escn diff --git a/tests/core/models/test_configs/test_gemnet_oc.yml b/tests/core/models/test_configs/test_gemnet_oc.yml index a720583608..f1c0d01c3a 100644 --- a/tests/core/models/test_configs/test_gemnet_oc.yml +++ b/tests/core/models/test_configs/test_gemnet_oc.yml @@ -1,34 +1,37 @@ - - - trainer: forces -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - logger: name: tensorboard -task: - dataset: lmdb - type: regression - metric: mae +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 model: name: gemnet_oc diff --git a/tests/core/modules/conftest.py b/tests/core/modules/conftest.py new file mode 100644 index 0000000000..1b1e4ab7e6 --- /dev/null +++ b/tests/core/modules/conftest.py @@ -0,0 +1,48 @@ +from itertools import product +from random import choice +import pytest +import numpy as np +from pymatgen.core.periodic_table import Element +from pymatgen.core import Structure + +from fairchem.core.datasets import LMDBDatabase, AseDBDataset + + +@pytest.fixture(scope="session") +def dummy_element_refs(): + # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) + return np.concatenate( + [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] + ) + + +@pytest.fixture(scope="session") +def max_num_elements(dummy_element_refs): + return len(dummy_element_refs) - 1 + + +@pytest.fixture(scope="session") +def dummy_binary_dataset(tmpdir_factory, dummy_element_refs): + # a dummy dataset with binaries with energy that depends on composition only plus noise + all_binaries = list(product(list(Element), repeat=2)) + rng = np.random.default_rng(seed=0) + + tmpdir = tmpdir_factory.mktemp("dataset") + with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: + for _ in range(1000): + elements = choice(all_binaries) + structure = Structure.from_prototype("cscl", species=elements, a=2.0) + energy = ( + sum(e.average_ionic_radius for e in elements) + + 0.05 * rng.random() * dummy_element_refs.mean() + ) + atoms = structure.to_ase_atoms() + db.write(atoms, data={"energy": energy, "forces": rng.random((2, 3))}) + + dataset = AseDBDataset( + config={ + "src": str(tmpdir / "dummy.aselmdb"), + "a2g_args": {"r_data_keys": ["energy", "forces"]}, + } + ) + return dataset diff --git a/tests/core/modules/test_element_references.py b/tests/core/modules/test_element_references.py new file mode 100644 index 0000000000..62928b623c --- /dev/null +++ b/tests/core/modules/test_element_references.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import numpy as np +import numpy.testing as npt +import pytest +import torch + +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.normalization.element_references import ( + LinearReferences, + create_element_references, + fit_linear_references, +) + + +@pytest.fixture(scope="session", params=(True, False)) +def element_refs(dummy_binary_dataset, max_num_elements, request): + return fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + shuffle=False, + max_num_elements=max_num_elements, + seed=0, + use_numpy=request.param, + ) + + +def test_apply_linear_references( + element_refs, dummy_binary_dataset, dummy_element_refs +): + max_noise = 0.05 * dummy_element_refs.mean() + + # check that removing element refs keeps only values within max noise + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + energy = batch.energy.clone().view(len(batch), -1) + deref_energy = element_refs["energy"].dereference(energy, batch) + assert all(deref_energy <= max_noise) + + # and check that we recover the total energy from applying references + ref_energy = element_refs["energy"](deref_energy, batch) + assert torch.allclose(ref_energy, energy) + + +def test_create_element_references(element_refs, tmp_path): + # test from state dict + sdict = element_refs["energy"].state_dict() + + refs = create_element_references(state_dict=sdict) + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # test from saved stated dict + torch.save(sdict, tmp_path / "linref.pt") + refs = create_element_references(file=tmp_path / "linref.pt") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # from a legacy numpy npz file + np.savez( + tmp_path / "linref.npz", coeff=element_refs["energy"].element_references.numpy() + ) + refs = create_element_references(file=tmp_path / "linref.npz") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # from a numpy npz file + np.savez( + tmp_path / "linref.npz", + element_references=element_refs["energy"].element_references.numpy(), + ) + + refs = create_element_references(file=tmp_path / "linref.npz") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + +def test_fit_linear_references( + element_refs, dummy_binary_dataset, max_num_elements, dummy_element_refs +): + # create the composition matrix + energy = np.array([d.energy for d in dummy_binary_dataset]) + cmatrix = np.vstack( + [ + np.bincount(d.atomic_numbers.int().numpy(), minlength=max_num_elements + 1) + for d in dummy_binary_dataset + ] + ) + mask = cmatrix.sum(axis=0) != 0.0 + + # fit using numpy + element_refs_np = np.zeros(max_num_elements + 1) + element_refs_np[mask] = np.linalg.lstsq(cmatrix[:, mask], energy, rcond=None)[0] + + # length is max_num_elements + 1, since H starts at 1 + assert len(element_refs["energy"].element_references) == max_num_elements + 1 + # first element is dummy, should always be zero + assert element_refs["energy"].element_references[0] == 0.0 + # elements not present should be zero + npt.assert_allclose(element_refs["energy"].element_references.numpy()[~mask], 0.0) + # torch fit vs numpy fit + npt.assert_allclose( + element_refs_np, element_refs["energy"].element_references.numpy(), atol=1e-5 + ) + # close enough to ground truth w/out noise + npt.assert_allclose( + dummy_element_refs[mask], + element_refs["energy"].element_references.numpy()[mask], + atol=5e-2, + ) + + +def test_fit_seed_no_seed(dummy_binary_dataset, max_num_elements): + refs_seed = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=0, + ) + refs_seed1 = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=0, + ) + refs_noseed = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=1, + ) + + assert torch.allclose( + refs_seed["energy"].element_references, + refs_seed1["energy"].element_references, + atol=1e-6, + ) + assert not torch.allclose( + refs_seed["energy"].element_references, + refs_noseed["energy"].element_references, + atol=1e-6, + ) diff --git a/tests/core/modules/test_normalizer.py b/tests/core/modules/test_normalizer.py new file mode 100644 index 0000000000..b0d4a44040 --- /dev/null +++ b/tests/core/modules/test_normalizer.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.normalization.normalizer import ( + Normalizer, + create_normalizer, + fit_normalizers, +) + + +@pytest.fixture(scope="session") +def normalizers(dummy_binary_dataset): + return fit_normalizers( + ["energy", "forces"], + override_values={"forces": {"mean": 0.0}}, + dataset=dummy_binary_dataset, + batch_size=16, + shuffle=False, + ) + + +def test_norm_denorm(normalizers, dummy_binary_dataset, dummy_element_refs): + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + # test norm and denorm + for target, normalizer in normalizers.items(): + normed = normalizer.norm(batch[target]) + assert torch.allclose( + (batch[target] - normalizer.mean) / normalizer.rmsd, normed + ) + assert torch.allclose( + normalizer.rmsd * normed + normalizer.mean, normalizer(normed) + ) + + +def test_create_normalizers(normalizers, dummy_binary_dataset, tmp_path): + # test that forces mean was overriden + assert normalizers["forces"].mean.item() == 0.0 + + # test from state dict + sdict = normalizers["energy"].state_dict() + + norm = create_normalizer(state_dict=sdict) + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # test from saved stated dict + torch.save(sdict, tmp_path / "norm.pt") + norm = create_normalizer(file=tmp_path / "norm.pt") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from a legacy numpy npz file + np.savez( + tmp_path / "norm.npz", + mean=normalizers["energy"].mean.numpy(), + std=normalizers["energy"].rmsd.numpy(), + ) + norm = create_normalizer(file=tmp_path / "norm.npz") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from a new npz file + np.savez( + tmp_path / "norm.npz", + mean=normalizers["energy"].mean.numpy(), + rmsd=normalizers["energy"].rmsd.numpy(), + ) + norm = create_normalizer(file=tmp_path / "norm.npz") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from tensor directly + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + norm = create_normalizer(tensor=batch.energy) + assert isinstance(norm, Normalizer) + # assert norm.state_dict() == sdict + # not sure why the above fails + new_sdict = norm.state_dict() + for key in sdict: + assert torch.allclose(new_sdict[key], sdict[key]) + + # passing values directly + norm = create_normalizer( + mean=batch.energy.mean().item(), rmsd=batch.energy.std().item() + ) + assert isinstance(norm, Normalizer) + # assert norm.state_dict() == sdict + new_sdict = norm.state_dict() + for key in sdict: + assert torch.allclose(new_sdict[key], sdict[key]) + + # bad construction + with pytest.raises(ValueError): + create_normalizer(mean=1.0) From 214522d2eed77c1c773cae27a5ad3487a7c8c8be Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 5 Aug 2024 17:05:03 -0700 Subject: [PATCH 7/8] update ocp example config (#794) --- configs/ocp_example.yml | 59 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/configs/ocp_example.yml b/configs/ocp_example.yml index b979a7a324..a988b4ef1a 100644 --- a/configs/ocp_example.yml +++ b/configs/ocp_example.yml @@ -12,7 +12,7 @@ dataset: # Can use 'single_point_lmdb' or 'trajectory_lmdb' for backward compatibility. # 'single_point_lmdb' was for training IS2RE models, and 'trajectory_lmdb' was # for training S2EF models. - format: lmdb # 'lmdb' or 'oc22_lmdb' + format: lmdb # 'lmdb', 'oc22_lmdb', or 'ase_d' # Directory containing training set LMDBs src: data/s2ef/all/train/ # If we want to rename a target value stored in the data object, specify the mapping here. @@ -34,9 +34,11 @@ dataset: irrep_dim: 0 anisotropic_stress: irrep_dim: 2 - # If we want to normalize targets, i.e. subtract the mean and - # divide by standard deviation, then specify the 'mean' and 'stdev' here. + # If we want to normalize targets, there are a couple of ways to specify normalization values. + # normalization values are applied as: (target - mean) / rmsd + # Note root mean squared difference (rmsd) is equal to stdev if mean != 0, and equal to rms if mean == 0. # Statistics will by default be applied to the validation and test set. + # 1) specify the 'mean' and 'stdev' explicitly here. normalizer: energy: mean: -0.7554450631141663 @@ -49,6 +51,52 @@ dataset: stdev: 674.1657344451734 anisotropic_stress: stdev: 143.72764771869745 + # 2) Estimate the values on-the-fly (OTF) from training data + normalizer: + fit: + targets: + forces: { mean: 0.0 } # values can be explicitly set, ie if you need RMS forces instead of stdev force + stress_isotropic: { } # to estimate both mean and rmsd set to {} or None + stress_anisotropic: { } + batch_size: 64 + num_batches: 5000 # if num_batches is not given, the whole dataset will be used + # 3) Specify a single .pt file with dict of target names and Normalizer modules + # (this is the format that OTF vales are saved in) + # see Normalizer module in fairchem.core.modules.normalization.normalizer + normalizer: + file: normalizers.pt + # 4) specify an individual file either .pt or .npz with keys 'mean' and 'rmsd' or 'stdev' + normalizer: + energy: + file: energy_norm.pt + forces: + file: forces_norm.npz + isotropic_stress: + file: isostress_norm.npz + anisotropic_stress: + file: anisostress_norm.npz + # If we want to train on total energies and use a per-element linear reference + # normalization scheme, we can estimate those from the data or specify the path to the per-element + # 1) Fit element references from data + element_references: + fit: + targets: + - energy + batch_size: 64 + num_batches: 5000 # if num_batches is not given, the whole dataset will be used + # 2) Specify a file with with key energy and LinearReference object. This is the format OTF references are saved in. + # see fairchem.core.modules.normalization.element_references for references. + element_references: + file: element_references.pt + # 3) Legacy files in npz format can be specified as well. They must have the elemenet references + # under the key coeff + element_references: + energy: + file: element_ref.npz + # 4) backwards compatibility only, linear references can be set as follows. Setting the references + # file as follows is a legacy setting and only works with oc22_lmdb and ase_lmdb datasets + lin_ref: element_ref.npz + # If we want to train OC20 on total energy, a path to OC20 reference # energies `oc20_ref` must be specified to unreference existing OC20 data. # download at https://dl.fbaipublicfiles.com/opencatalystproject/data/oc22/oc20_ref.pkl @@ -56,10 +104,7 @@ dataset: # OC22 defaults to total energy, so these flags are not necessary. train_on_oc20_total_energies: False # True or False oc20_ref: None # path to oc20_ref - # If we want to train on total energies and use a linear reference - # normalization scheme, we must specify the path to the per-element - # coefficients in a `.npz` format. - lin_ref: False # True or False + val: # Directory containing val set LMDBs src: data/s2ef/all/val_id/ From b2eebb6a0f8ab06fef504fc2989daf9216816eab Mon Sep 17 00:00:00 2001 From: Misko Date: Mon, 5 Aug 2024 19:02:52 -0700 Subject: [PATCH 8/8] Add an option to run PBC in single system mode (#795) * do pbc per system * add option to use single system pbc * remove comments * integrate use_pbc_single to all the models in repo; add test --- src/fairchem/core/models/base.py | 49 ++++++++++++++++--- src/fairchem/core/models/dimenet_plus_plus.py | 36 ++++++-------- .../models/equiformer_v2/equiformer_v2.py | 2 + src/fairchem/core/models/escn/escn.py | 3 ++ src/fairchem/core/models/gemnet/gemnet.py | 2 + src/fairchem/core/models/gemnet_gp/gemnet.py | 2 + .../core/models/gemnet_oc/gemnet_oc.py | 4 ++ src/fairchem/core/models/painn/painn.py | 2 + src/fairchem/core/models/schnet.py | 20 ++++---- src/fairchem/core/models/scn/scn.py | 3 ++ tests/core/e2e/test_s2ef.py | 19 +++++++ 11 files changed, 106 insertions(+), 36 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index eb8c9d543c..8ce8f3fcb1 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -53,10 +53,12 @@ def generate_graph( use_pbc=None, otf_graph=None, enforce_max_neighbors_strictly=None, + use_pbc_single=False, ): cutoff = cutoff or self.cutoff max_neighbors = max_neighbors or self.max_neighbors use_pbc = use_pbc or self.use_pbc + use_pbc_single = use_pbc_single or self.use_pbc_single otf_graph = otf_graph or self.otf_graph if enforce_max_neighbors_strictly is not None: @@ -84,12 +86,47 @@ def generate_graph( if use_pbc: if otf_graph: - edge_index, cell_offsets, neighbors = radius_graph_pbc( - data, - cutoff, - max_neighbors, - enforce_max_neighbors_strictly, - ) + if use_pbc_single: + ( + edge_index_per_system, + cell_offsets_per_system, + neighbors_per_system, + ) = list( + zip( + *[ + radius_graph_pbc( + data[idx], + cutoff, + max_neighbors, + enforce_max_neighbors_strictly, + ) + for idx in range(len(data)) + ] + ) + ) + + # atom indexs in the edge_index need to be offset + atom_index_offset = data.natoms.cumsum(dim=0).roll(1) + atom_index_offset[0] = 0 + edge_index = torch.hstack( + [ + edge_index_per_system[idx] + atom_index_offset[idx] + for idx in range(len(data)) + ] + ) + cell_offsets = torch.vstack(cell_offsets_per_system) + neighbors = torch.hstack(neighbors_per_system) + else: + ## TODO this is the original call, but blows up with memory + ## using two different samples + ## sid='mp-675045-mp-675045-0-7' (MPTRAJ) + ## sid='75396' (OC22) + edge_index, cell_offsets, neighbors = radius_graph_pbc( + data, + cutoff, + max_neighbors, + enforce_max_neighbors_strictly, + ) out = get_pbc_distances( data.pos, diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index aa08ea0672..f555448261 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -352,16 +352,13 @@ def forward( ) } if self.regress_forces: - outputs["forces"] = ( - -1 - * ( - torch.autograd.grad( - outputs["energy"], - data.pos, - grad_outputs=torch.ones_like(outputs["energy"]), - create_graph=True, - )[0] - ) + outputs["forces"] = -1 * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] ) return outputs @@ -371,6 +368,7 @@ class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, hidden_channels: int = 128, num_blocks: int = 4, @@ -388,6 +386,7 @@ def __init__( ) -> None: self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.max_neighbors = 50 @@ -466,16 +465,13 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = ( - -1 - * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + forces = -1 * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] ) outputs["forces"] = forces diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index e2625eadaf..06a0280e98 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -116,6 +116,7 @@ class EquiformerV2(nn.Module, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = True, max_neighbors: int = 500, @@ -169,6 +170,7 @@ def __init__( raise ImportError self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.regress_forces = regress_forces self.otf_graph = otf_graph self.max_neighbors = max_neighbors diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index dfa872c398..d6367fa9ad 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -47,6 +47,7 @@ class eSCN(nn.Module, GraphModelMixin): Args: use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time regress_forces (bool): Compute forces otf_graph (bool): Compute graph On The Fly (OTF) max_neighbors (int): Maximum number of neighbors per atom @@ -69,6 +70,7 @@ class eSCN(nn.Module, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, max_neighbors: int = 40, @@ -100,6 +102,7 @@ def __init__( self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.show_timing_info = show_timing_info diff --git a/src/fairchem/core/models/gemnet/gemnet.py b/src/fairchem/core/models/gemnet/gemnet.py index 59b3eda08f..f5537b9535 100644 --- a/src/fairchem/core/models/gemnet/gemnet.py +++ b/src/fairchem/core/models/gemnet/gemnet.py @@ -118,6 +118,7 @@ def __init__( extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, output_init: str = "HeOrthogonal", activation: str = "swish", num_elements: int = 83, @@ -143,6 +144,7 @@ def __init__( self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # GemNet variants self.direct_forces = direct_forces diff --git a/src/fairchem/core/models/gemnet_gp/gemnet.py b/src/fairchem/core/models/gemnet_gp/gemnet.py index a75756dcc1..97af540de2 100644 --- a/src/fairchem/core/models/gemnet_gp/gemnet.py +++ b/src/fairchem/core/models/gemnet_gp/gemnet.py @@ -114,6 +114,7 @@ def __init__( extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, output_init: str = "HeOrthogonal", activation: str = "swish", scale_num_blocks: bool = False, @@ -142,6 +143,7 @@ def __init__( self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # GemNet variants self.direct_forces = direct_forces diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index 0aea3d81ba..c9dd9e13ed 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -108,6 +108,8 @@ class GemNetOC(nn.Module, GraphModelMixin): If False predict forces based on negative gradient of energy potential. use_pbc: bool Whether to use periodic boundary conditions. + use_pbc_single: + Process batch PBC graphs one at a time scale_backprop_forces: bool Whether to scale up the energy and then scales down the forces to prevent NaNs and infs in backpropagated forces. @@ -203,6 +205,7 @@ def __init__( regress_forces: bool = True, direct_forces: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, scale_backprop_forces: bool = False, cutoff: float = 6.0, cutoff_qint: float | None = None, @@ -269,6 +272,7 @@ def __init__( ) self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.direct_forces = direct_forces self.forces_coupled = forces_coupled diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index ec9e9f465c..33425e8d8d 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -73,6 +73,7 @@ def __init__( regress_forces: bool = True, direct_forces: bool = True, use_pbc: bool = True, + use_pbc_single: bool = False, otf_graph: bool = True, num_elements: int = 83, scale_file: str | None = None, @@ -92,6 +93,7 @@ def __init__( self.direct_forces = direct_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # Borrowed from GemNet. self.symmetric_edge_symmetrization = False diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 5ca70a354e..878aee746a 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -30,6 +30,7 @@ class SchNetWrap(SchNet, GraphModelMixin): Args: use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. (default: :obj:`True`) + use_pbc_single (bool,optional): Process batch PBC graphs one at a time regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating energy with respect to positions. (default: :obj:`True`) @@ -52,6 +53,7 @@ class SchNetWrap(SchNet, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, hidden_channels: int = 128, @@ -64,6 +66,7 @@ def __init__( self.num_targets = 1 self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.max_neighbors = 50 @@ -111,16 +114,13 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = ( - -1 - * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + forces = -1 * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] ) outputs["forces"] = forces diff --git a/src/fairchem/core/models/scn/scn.py b/src/fairchem/core/models/scn/scn.py index 84806e19e8..299fa48584 100644 --- a/src/fairchem/core/models/scn/scn.py +++ b/src/fairchem/core/models/scn/scn.py @@ -39,6 +39,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin): Args: use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time regress_forces (bool): Compute forces otf_graph (bool): Compute graph On The Fly (OTF) max_num_neighbors (int): Maximum number of neighbors per atom @@ -76,6 +77,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin): def __init__( self, use_pbc: bool = True, + use_pbc_single: bool = True, regress_forces: bool = True, otf_graph: bool = False, max_num_neighbors: int = 20, @@ -107,6 +109,7 @@ def __init__( self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.show_timing_info = show_timing_info diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 54055d0c3b..9a68c4771c 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -319,6 +319,25 @@ def test_train_and_predict( otf_norms=otf_norms, ) + def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"use_pbc_single": True}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + ) + @pytest.mark.parametrize( ("world_size", "ddp"), [