From cb30f416a1af73eb4df591a3a307f3f0f159d5df Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 2 Apr 2024 05:19:21 -0400 Subject: [PATCH] fix(tf): make `se_atten_v2` masking smooth when davg is not zero (#3632) Currently, `se_atten_v2` is always masked to zero when `exclude_types` is given. However, for the no neighbor case, the placeholder for a virtual neighbor is `davg`. This causes discontinuity when `set_davg_zero` is not set. This PR uses `davg` for masking. In production, we usually use `set_davg_zero` along with `exclude_types`, so it hasn't caused a real problem. I notice PT hasn't implemented `se_atten_v2` or `exclude_types`, but we need attention in the future. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 63601b0f69ee5815702123b7ab863dbf4b4ee2bf) --- deepmd/descriptor/se_atten.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index f9193675e2..8c1a179923 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -649,7 +649,17 @@ def _pass_filter( tf.shape(inputs_i)[0], self.nei_type_vec, # extra input for atten ) - inputs_i *= mask + if self.smooth: + inputs_i = tf.where( + tf.cast(mask, tf.bool), + inputs_i, + # (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) + tf.tile( + tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt] + ), + ) + else: + inputs_i *= mask if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: inputs_i = descrpt2r4(inputs_i, atype) layer, qmat = self._filter(