Skip to content

Commit

Permalink
init var
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu committed Oct 24, 2024
1 parent 610767d commit 5381bc6
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.fparam_inv_std),
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
t_aparam_avg = tf.get_variable(
"t_aparam_avg",
self.numb_aparam,
Expand All @@ -576,6 +576,13 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.aparam_inv_std),
)
else:
t_aparam_avg = tf.zeros(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)
t_aparam_istd = tf.ones(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)

inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt])
if len(self.atom_ener):
Expand All @@ -602,14 +609,7 @@ def build(
fparam = (fparam - t_fparam_avg) * t_fparam_istd

aparam = None
if self.numb_aparam > 0:
if self.use_aparam_as_mask:
t_aparam_avg = tf.zeros(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)
t_aparam_istd = tf.ones(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
Expand Down

0 comments on commit 5381bc6

Please sign in to comment.