Skip to content

Commit

Permalink
fix(tf): throw RuntimeError for se_a + type_embedding (#3861)
Browse files Browse the repository at this point in the history
Fix #3541.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced a new attribute `use_tebd` to enhance serialization
handling based on input conditions.

- **Tests**
  - Added new assertions to improve error handling in the test suite.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jun 11, 2024
1 parent 7786126 commit 73dab63
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 9 additions & 4 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def __init__(
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.original_sel = None
# Whether type embedding is used
self.use_tebd: bool = False

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -746,6 +748,8 @@ def _pass_filter(
):
if input_dict is not None:
type_embedding = input_dict.get("type_embedding", None)
if type_embedding is not None:
self.use_tebd = True
else:
type_embedding = None
if self.stripped_type_embedding and type_embedding is None:
Expand Down Expand Up @@ -1406,7 +1410,7 @@ def serialize(self, suffix: str = "") -> dict:
raise NotImplementedError(
"Serialization is unsupported when tebd_input_mode is set to 'strip'"
)
if (self.original_sel != self.sel_a).any():
if self.original_sel is not None and (self.original_sel != self.sel_a).any():
raise NotImplementedError(
"Adjusting sel is unsupported by the native model"
)
Expand All @@ -1416,9 +1420,10 @@ def serialize(self, suffix: str = "") -> dict:
raise NotImplementedError("spin is unsupported")
assert self.davg is not None
assert self.dstd is not None
# TODO: tf: handle type embedding in DescrptSeA.serialize
# not sure how to handle type embedding - type embedding is not a model parameter,
# but instead a part of the input data. Maybe the interface should be refactored...
if self.use_tebd:
raise RuntimeError(
"Serialization is unsupported when type_embedding is used."
)

return {
"@class": "Descriptor",
Expand Down
3 changes: 3 additions & 0 deletions source/tests/tf/test_model_se_a_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,6 @@ def test_model(self):
np.testing.assert_almost_equal(e, refe, places)
np.testing.assert_almost_equal(f, reff, places)
np.testing.assert_almost_equal(v, refv, places)

with self.assertRaises(RuntimeError):
descrpt.serialize(suffix="se_a_type")

0 comments on commit 73dab63

Please sign in to comment.