-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge changes during the E3 Hamiltonian model development into main #85
Conversation
else: | ||
self.out_layer = AtomicMLP(**config[-1], if_batch_normalized=False, in_field=in_field, out_field=out_field, activation=activation, device=device, dtype=dtype) | ||
nn.init.normal_(self.out_layer.out_layer.weight, mean=0, std=1e-3) | ||
nn.init.normal_(self.out_layer.out_layer.bias, mean=0, std=1e-3) | ||
# self.out_norm = nn.LayerNorm(config[-1]['out_features'], elementwise_affine=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些std 不应该是从 common options 传过来的吗?
examples/silicon/nnsk/checkpoint/best_nnsk_b2.600_c2.600_w0.300.pth
Outdated
Show resolved
Hide resolved
dptb/data/dataset/_deeph_dataset.py
Outdated
def E3statistics(self, decay=False): | ||
assert self.transform is not None | ||
idp = self.transform | ||
|
||
torch.save(atomic_data, os.path.join(file, "AtomicData.pth")) | ||
if self.data[AtomicDataDict.EDGE_FEATURES_KEY].abs().sum() < 1e-7: | ||
return None | ||
|
||
typed_dataset = idp(self.data.clone().to_dict()) | ||
e3h = E3Hamiltonian(basis=idp.basis, decompose=True) | ||
with torch.no_grad(): | ||
typed_dataset = e3h(typed_dataset) | ||
|
||
return atomic_data | ||
stats = {} | ||
stats["node"] = self._E3nodespecies_stat(typed_dataset=typed_dataset) | ||
stats["edge"] = self._E3edgespecies_stat(typed_dataset=typed_dataset, decay=decay) | ||
|
||
return stats | ||
|
||
def len(self) -> int: | ||
return self.num_examples No newline at end of file | ||
def _E3edgespecies_stat(self, typed_dataset, decay): | ||
# we get the bond type marked dataset first | ||
idp = self.transform | ||
typed_dataset = typed_dataset | ||
|
||
idp.get_irreps(no_parity=False) | ||
irrep_slices = idp.orbpair_irreps.slices() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些statistics的代码 在deeph_dataset 和我们自己的default_dataset 一样吗? 如果一样是不是可以复用啊。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以,后面提出来
The main changes are listed below:
add support for element-specific cutoff in the AtomicData class. The max, can be defined as a dict as {A: ra, B:rb}, the neighbour list will only contain bond A-A for ra, B-B for rb, and A-B for (ra+rb)/2
add a dataset to load atomic data style in DeePH's format directly, called DeePHDataset. Add E3statistics in DeePHDataset
Add support in load density matrix data, formatted as a h5 file named DM.h5. Refactor the data loading that allows the user to decide whether to load hamiltonian/eigenvalue/overlap/density matrix, by tag: get_hamiltonian/get_eigenvalues/get_overlap/get_DM
Refactor the data parsing module to align with the modification above, where the user can choose whether to parse hamiltonian/overlap/eigenvalues/DM
updating the E3-Equivariant tensor fitting model, including a nonlocal and a local one.
Implementing the SO(2) convolution and its edge-wise parameterization in our method.
Accelerating E3 Hamiltonian Rotation with recalculated JD matrix
Add diagonal mapping in OrbitalMapper Class.
Write a prototype to_model_v1 method in nnsk class to transcript the v2 model to v1 json model.
Add many functions in loss module, including: a) support the shift of diagonal elements during E3 Hamiltonian fitting 2) develop the per-bond balanced loss which computed the error of each bond separately, and added up with the same weight to avoid the data scarcity and imbalance. 3) Add HamilLossAnalysis Class to analyze the error (per bond, per element, per irreps, overall) between predicted hamiltonain and the reference one
Adjust argcheck accordingly.
Fix bugs: a) the ref_batch data are using train_loader's batch information b) iterator error in ref_loader.