From 5cdef87fb781d21efa78e1e0775f07478e57d7bf Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 24 Oct 2024 11:12:49 +0800 Subject: [PATCH] resolve comments --- deepmd/dpmodel/fitting/general_fitting.py | 12 +++++++----- deepmd/pt/model/task/fitting.py | 12 +++++++----- deepmd/tf/fit/ener.py | 17 +++++++++-------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a63336566e..e0627dc11c 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -155,15 +155,17 @@ def __init__( self.fparam_inv_std = np.ones(self.numb_fparam) # pylint: disable=no-explicit-dtype else: self.fparam_avg, self.fparam_inv_std = None, None - if self.numb_aparam > 0 and not self.use_aparam_as_mask: + if self.numb_aparam > 0: self.aparam_avg = np.zeros(self.numb_aparam) # pylint: disable=no-explicit-dtype self.aparam_inv_std = np.ones(self.numb_aparam) # pylint: disable=no-explicit-dtype else: self.aparam_avg, self.aparam_inv_std = None, None # init networks - in_dim = self.dim_descrpt + self.numb_fparam - if not self.use_aparam_as_mask: - in_dim += self.numb_aparam + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ) self.nets = NetworkCollection( 1 if not self.mixed_types else 0, self.ntypes, @@ -391,7 +393,7 @@ def _call_common( axis=-1, ) # check aparam dim, concate to input descriptor - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" if aparam.shape[-1] != self.numb_aparam: raise ValueError( diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 4dfd2e38b7..90a3a2e2f7 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -198,7 +198,7 @@ def __init__( ) else: self.fparam_avg, self.fparam_inv_std = None, None - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0: self.register_buffer( "aparam_avg", torch.zeros(self.numb_aparam, dtype=self.prec, device=device), @@ -210,9 +210,11 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None - in_dim = self.dim_descrpt + self.numb_fparam - if not self.use_aparam_as_mask: - in_dim += self.numb_aparam + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ) self.filter_layers = NetworkCollection( 1 if not self.mixed_types else 0, @@ -444,7 +446,7 @@ def _forward_common( dim=-1, ) # check aparam dim, concate to input descriptor - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index fbf77a228d..de22ff6311 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -340,7 +340,7 @@ def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None: self.fparam_std[ii] = protection self.fparam_inv_std = 1.0 / self.fparam_std # stat aparam - if self.numb_aparam > 0 and not self.use_aparam_as_mask: + if self.numb_aparam > 0: sys_sumv = [] sys_sumv2 = [] sys_sumn = [] @@ -505,7 +505,7 @@ def build( self.fparam_avg = 0.0 if self.fparam_inv_std is None: self.fparam_inv_std = 1.0 - if self.numb_aparam > 0 and not self.use_aparam_as_mask: + if self.numb_aparam > 0: if self.aparam_avg is None: self.aparam_avg = 0.0 if self.aparam_inv_std is None: @@ -561,7 +561,7 @@ def build( trainable=False, initializer=tf.constant_initializer(self.fparam_inv_std), ) - if self.numb_aparam > 0 and not self.use_aparam_as_mask: + if self.numb_aparam > 0: t_aparam_avg = tf.get_variable( "t_aparam_avg", self.numb_aparam, @@ -602,7 +602,7 @@ def build( fparam = (fparam - t_fparam_avg) * t_fparam_istd aparam = None - if not self.use_aparam_as_mask and self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: aparam = input_dict["aparam"] aparam = tf.reshape(aparam, [-1, self.numb_aparam]) aparam = (aparam - t_aparam_avg) * t_aparam_istd @@ -895,9 +895,6 @@ def serialize(self, suffix: str = "") -> dict: dict The serialized data """ - in_dim = self.dim_descrpt + self.numb_fparam - if not self.use_aparam_as_mask: - in_dim += self.numb_aparam data = { "@class": "Fitting", "type": "ener", @@ -924,7 +921,11 @@ def serialize(self, suffix: str = "") -> dict: "nets": self.serialize_network( ntypes=self.ntypes, ndim=0 if self.mixed_types else 1, - in_dim=in_dim, + in_dim=( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + ), neuron=self.n_neuron, activation_function=self.activation_function_name, resnet_dt=self.resnet_dt,