Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (tf/pt): add atomic weights to tensor loss #4466

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
enable_atomic_weight: bool = False,
**kwargs,
) -> None:
r"""Construct a loss for local and global tensors.
Expand All @@ -40,6 +41,8 @@
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
enable_atomic_weight : bool
If true, atomic weight will be used in the loss calculation.
**kwargs
Other keyword arguments.
"""
Expand All @@ -50,6 +53,7 @@
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference
self.enable_atomic_weight = enable_atomic_weight

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
Expand Down Expand Up @@ -85,6 +89,12 @@
"""
model_pred = model(**input_dict)
del learning_rate, mae

if self.enable_atomic_weight:
atomic_weight = label["atom_weight"].reshape([-1, 1])

Check warning on line 94 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L94

Added line #L94 was not covered by tests
else:
atomic_weight = 1.0

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if (
Expand All @@ -103,6 +113,7 @@
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
diff = diff * atomic_weight
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
Expand Down Expand Up @@ -171,4 +182,15 @@
high_prec=False,
)
)
if self.enable_atomic_weight:
label_requirement.append(

Check warning on line 186 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L186

Added line #L186 was not covered by tests
DataRequirementItem(
"atomic_weight",
ndof=1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
)
)
return label_requirement
22 changes: 21 additions & 1 deletion deepmd/tf/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight
self.local_weight = jdata.get("pref_atomic", None)
self.global_weight = jdata.get("pref", None)
self.enable_atomic_weight = jdata.get("enable_atomic_weight", False)

assert (
self.local_weight is not None and self.global_weight is not None
Expand All @@ -66,9 +67,16 @@
"global_loss": global_cvt_2_tf_float(0.0),
}

if self.enable_atomic_weight:
atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1])

Check warning on line 71 in deepmd/tf/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/loss/tensor.py#L70-L71

Added lines #L70 - L71 were not covered by tests
else:
atomic_weight = global_cvt_2_tf_float(1.0)

Check warning on line 73 in deepmd/tf/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/loss/tensor.py#L73

Added line #L73 was not covered by tests

if self.local_weight > 0.0:
diff = polar - atomic_polar_hat
diff = tf.reshape(diff, [-1, self.tensor_size]) * atomic_weight

Check warning on line 77 in deepmd/tf/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/loss/tensor.py#L76-L77

Added lines #L76 - L77 were not covered by tests
local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean(
tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix
tf.square(self.scale * diff), name="l2_" + suffix
)
more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic)
l2_loss += self.local_weight * local_loss
Expand Down Expand Up @@ -163,4 +171,16 @@
type_sel=self.type_sel,
)
)
if self.enable_atomic_weight:
data_requirements.append(

Check warning on line 175 in deepmd/tf/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/loss/tensor.py#L175

Added line #L175 was not covered by tests
DataRequirementItem(
"atom_weight",
1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
type_sel=self.type_sel,
)
)
return data_requirements
12 changes: 10 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,8 +2511,9 @@ def loss_property():
def loss_tensor():
# doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]."
# doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_enable_atomic_weight = "If true, the atomic loss will be reweighted."
return [
Argument(
"pref", [float, int], optional=False, default=None, doc=doc_global_weight
Expand All @@ -2524,6 +2525,13 @@ def loss_tensor():
default=None,
doc=doc_local_weight,
),
Argument(
"enable_atomic_weight",
bool,
optional=True,
default=False,
doc=doc_enable_atomic_weight,
),
]


Expand Down
Loading