Skip to content

Commit

Permalink
Fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 6c171c5 commit 2e87e1d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,7 @@ def compute_or_load_stat(
The path to the statistics files.
"""
raise NotImplementedError

def data_requirement(self) -> dict:
"""Get the data requirement for the model."""
raise NotImplementedError
5 changes: 5 additions & 0 deletions source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
LearningRateExp,
)

from ..test_stat import (
energy_data_requirement,
)

VariableState = collections.namedtuple("VariableState", ["value", "gradient"])


Expand Down Expand Up @@ -281,6 +285,7 @@ def test_consistency(self):
"type_map": self.type_map,
},
)
my_ds.add_data_requirement(energy_data_requirement)
my_model = get_model(
model_params={
"descriptor": {
Expand Down
1 change: 0 additions & 1 deletion source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def setUp(self):
self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu")
self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE)
self.ft0 = PolarFittingNet(
"polar",
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
Expand Down

0 comments on commit 2e87e1d

Please sign in to comment.