Skip to content

Commit

Permalink
add ut for enable_atomic_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu committed Dec 13, 2024
1 parent e5580c7 commit b4899a1
Show file tree
Hide file tree
Showing 2 changed files with 468 additions and 2 deletions.
6 changes: 4 additions & 2 deletions deepmd/tf/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
atomic_weight = global_cvt_2_tf_float(1.0)

if self.local_weight > 0.0:
diff = polar - atomic_polar_hat
diff = tf.reshape(diff, [-1, self.tensor_size]) * atomic_weight
diff = tf.reshape(polar, [-1, self.tensor_size]) - tf.reshape(
atomic_polar_hat, [-1, self.tensor_size]
)
diff = diff * atomic_weight
local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean(
tf.square(self.scale * diff), name="l2_" + suffix
)
Expand Down
Loading

0 comments on commit b4899a1

Please sign in to comment.