Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): add dpa3 alpha descriptor #4476

Merged
merged 5 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


class RepFlowArgs:
def __init__(
self,
n_dim: int = 128,
e_dim: int = 64,
a_dim: int = 64,
nlayers: int = 6,
e_rcut: float = 6.0,
e_rcut_smth: float = 5.0,
e_sel: int = 120,
a_rcut: float = 4.0,
a_rcut_smth: float = 3.5,
a_sel: int = 20,
a_compress_rate: int = 0,
axis_neuron: int = 4,
update_angle: bool = True,
update_style: str = "res_residual",
update_residual: float = 0.1,
update_residual_init: str = "const",
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.

Parameters
----------
n_dim : int, optional
The dimension of node representation.
e_dim : int, optional
The dimension of edge representation.
a_dim : int, optional
The dimension of angle representation.
nlayers : int, optional
Number of repflow layers.
e_rcut : float, optional
The edge cut-off radius.
e_rcut_smth : float, optional
Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth.
e_sel : int, optional
Maximally possible number of selected edge neighbors.
a_rcut : float, optional
The angle cut-off radius.
a_rcut_smth : float, optional
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
a_sel : int, optional
Maximally possible number of selected angle neighbors.
a_compress_rate : int, optional
The compression rate for angular messages. The default value is 0, indicating no compression.
If a non-zero integer c is provided, the node and edge dimensions will be compressed
to n_dim/c and e_dim/2c, respectively, within the angular message.
axis_neuron : int, optional
The number of dimension of submatrix in the symmetrization ops.
update_angle : bool, optional
Where to update the angle rep. If not, only node and edge rep will be used.
update_style : str, optional
Style to update a representation.
Supported options are:
-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n)
-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)
-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n)
where `r1`, `r2` ... `r3` are residual weights defined by `update_residual`
and `update_residual_init`.
update_residual : float, optional
When update using residual mode, the initial std of residual vector weights.
update_residual_init : str, optional
When update using residual mode, the initialization mode of residual vector weights.
"""
self.n_dim = n_dim
self.e_dim = e_dim
self.a_dim = a_dim
self.nlayers = nlayers
self.e_rcut = e_rcut
self.e_rcut_smth = e_rcut_smth
self.e_sel = e_sel
self.a_rcut = a_rcut
self.a_rcut_smth = a_rcut_smth
self.a_sel = a_sel
self.a_compress_rate = a_compress_rate
self.axis_neuron = axis_neuron
self.update_angle = update_angle
self.update_style = update_style
self.update_residual = update_residual
self.update_residual_init = update_residual_init

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(key)

def serialize(self) -> dict:
return {
"n_dim": self.n_dim,
"e_dim": self.e_dim,
"a_dim": self.a_dim,
"nlayers": self.nlayers,
"e_rcut": self.e_rcut,
"e_rcut_smth": self.e_rcut_smth,
"e_sel": self.e_sel,
"a_rcut": self.a_rcut,
"a_rcut_smth": self.a_rcut_smth,
"a_sel": self.a_sel,
"a_compress_rate": self.a_compress_rate,
"axis_neuron": self.axis_neuron,
"update_angle": self.update_angle,
"update_style": self.update_style,
"update_residual": self.update_residual,
"update_residual_init": self.update_residual_init,
}

@classmethod
def deserialize(cls, data: dict) -> "RepFlowArgs":
return cls(**data)
71 changes: 40 additions & 31 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,28 +187,26 @@
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
energy_pred = energy_pred * atom_norm
energy_label = energy_label * atom_norm

Check warning on line 191 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L190-L191

Added lines #L190 - L191 were not covered by tests
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="sum",
reduction="mean",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
).detach(),
l1_ener_loss.detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)
# if mae:
# mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
# more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
# mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
# more_loss["mae_e_all"] = self.display_if_exist(
# mae_e_all.detach(), find_energy
Comment on lines +203 to +208

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# )

if (
(self.has_f or self.has_pf or self.relative_f or self.has_gf)
Expand Down Expand Up @@ -241,17 +239,17 @@
rmse_f.detach(), find_force
)
else:
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean")

Check warning on line 242 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L242

Added line #L242 was not covered by tests
more_loss["mae_f"] = self.display_if_exist(
l1_force_loss.mean().detach(), find_force
l1_force_loss.detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
# l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = self.display_if_exist(
mae_f.detach(), find_force
)
# if mae:
# mae_f = torch.mean(torch.abs(diff_f))
# more_loss["mae_f"] = self.display_if_exist(
# mae_f.detach(), find_force
Comment on lines +248 to +251

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# )

if self.has_pf and "atom_pref" in label:
atom_pref = label["atom_pref"]
Expand Down Expand Up @@ -297,18 +295,29 @@
if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
virial_label = label["virial"]
virial_pred = model_pred["virial"].reshape(-1, 9)
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
if not self.use_l1_all:
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
)
else:
l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean")
more_loss["mae_v"] = self.display_if_exist(

Check warning on line 314 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L313-L314

Added lines #L313 - L314 were not covered by tests
l1_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION)

Check warning on line 317 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L317

Added line #L317 was not covered by tests
# if mae:
# mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
# more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
Comment on lines +318 to +320

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
atom_ener = model_pred["atom_energy"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from .dpa2 import (
DescrptDPA2,
)
from .dpa3 import (
DescrptDPA3,
)
from .env_mat import (
prod_env_mat,
)
Expand Down Expand Up @@ -49,6 +52,7 @@
"DescrptBlockSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptDPA3",
"DescrptHybrid",
"DescrptSeA",
"DescrptSeAttenV2",
Expand Down
Loading
Loading