-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for atomic tensor, like NMR
- Loading branch information
Showing
5 changed files
with
447 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
## Config files for atomic tensor (i.e. a tensor value for each atom) | ||
|
||
seed_everything: 35 | ||
log_level: info | ||
|
||
data: | ||
tensor_target_name: nmr_tensor | ||
atom_selector: atom_selector | ||
tensor_target_formula: ij=ji | ||
root: . | ||
trainset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json | ||
valset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json | ||
testset_filename: /Users/mjwen.admin/Documents/Dataset/NMR_tensor/si_nmr_data_small.json | ||
r_cut: 5.0 | ||
reuse: false | ||
loader_kwargs: | ||
batch_size: 2 | ||
shuffle: true | ||
|
||
model: | ||
########## | ||
# embedding | ||
########## | ||
|
||
# atom species embedding | ||
species_embedding_dim: 16 | ||
|
||
# spherical harmonics embedding of edge direction | ||
irreps_edge_sh: 0e + 1o + 2e | ||
|
||
# radial edge distance embedding | ||
radial_basis_type: bessel | ||
num_radial_basis: 8 | ||
radial_basis_start: 0. | ||
radial_basis_end: 5. | ||
|
||
########## | ||
# message passing conv layers | ||
########## | ||
num_layers: 3 | ||
|
||
# radial network | ||
invariant_layers: 2 # number of radial layers | ||
invariant_neurons: 32 # number of hidden neurons in radial function | ||
|
||
# Average number of neighbors used for normalization. Options: | ||
# 1. `auto` to determine it automatically, by setting it to average number | ||
# of neighbors of the training set | ||
# 2. float or int provided here. | ||
# 3. `null` to not use it | ||
average_num_neighbors: auto | ||
|
||
# point convolution | ||
conv_layer_irreps: 32x0o+32x0e + 16x1o+16x1e + 4x2o+4x2e | ||
nonlinearity_type: gate | ||
normalization: batch | ||
resnet: true | ||
|
||
########## | ||
# output | ||
########## | ||
|
||
# output_format and output_formula should be used together. | ||
# - output_format (can be `irreps` or `cartesian`) determines what the loss | ||
# function will be on (either on the irreps space or the cartesian space). | ||
# - output_formula gives what the cartesian formula of the tensor is. | ||
# For example, ijkl=jikl=klij specifies a forth-rank elasticity tensor. | ||
output_format: irreps | ||
output_formula: ij=ji | ||
|
||
# pooling node feats to graph feats | ||
reduce: mean | ||
|
||
trainer: | ||
max_epochs: 10 # number of maximum training epochs | ||
num_nodes: 1 | ||
accelerator: cpu | ||
devices: 1 | ||
|
||
callbacks: | ||
- class_path: pytorch_lightning.callbacks.ModelCheckpoint | ||
init_args: | ||
monitor: val/score | ||
mode: min | ||
save_top_k: 3 | ||
save_last: true | ||
verbose: false | ||
- class_path: pytorch_lightning.callbacks.EarlyStopping | ||
init_args: | ||
monitor: val/score | ||
mode: min | ||
patience: 150 | ||
min_delta: 0 | ||
verbose: true | ||
- class_path: pytorch_lightning.callbacks.ModelSummary | ||
init_args: | ||
max_depth: -1 | ||
|
||
#logger: | ||
# class_path: pytorch_lightning.loggers.wandb.WandbLogger | ||
# init_args: | ||
# save_dir: matten_logs | ||
# project: matten_proj | ||
|
||
optimizer: | ||
class_path: torch.optim.Adam | ||
init_args: | ||
lr: 0.01 | ||
weight_decay: 0.00001 | ||
|
||
lr_scheduler: | ||
class_path: torch.optim.lr_scheduler.ReduceLROnPlateau | ||
init_args: | ||
mode: min | ||
factor: 0.5 | ||
patience: 50 | ||
verbose: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Script to train the materials tensor model.""" | ||
|
||
from pathlib import Path | ||
from typing import Dict, List, Union | ||
|
||
import yaml | ||
from loguru import logger | ||
from pytorch_lightning import Trainer, seed_everything | ||
from pytorch_lightning.cli import instantiate_class as lit_instantiate_class | ||
|
||
from matten.dataset.structure_scalar_tensor import TensorDataModule | ||
from matten.log import set_logger | ||
from matten.model_factory.task import TensorRegressionTask | ||
from matten.model_factory.tfn_atomic_tensor import AtomicTensorModel | ||
|
||
|
||
def instantiate_class(d: Union[Dict, List]): | ||
args = tuple() # no positional args | ||
if isinstance(d, dict): | ||
return lit_instantiate_class(args, d) | ||
elif isinstance(d, list): | ||
return [lit_instantiate_class(args, x) for x in d] | ||
else: | ||
raise ValueError(f"Cannot instantiate class from {d}") | ||
|
||
|
||
def get_args(path: Path): | ||
"""Get the arguments from the config file.""" | ||
with open(path, "r") as f: | ||
config = yaml.safe_load(f) | ||
return config | ||
|
||
|
||
def main(config: Dict): | ||
dm = TensorDataModule(**config["data"]) | ||
dm.prepare_data() | ||
dm.setup() | ||
|
||
model = AtomicTensorModel( | ||
tasks=TensorRegressionTask(name=config["data"]["tensor_target_name"]), | ||
backbone_hparams=config["model"], | ||
dataset_hparams=dm.get_to_model_info(), | ||
optimizer_hparams=config["optimizer"], | ||
lr_scheduler_hparams=config["lr_scheduler"], | ||
) | ||
|
||
try: | ||
callbacks = instantiate_class(config["trainer"].pop("callbacks")) | ||
lit_logger = instantiate_class(config["trainer"].pop("logger")) | ||
except KeyError: | ||
callbacks = None | ||
lit_logger = None | ||
|
||
trainer = Trainer( | ||
callbacks=callbacks, | ||
logger=lit_logger, | ||
**config["trainer"], | ||
) | ||
|
||
logger.info("Start training!") | ||
trainer.fit(model, datamodule=dm) | ||
|
||
# test | ||
logger.info("Start testing!") | ||
trainer.test(ckpt_path="best", datamodule=dm) | ||
|
||
# print path of best checkpoint | ||
logger.info(f"Best checkpoint path: {trainer.checkpoint_callback.best_model_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
config_file = Path(__file__).parent / "configs" / "atomic_tensor.yaml" | ||
config = get_args(config_file) | ||
|
||
seed = config.get("seed_everything", 1) | ||
seed_everything(seed) | ||
|
||
log_level = config.get("log_level", "INFO") | ||
set_logger(log_level) | ||
|
||
main(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.