diff --git a/source/tests/pt/test_fitting_net.py b/source/tests/pt/test_fitting_net.py index ed2c428de5..0390043770 100644 --- a/source/tests/pt/test_fitting_net.py +++ b/source/tests/pt/test_fitting_net.py @@ -17,6 +17,9 @@ from deepmd.tf.fit.ener import ( EnerFitting, ) +from deepmd.pt.utils import ( + env, +) class FakeDescriptor: @@ -105,7 +108,7 @@ def test_consistency(self): neuron=self.n_neuron, bias_atom_e=self.dp_fn.bias_atom_e, distinguish_types=True, - ) + ).to(env.DEVICE) for name, param in my_fn.named_parameters(): matched = re.match( "filter_layers\.networks\.(\d).layers\.(\d)\.([a-z]+)", name @@ -129,9 +132,9 @@ def test_consistency(self): embedding = torch.from_numpy(self.embedding) embedding = embedding.view(4, -1, self.embedding_width) atype = torch.from_numpy(self.atype) - ret = my_fn(embedding, atype) + ret = my_fn(embedding.to(env.DEVICE), atype.to(env.DEVICE)) my_energy = ret["energy"] - my_energy = my_energy.detach() + my_energy = my_energy.detach().cpu() np.testing.assert_allclose(dp_energy, my_energy.numpy().reshape([-1]))