diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 617e8b49b6..71d1b3a30c 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -22,6 +22,7 @@ env, ) from deepmd.pt.utils.env import ( + PRECISION_DICT, RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.update_sel import ( @@ -304,6 +305,7 @@ def __init__( use_tebd_bias=use_tebd_bias, type_map=type_map, ) + self.prec = PRECISION_DICT[precision] self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable @@ -593,6 +595,8 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -608,7 +612,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None, + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 7c47759a3f..4314893db8 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -27,6 +27,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, get_multiple_nlist_key, @@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class): ) self.concat_output_tebd = concat_output_tebd self.precision = precision + self.prec = PRECISION_DICT[self.precision] self.smooth = smooth self.exclude_types = exclude_types self.env_protection = env_protection @@ -744,6 +748,9 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 @@ -809,7 +816,13 @@ def forward( ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 8733234a5b..a6242c2cdb 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -22,6 +22,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) @@ -240,6 +243,7 @@ def __init__( self.reinit_exclude(exclude_types) self.env_protection = env_protection self.precision = precision + self.prec = PRECISION_DICT[precision] self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 @@ -321,12 +325,8 @@ def __init__( self.layers = torch.nn.ModuleList(layers) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.stats = None diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 1b51acfa21..62df142935 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -97,6 +97,7 @@ def __init__( raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map + self.prec = PRECISION_DICT[precision] self.sea = DescrptBlockSeA( rcut, rcut_smth, @@ -270,7 +271,18 @@ def forward( The smooth switch function. """ - return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.sea.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -703,8 +715,8 @@ def forward( result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron) rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, + rot_mat, None, None, sw, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index c760f7330b..30cb46d8b3 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -231,12 +231,8 @@ def __init__( ) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index b873ee20b8..74f060ba3b 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -343,7 +343,9 @@ def forward( The smooth switch function. """ - del mapping + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + del mapping, comm_dict nf = nlist.shape[0] nloc = nlist.shape[1] atype = atype_ext[:, :nloc] @@ -361,7 +363,6 @@ def forward( assert self.filter_layers is not None dmatrix = dmatrix.view(-1, self.nnei, 1) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( @@ -389,7 +390,7 @@ def forward( None, None, None, - sw, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), ) def set_stat_mean_and_stddev( diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 072457b48f..1259081978 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -129,6 +129,7 @@ def __init__( raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map + self.prec = PRECISION_DICT[precision] self.seat = DescrptBlockSeT( rcut, rcut_smth, @@ -300,7 +301,18 @@ def forward( The smooth switch function. """ - return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.seat.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -680,7 +692,6 @@ def forward( protection=self.env_protection, ) dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit result = torch.zeros( @@ -718,7 +729,7 @@ def forward( # xyz_scatter /= (self.nnei * self.nnei) result = result.view(nf, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 437a464709..821e25fa9b 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -163,6 +163,7 @@ def __init__( ) self.use_econf_tebd = use_econf_tebd self.type_map = type_map + self.prec = PRECISION_DICT[precision] self.smooth = smooth self.type_embedding = TypeEmbedNet( ntypes, @@ -441,12 +442,14 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] - g1, g2, h2, rot_mat, sw = self.se_ttebd( + g1, _, _, _, sw = self.se_ttebd( nlist, extended_coord, extended_atype, @@ -456,7 +459,13 @@ def forward( if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) - return g1, rot_mat, g2, h2, sw + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) @classmethod def update_sel( @@ -540,12 +549,8 @@ def __init__( self.reinit_exclude(exclude_types) wanted_shape = (self.ntypes, self.nnei, 4) - mean = torch.zeros( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - stddev = torch.ones( - wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) self.tebd_dim_input = self.tebd_dim * 2 @@ -845,7 +850,7 @@ def forward( # nf x nl x ng result = res_ij.view(nframes, nloc, self.filter_neuron[-1]) return ( - result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + result, None, None, None, diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index f2137bd004..2b8383806b 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -200,7 +200,6 @@ def forward( The output. """ ori_prec = xx.dtype - xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -215,7 +214,6 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy - yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 56b14677b9..5784a659bd 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -189,6 +189,8 @@ def forward( ] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) + # cast from global to gr precision again + out = out.to(dtype=gr.dtype) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 1827569a17..0505a12d3f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -415,7 +415,11 @@ def _forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ): - xx = descriptor + # cast the input to internal precsion + xx = descriptor.to(self.prec) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + if self.remove_vaccum_contribution is not None: # TODO: compute the input for vaccm when remove_vaccum_contribution is set # Idealy, the input for vaccum should be computed; @@ -488,49 +492,33 @@ def _forward_common( dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=descriptor.device, ) # jit assertion - if self.old_impl: - assert self.filter_layers_old is not None - assert xx_zeros is None - if self.mixed_types: - atom_property = self.filter_layers_old[0](xx) + self.bias_atom_e[atype] - outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - else: - for type_i, filter_layer in enumerate(self.filter_layers_old): - mask = atype == type_i - atom_property = filter_layer(xx) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask.unsqueeze(-1) - outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + if self.mixed_types: + atom_property = self.filter_layers.networks[0](xx) + if xx_zeros is not None: + atom_property -= self.filter_layers.networks[0](xx_zeros) + outs = ( + outs + atom_property + self.bias_atom_e[atype] + ) # Shape is [nframes, natoms[0], net_dim_out] else: - if self.mixed_types: - atom_property = ( - self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] - ) + for type_i, ll in enumerate(self.filter_layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, net_dim_out)) + atom_property = ll(xx) if xx_zeros is not None: - atom_property -= self.filter_layers.networks[0](xx_zeros) + # must assert, otherwise jit is not happy + assert self.remove_vaccum_contribution is not None + if not ( + len(self.remove_vaccum_contribution) > type_i + and not self.remove_vaccum_contribution[type_i] + ): + atom_property -= ll(xx_zeros) + atom_property = atom_property + self.bias_atom_e[type_i] + atom_property = atom_property * mask outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] - else: - for type_i, ll in enumerate(self.filter_layers.networks): - mask = (atype == type_i).unsqueeze(-1) - mask = torch.tile(mask, (1, 1, net_dim_out)) - atom_property = ll(xx) - if xx_zeros is not None: - # must assert, otherwise jit is not happy - assert self.remove_vaccum_contribution is not None - if not ( - len(self.remove_vaccum_contribution) > type_i - and not self.remove_vaccum_contribution[type_i] - ): - atom_property -= ll(xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask - outs = ( - outs + atom_property - ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc - mask = self.emask(atype) + mask = self.emask(atype).to(torch.bool) # nf x nloc x nod - outs = outs * mask[:, :, None] + outs = torch.where(mask[:, :, None], outs, 0.0) return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index a16ab886d4..587d14e924 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -241,6 +241,9 @@ def forward( self.var_name ] out = out * (self.scale.to(atype.device))[atype] + # cast from global to gr precision again + out = out.to(dtype=gr.dtype) + gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) if self.fit_diag: