Skip to content

Commit

Permalink
fix TypeError when type_map is not given (deepmodeling#2890)
Browse files Browse the repository at this point in the history
Fix deepmodeling#2889.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Oct 7, 2023
1 parent 14c9964 commit 47d985d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
7 changes: 7 additions & 0 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class DescrptSeAtten(DescrptSeA):
When using stripped type embedding, whether to dot smooth factor on the network output of type embedding
to keep the network smooth, instead of setting `set_davg_zero` to be True.
Default value will be True in `se_atten_v2` descriptor.
Raises
------
ValueError
if ntypes is 0.
"""

def __init__(
Expand Down Expand Up @@ -178,6 +183,8 @@ def __init__(
assert Version(TF_VERSION) > Version(
"2"
), "se_atten only support tensorflow version 2.0 or higher."
if ntypes == 0:
raise ValueError("`model/type_map` is not set or empty!")
self.stripped_type_embedding = stripped_type_embedding
self.smooth = smooth_type_embdding
self.ntypes = ntypes
Expand Down
2 changes: 1 addition & 1 deletion deepmd/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def __init__(
self.descrpt = descriptor
else:
self.descrpt = Descriptor(
**descriptor, ntypes=len(type_map), spin=self.spin
**descriptor, ntypes=len(self.get_type_map()), spin=self.spin
)

if isinstance(fitting_net, Fitting):
Expand Down
2 changes: 1 addition & 1 deletion deepmd/model/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
else:
self.descrpt = Descriptor(
**descriptor,
ntypes=len(type_map),
ntypes=len(self.get_type_map()),
multi_task=True,
spin=self.spin,
)
Expand Down

0 comments on commit 47d985d

Please sign in to comment.