diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 5784a659bd..56b14677b9 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -189,8 +189,6 @@ def forward( ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) - # cast from global to gr precision again - out = out.to(dtype=gr.dtype) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 0505a12d3f..55cc665e24 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -489,7 +489,7 @@ def _forward_common( outs = torch.zeros( (nf, nloc, net_dim_out), - dtype=env.GLOBAL_PT_FLOAT_PRECISION, + dtype=self.prec, device=descriptor.device, ) # jit assertion if self.mixed_types: @@ -497,7 +497,7 @@ def _forward_common( if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) outs = ( - outs + atom_property + self.bias_atom_e[atype] + outs + atom_property + self.bias_atom_e[atype].to(self.prec) ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): @@ -512,8 +512,8 @@ def _forward_common( and not self.remove_vaccum_contribution[type_i] ): atom_property -= ll(xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask + atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec) + atom_property = torch.where(mask, atom_property, 0.0) outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] @@ -521,4 +521,4 @@ def _forward_common( mask = self.emask(atype).to(torch.bool) # nf x nloc x nod outs = torch.where(mask[:, :, None], outs, 0.0) - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + return {self.var_name: outs} diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index 230046b74b..42ce59f575 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -175,7 +175,10 @@ def forward( ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ - return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ + self.var_name + ] + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 587d14e924..645293c1df 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -240,9 +240,7 @@ def forward( out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ self.var_name ] - out = out * (self.scale.to(atype.device))[atype] - # cast from global to gr precision again - out = out.to(dtype=gr.dtype) + out = out * (self.scale.to(atype.device).to(self.prec))[atype] gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 71da2781ac..0c4121f457 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -262,7 +262,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -303,7 +303,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 1ca563a8c2..4e63145741 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -326,7 +326,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) diff --git a/source/tests/pt/model/test_property_fitting.py b/source/tests/pt/model/test_property_fitting.py index dfe2725f3b..ad5f3687e9 100644 --- a/source/tests/pt/model/test_property_fitting.py +++ b/source/tests/pt/model/test_property_fitting.py @@ -228,7 +228,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -399,7 +399,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["property"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))