diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 8f2f937a07..69b133de58 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -22,6 +22,7 @@ def __init__( pref_atomic: float = 0.0, pref: float = 0.0, inference=False, + enable_atomic_weight: bool = False, **kwargs, ) -> None: r"""Construct a loss for local and global tensors. @@ -40,6 +41,8 @@ def __init__( The prefactor of the weight of global loss. It should be larger than or equal to 0. inference : bool If true, it will output all losses found in output, ignoring the pre-factors. + enable_atomic_weight : bool + If true, atomic weight will be used in the loss calculation. **kwargs Other keyword arguments. """ @@ -50,6 +53,7 @@ def __init__( self.local_weight = pref_atomic self.global_weight = pref self.inference = inference + self.enable_atomic_weight = enable_atomic_weight assert ( self.local_weight >= 0.0 and self.global_weight >= 0.0 @@ -85,6 +89,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False """ model_pred = model(**input_dict) del learning_rate, mae + + if self.enable_atomic_weight: + atomic_weight = label["atom_weight"].reshape([-1, 1]) + else: + atomic_weight = 1.0 + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} if ( @@ -103,6 +113,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False diff = (local_tensor_pred - local_tensor_label).reshape( [-1, self.tensor_size] ) + diff = diff * atomic_weight if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) @@ -171,4 +182,15 @@ def label_requirement(self) -> list[DataRequirementItem]: high_prec=False, ) ) + if self.enable_atomic_weight: + label_requirement.append( + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ) + ) return label_requirement diff --git a/deepmd/tf/loss/tensor.py b/deepmd/tf/loss/tensor.py index aca9182ff6..f3fdba6ae8 100644 --- a/deepmd/tf/loss/tensor.py +++ b/deepmd/tf/loss/tensor.py @@ -40,6 +40,7 @@ def __init__(self, jdata, **kwarg) -> None: # YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight self.local_weight = jdata.get("pref_atomic", None) self.global_weight = jdata.get("pref", None) + self.enable_atomic_weight = jdata.get("enable_atomic_weight", False) assert ( self.local_weight is not None and self.global_weight is not None @@ -66,9 +67,16 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix): "global_loss": global_cvt_2_tf_float(0.0), } + if self.enable_atomic_weight: + atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1]) + else: + atomic_weight = global_cvt_2_tf_float(1.0) + if self.local_weight > 0.0: + diff = polar - atomic_polar_hat + diff = tf.reshape(diff, [-1, self.tensor_size]) * atomic_weight local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( - tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix + tf.square(self.scale * diff), name="l2_" + suffix ) more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic) l2_loss += self.local_weight * local_loss @@ -163,4 +171,16 @@ def label_requirement(self) -> list[DataRequirementItem]: type_sel=self.type_sel, ) ) + if self.enable_atomic_weight: + data_requirements.append( + DataRequirementItem( + "atom_weight", + 1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + type_sel=self.type_sel, + ) + ) return data_requirements diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5b57f15979..9eac0e804d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2511,8 +2511,9 @@ def loss_property(): def loss_tensor(): # doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]." # doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well." - doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." - doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_enable_atomic_weight = "If true, the atomic loss will be reweighted." return [ Argument( "pref", [float, int], optional=False, default=None, doc=doc_global_weight @@ -2524,6 +2525,13 @@ def loss_tensor(): default=None, doc=doc_local_weight, ), + Argument( + "enable_atomic_weight", + bool, + optional=True, + default=False, + doc=doc_enable_atomic_weight, + ), ]