Skip to content

Commit

Permalink
Update test_fitting_net.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 30, 2024
1 parent f174590 commit a4892b7
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions source/tests/pt/test_fitting_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from deepmd.tf.fit.ener import (
EnerFitting,
)
from deepmd.pt.utils import (
env,
)


class FakeDescriptor:
Expand Down Expand Up @@ -105,7 +108,7 @@ def test_consistency(self):
neuron=self.n_neuron,
bias_atom_e=self.dp_fn.bias_atom_e,
distinguish_types=True,
)
).to(env.DEVICE)
for name, param in my_fn.named_parameters():
matched = re.match(
"filter_layers\.networks\.(\d).layers\.(\d)\.([a-z]+)", name
Expand All @@ -129,9 +132,9 @@ def test_consistency(self):
embedding = torch.from_numpy(self.embedding)
embedding = embedding.view(4, -1, self.embedding_width)
atype = torch.from_numpy(self.atype)
ret = my_fn(embedding, atype)
ret = my_fn(embedding.to(env.DEVICE), atype.to(env.DEVICE))
my_energy = ret["energy"]
my_energy = my_energy.detach()
my_energy = my_energy.detach().cpu()
np.testing.assert_allclose(dp_energy, my_energy.numpy().reshape([-1]))


Expand Down

0 comments on commit a4892b7

Please sign in to comment.