from __future__ import annotations import os, json import shutil import warnings import numpy as np import pytorch_lightning as pl from dgl.data.utils import split_dataset from pytorch_lightning.loggers import CSVLogger from pymatgen.io.vasp.outputs import Vasprun import matgl from matgl.ext.pymatgen import Structure2Graph, get_element_list from matgl.graph.data import M3GNetDataset, MGLDataLoader, collate_fn_efs from matgl.models import M3GNet from matgl.utils.training import PotentialLightningModule # To suppress warnings for clearer output warnings.simplefilter("ignore") import torch torch.set_default_device('cuda') AVAIL_GPUS = torch.cuda.device_count() folder_path = './test_xml_data' xml_files = [f for f in os.listdir(folder_path) if f.endswith(".xml")] #initialize empty arrays structures = [] energies = [] forces = [] stresses = [] for xml_file in xml_files: xml_file_path = os.path.join(folder_path, xml_file) try: vrun = Vasprun(xml_file_path) print(f"File: {xml_file} loaded") for i in range(len(vrun.ionic_steps)): structures.append(vrun.ionic_steps[i]['structure']) energies.append(vrun.ionic_steps[i]['e_fr_energy']) forces.append(vrun.ionic_steps[i]['forces']) stresses.append(vrun.ionic_steps[i]['stress']) except Exception as e: print(f"Error parsing {xml_file}: {str(e)}") labels = { "energies": energies, "forces": forces, "stresses": stresses, } print(f"{len(structures)} downloaded from MP.") #formatted_data = json.dumps(labels, indent=4) #with open("labels.json","w") as json_file: #json_file.write(formatted_data) element_types = get_element_list(structures) converter = Structure2Graph(element_types=element_types, cutoff=5.0) dataset = M3GNetDataset( threebody_cutoff=4.0, structures=structures, converter=converter, energies=energies, forces=forces, stresses=stresses, #changed when downgrading to 0.7.1 #labels=labels, ) train_data, val_data, test_data = split_dataset( dataset, frac_list=[0.8, 0.1, 0.1], shuffle=True, random_state=42, ) train_loader, val_loader, test_loader = MGLDataLoader( train_data=train_data, val_data=val_data, test_data=test_data, collate_fn=collate_fn_efs, batch_size=8, num_workers=4, use_ddp='True', generator=torch.Generator("cuda"), ) model = M3GNet( element_types=element_types, is_intensive=False, use_smooth=True ) lr = 1e-4 lit_module = PotentialLightningModule(model=model,lr=lr) # If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg. logger = CSVLogger("logs", name="M3GNet_training") # Inference mode = False is required for calculating forces, stress in test mode and prediction mode trainer = pl.Trainer(max_epochs=1, accelerator="cuda", devices=4, strategy="ddp", logger=logger, inference_mode=False) trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader) trainer.test(dataloaders=test_loader) model_export_path = './trained_model/mgl.m3g_out' model.save(model_export_path) model = matgl.load_model(path = model_export_path)