-
Notifications
You must be signed in to change notification settings - Fork 0
/
sevenn_runner.py
161 lines (141 loc) · 5.09 KB
/
sevenn_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
from typing import Union
import numpy as np
import torch
import torch.jit
from ase.calculators.calculator import Calculator, all_changes
from ase.data import chemical_symbols
import sevenn._const
import sevenn._keys as KEY
import sevenn.util
import time
torch_script_type = torch.jit._script.RecursiveScriptModule
class SevenNetCalculator(Calculator):
"""ASE calculator for SevenNet models
Multi-GPU parallel MD is not supported for this mode.
Use LAMMPS for multi-GPU parallel MD.
This class is for convenience who want to run SevenNet models with ase.
Note than ASE calculator is designed to be interface of other programs.
But in this class, we simply run torch model inside ASE calculator.
So there is no FileIO things.
Here, free_energy = energy
"""
def __init__(
self,
model: str = 'SevenNet-0',
file_type: str = 'checkpoint',
device: Union[torch.device, str] = 'auto',
sevennet_config=None,
**kwargs,
):
"""Initialize the calculator
Args:
model (SevenNet): path to the checkpoint file, or pretrained
device (str, optional): Torch device to use. Defaults to "auto".
"""
super().__init__(**kwargs)
file_type = file_type.lower()
if file_type not in ['checkpoint', 'torchscript']:
raise ValueError('file_type should be checkpoint or torchscript')
if not isinstance(device, torch.device) and not isinstance(
device, str
):
raise ValueError(
'device must be an instance of torch.device or str.'
)
if isinstance(device, str):
if device == 'auto':
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu'
)
else:
self.device = torch.device(device)
else:
self.device = device
if file_type == 'checkpoint':
if os.path.isfile(model):
checkpoint = model
else:
checkpoint = sevenn.util.pretrained_name_to_path(model)
model_loaded, config = sevenn.util.model_from_checkpoint(
checkpoint
)
model_loaded.set_is_batch_data(False)
self.type_map = config[KEY.TYPE_MAP]
self.cutoff = config[KEY.CUTOFF]
self.sevennet_config = config
elif file_type == 'torchscript':
extra_dict = {
'chemical_symbols_to_index': b'',
'cutoff': b'',
'num_species': b'',
'model_type': b'',
'version': b'',
'dtype': b'',
'time': b'',
}
model_loaded = torch.jit.load(
model, _extra_files=extra_dict, map_location=self.device
)
chem_symbols = extra_dict['chemical_symbols_to_index'].decode(
'utf-8'
)
sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)}
self.type_map = {
sym_to_num[sym]: i
for i, sym in enumerate(chem_symbols.split())
}
self.cutoff = float(extra_dict['cutoff'].decode('utf-8'))
else:
raise ValueError('Unknown file type')
self.model = model_loaded
self.model.to(self.device)
self.model.eval()
self.implemented_properties = [
'free_energy',
'energy',
'forces',
'stress',
'energies',
]
def calculate(
self, atoms=None, properties=None, system_changes=all_changes
):
# start = time.time()
# call parent class to set necessary atom attributes
Calculator.calculate(self, atoms, properties, system_changes)
data = sevenn.util.unlabeled_atoms_to_input(atoms, self.cutoff)
# end = time.time()
# print("setup time: ", end - start)
data[KEY.NODE_FEATURE] = torch.LongTensor(
[self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]]
)
# print("self.device: ", self.device)
# start = time.time()
data.to(self.device)
if isinstance(self.model, torch_script_type):
data = data.to_dict()
del data['data_info']
output = self.model(data)
# end = time.time()
# print("calculator time: ", end - start)
energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item()
# Store results
self.results = {
'free_energy': energy,
'energy': energy,
'energies': (
output[KEY.ATOMIC_ENERGY]
.detach()
.cpu()
.reshape(len(atoms))
.numpy()
),
'forces': np.clip(output[KEY.PRED_FORCE].detach().cpu().numpy(), a_min=None, a_max=50),
'stress': np.array(
(-output[KEY.PRED_STRESS])
.detach()
.cpu()
.numpy()[[0, 1, 2, 4, 5, 3]]
),
}