Skip to content

Commit

Permalink
fix prec
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 21, 2024
1 parent 65990e2 commit b5384dc
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 73 deletions.
12 changes: 11 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.update_sel import (
Expand Down Expand Up @@ -304,6 +305,7 @@ def __init__(
use_tebd_bias=use_tebd_bias,
type_map=type_map,
)
self.prec = PRECISION_DICT[precision]
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
Expand Down Expand Up @@ -593,6 +595,8 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand All @@ -608,7 +612,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
Expand Down Expand Up @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class):
)
self.concat_output_tebd = concat_output_tebd
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.smooth = smooth
self.exclude_types = exclude_types
self.env_protection = env_protection
Expand Down Expand Up @@ -744,6 +748,9 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)

use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
Expand Down Expand Up @@ -809,7 +816,13 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
Expand Down Expand Up @@ -240,6 +243,7 @@ def __init__(
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
self.precision = precision
self.prec = PRECISION_DICT[precision]
self.trainable_ln = trainable_ln
self.ln_eps = ln_eps
self.epsilon = 1e-4
Expand Down Expand Up @@ -321,12 +325,8 @@ def __init__(
self.layers = torch.nn.ModuleList(layers)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None
Expand Down
18 changes: 15 additions & 3 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
raise NotImplementedError("old implementation of spin is not supported.")
super().__init__()
self.type_map = type_map
self.prec = PRECISION_DICT[precision]
self.sea = DescrptBlockSeA(
rcut,
rcut_smth,
Expand Down Expand Up @@ -270,7 +271,18 @@ def forward(
The smooth switch function.
"""
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.sea.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -703,8 +715,8 @@ def forward(
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
rot_mat,
None,
None,
sw,
Expand Down
8 changes: 2 additions & 6 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,8 @@ def __init__(
)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def forward(
The smooth switch function.
"""
del mapping
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
del mapping, comm_dict
nf = nlist.shape[0]
nloc = nlist.shape[1]
atype = atype_ext[:, :nloc]
Expand All @@ -361,7 +363,6 @@ def forward(

assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
Expand Down Expand Up @@ -389,7 +390,7 @@ def forward(
None,
None,
None,
sw,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
Expand Down
17 changes: 14 additions & 3 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
raise NotImplementedError("old implementation of spin is not supported.")
super().__init__()
self.type_map = type_map
self.prec = PRECISION_DICT[precision]
self.seat = DescrptBlockSeT(
rcut,
rcut_smth,
Expand Down Expand Up @@ -300,7 +301,18 @@ def forward(
The smooth switch function.
"""
return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping)
# cast the input to internal precsion
coord_ext = coord_ext.to(dtype=self.prec)
g1, rot_mat, g2, h2, sw = self.seat.forward(
nlist, coord_ext, atype_ext, None, mapping
)
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -680,7 +692,6 @@ def forward(
protection=self.env_protection,
)
dmatrix = dmatrix.view(-1, self.nnei, 4)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
result = torch.zeros(
Expand Down Expand Up @@ -718,7 +729,7 @@ def forward(
# xyz_scatter /= (self.nnei * self.nnei)
result = result.view(nf, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
23 changes: 14 additions & 9 deletions deepmd/pt/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
)
self.use_econf_tebd = use_econf_tebd
self.type_map = type_map
self.prec = PRECISION_DICT[precision]
self.smooth = smooth
self.type_embedding = TypeEmbedNet(
ntypes,
Expand Down Expand Up @@ -441,12 +442,14 @@ def forward(
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, g2, h2, rot_mat, sw = self.se_ttebd(
g1, _, _, _, sw = self.se_ttebd(
nlist,
extended_coord,
extended_atype,
Expand All @@ -456,7 +459,13 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw
return (
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)

@classmethod
def update_sel(
Expand Down Expand Up @@ -540,12 +549,8 @@ def __init__(
self.reinit_exclude(exclude_types)

wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
stddev = torch.ones(
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim * 2
Expand Down Expand Up @@ -845,7 +850,7 @@ def forward(
# nf x nl x ng
result = res_ij.view(nframes, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
result,
None,
None,
None,
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def forward(
The output.
"""
ori_prec = xx.dtype
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
Expand All @@ -215,7 +214,6 @@ def forward(
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
yy = yy.to(ori_prec)
return yy

def serialize(self) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def forward(
]
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# cast from global to gr precision again
out = out.to(dtype=gr.dtype)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3)
# (nframes, nloc, 3)
Expand Down
Loading

0 comments on commit b5384dc

Please sign in to comment.