Skip to content

Commit

Permalink
fix(pt): fix get_dim for DescrptDPA1Compat (#4007)
Browse files Browse the repository at this point in the history
- [x] (Tomorrow) Test if it works for #3997. 

#3997 needs another fix in #4022 .

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Introduced a method to dynamically determine the output dimension of
the descriptor, enhancing its functionality and interaction with other
components.
- Improved tensor dimensionality handling in tests to ensure
compatibility with the new output dimension method.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
iProzd and njzjz authored Jul 26, 2024
1 parent c335dcf commit f7aa626
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
17 changes: 16 additions & 1 deletion deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,14 @@ def _pass_filter(
type_embedding=type_embedding,
atype=atype,
)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()])
layer = tf.reshape(
layer,
[
tf.shape(inputs)[0],
natoms[0],
self.filter_neuron[-1] * self.n_axis_neuron,
],
)
qmat = tf.reshape(
qmat, [tf.shape(inputs)[0], natoms[0], self.get_dim_rot_mat_1() * 3]
)
Expand Down Expand Up @@ -2194,6 +2201,14 @@ def __init__(
else:
self.embd_input_dim = 1

def get_dim_out(self) -> int:
"""Returns the output dimension of this descriptor."""
return (
super().get_dim_out() + self.tebd_dim
if self.concat_output_tebd
else super().get_dim_out()
)

def build(
self,
coord_: tf.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions source/tests/consistent/descriptor/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
# ensure get_dim_out gives the correct shape
t_des = tf.reshape(t_des, [1, natoms[0], obj.get_dim_out()])
return [t_des], {
t_coord: coords,
t_type: atype,
Expand Down

0 comments on commit f7aa626

Please sign in to comment.