Skip to content

Commit

Permalink
fix(pt): change fitting_attr variable scope reuse to AUTO_REUSE
Browse files Browse the repository at this point in the history
Fix deepmodeling#3928. Prevent `fitting_attr` from becoming `fitting_attr_1`.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Jun 29, 2024
1 parent 20aeaf8 commit 39bd6be
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion deepmd/tf/descriptor/se_a_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def build(
aparam[:, :] is the real/virtual sign for each atom.
"""
aparam = input_dict["aparam"]
with tf.variable_scope("fitting_attr" + suffix, reuse=reuse):
with tf.variable_scope("fitting_attr" + suffix, reuse=tf.AUTO_REUSE):
t_aparam_nall = tf.constant(True, name="aparam_nall", dtype=tf.bool)
self.mask = tf.cast(aparam, tf.int32)
self.mask = tf.reshape(self.mask, [-1, natoms[1]])
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def build(
if self.aparam_inv_std is None:
self.aparam_inv_std = 1.0

with tf.variable_scope("fitting_attr" + suffix, reuse=reuse):
with tf.variable_scope("fitting_attr" + suffix, reuse=tf.AUTO_REUSE):
t_dfparam = tf.constant(self.numb_fparam, name="dfparam", dtype=tf.int32)
t_daparam = tf.constant(self.numb_aparam, name="daparam", dtype=tf.int32)
t_numb_dos = tf.constant(self.numb_dos, name="numb_dos", dtype=tf.int32)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def build(
if "t_bias_atom_e" in nvnmd_cfg.weight.keys():
self.bias_atom_e = nvnmd_cfg.weight["t_bias_atom_e"]

with tf.variable_scope("fitting_attr" + suffix, reuse=reuse):
with tf.variable_scope("fitting_attr" + suffix, reuse=tf.AUTO_REUSE):
t_dfparam = tf.constant(self.numb_fparam, name="dfparam", dtype=tf.int32)
t_daparam = tf.constant(self.numb_aparam, name="daparam", dtype=tf.int32)
self.t_bias_atom_e = tf.get_variable(
Expand Down

0 comments on commit 39bd6be

Please sign in to comment.