Skip to content

Commit

Permalink
fix prec
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 21, 2024
1 parent b5384dc commit 7597c5a
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 16 deletions.
2 changes: 0 additions & 2 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def forward(
]
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3)
# (nframes, nloc, 3)
Expand Down
10 changes: 5 additions & 5 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,15 @@ def _forward_common(

outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
dtype=self.prec,
device=descriptor.device,
) # jit assertion
if self.mixed_types:
atom_property = self.filter_layers.networks[0](xx)
if xx_zeros is not None:
atom_property -= self.filter_layers.networks[0](xx_zeros)
outs = (
outs + atom_property + self.bias_atom_e[atype]
outs + atom_property + self.bias_atom_e[atype].to(self.prec)
) # Shape is [nframes, natoms[0], net_dim_out]
else:
for type_i, ll in enumerate(self.filter_layers.networks):
Expand All @@ -512,13 +512,13 @@ def _forward_common(
and not self.remove_vaccum_contribution[type_i]
):
atom_property -= ll(xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec)
atom_property = torch.where(mask, atom_property, 0.0)
outs = (
outs + atom_property
) # Shape is [nframes, natoms[0], net_dim_out]
# nf x nloc
mask = self.emask(atype).to(torch.bool)
# nf x nloc x nod
outs = torch.where(mask[:, :, None], outs, 0.0)
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}
return {self.var_name: outs}
5 changes: 4 additions & 1 deletion deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def forward(
-------
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
"""
return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

# make jit happy with torch 2.0.0
exclude_types: list[int]
4 changes: 1 addition & 3 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,7 @@ def forward(
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * (self.scale.to(atype.device))[atype]
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)
out = out * (self.scale.to(atype.device).to(self.prec))[atype]

gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3)

Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_permu(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["dipole"])

np.testing.assert_allclose(
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["dipole"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["polarizability"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["property"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["property"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand Down

0 comments on commit 7597c5a

Please sign in to comment.