-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): add dpa3 alpha descriptor #4476
Conversation
|
||
# deserialize repflow | ||
statistic_repflows = repflow_variable.pop("@variables") | ||
env_mat = repflow_variable.pop("env_mat") |
Check notice
Code scanning / CodeQL
Unused local variable Note
# cast the input to internal precsion | ||
extended_coord = extended_coord.to(dtype=self.prec) | ||
nframes, nloc, nnei = nlist.shape | ||
nall = extended_coord.view(nframes, -1).shape[1] // 3 |
Check notice
Code scanning / CodeQL
Unused local variable Note
# h2: nb x nloc x nnei x 3 | ||
# msk: nb x nloc x nnei | ||
nb, nloc, nnei, _ = edge_ebd.shape | ||
e_dim = edge_ebd.shape[-1] |
Check notice
Code scanning / CodeQL
Unused local variable Note
def list_update_res_residual( | ||
self, update_list: list[torch.Tensor], update_name: str = "node" | ||
) -> torch.Tensor: | ||
nitem = len(update_list) |
Check notice
Code scanning / CodeQL
Unused local variable Note
[False], # use_econf_tebd | ||
): | ||
dtype = PRECISION_DICT[prec] | ||
rtol, atol = get_tols(prec) |
Check notice
Code scanning / CodeQL
Unused local variable Note test
[False], # use_econf_tebd | ||
): | ||
dtype = PRECISION_DICT[prec] | ||
rtol, atol = get_tols(prec) |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
||
dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) | ||
dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) | ||
model = torch.jit.script(dd0) |
Check notice
Code scanning / CodeQL
Unused local variable Note test
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## dpa3-alpha #4476 +/- ##
==============================================
+ Coverage 83.75% 84.48% +0.72%
==============================================
Files 667 674 +7
Lines 61514 62824 +1310
Branches 3486 3486
==============================================
+ Hits 51523 53074 +1551
+ Misses 8866 8623 -243
- Partials 1125 1127 +2 ☔ View full report in Codecov by Sentry. |
# if mae: | ||
# mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm | ||
# more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) | ||
# mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) | ||
# more_loss["mae_e_all"] = self.display_if_exist( | ||
# mae_e_all.detach(), find_energy |
Check notice
Code scanning / CodeQL
Commented-out code Note
# if mae: | ||
# mae_f = torch.mean(torch.abs(diff_f)) | ||
# more_loss["mae_f"] = self.display_if_exist( | ||
# mae_f.detach(), find_force |
Check notice
Code scanning / CodeQL
Commented-out code Note
# if mae: | ||
# mae_v = torch.mean(torch.abs(diff_v)) * atom_norm | ||
# more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) |
Check notice
Code scanning / CodeQL
Commented-out code Note
This PR is an early experimental preview version of DPA3. Significant changes may occur in subsequent updates. Please use with caution.