From 5381bc65f94690f674fadab7988665dd72247cdf Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 24 Oct 2024 11:54:10 +0800 Subject: [PATCH] init var --- deepmd/tf/fit/ener.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index 90fb0e090b..330ea57179 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -561,7 +561,7 @@ def build( trainable=False, initializer=tf.constant_initializer(self.fparam_inv_std), ) - if self.numb_aparam > 0: + if self.numb_aparam > 0 and not self.use_aparam_as_mask: t_aparam_avg = tf.get_variable( "t_aparam_avg", self.numb_aparam, @@ -576,6 +576,13 @@ def build( trainable=False, initializer=tf.constant_initializer(self.aparam_inv_std), ) + else: + t_aparam_avg = tf.zeros( + self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION + ) + t_aparam_istd = tf.ones( + self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION + ) inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) if len(self.atom_ener): @@ -602,14 +609,7 @@ def build( fparam = (fparam - t_fparam_avg) * t_fparam_istd aparam = None - if self.numb_aparam > 0: - if self.use_aparam_as_mask: - t_aparam_avg = tf.zeros( - self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION - ) - t_aparam_istd = tf.ones( - self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION - ) + if self.numb_aparam > 0 and not self.use_aparam_as_mask: aparam = input_dict["aparam"] aparam = tf.reshape(aparam, [-1, self.numb_aparam]) aparam = (aparam - t_aparam_avg) * t_aparam_istd