Skip to content

Commit

Permalink
Feat: add DipoleModel and PolarModel (#3309)
Browse files Browse the repository at this point in the history
This PR is to provide model wrappers for `DipoleFittingNet` and
`PolarFittingNet`, such that the saved model can be used directly in
inference with `DeepDiole` and `DeepPolar`.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Feb 21, 2024
1 parent 4956864 commit af14ba4
Show file tree
Hide file tree
Showing 16 changed files with 379 additions and 42 deletions.
36 changes: 25 additions & 11 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,33 +118,47 @@ def serialize(self) -> dict:
def deserialize(cls):
pass

def do_grad(
def do_grad_r(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is differentiable.
if var_name is None, returns if any of the variable is differentiable.
"""Tell if the output variable `var_name` is r_differentiable.
if var_name is None, returns if any of the variable is r_differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv))
require.append(self.do_grad_(vv, "r"))
return any(require)
else:
return self.do_grad_(var_name)
return self.do_grad_(var_name, "r")

def do_grad_(
def do_grad_c(
self,
var_name: str,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is c_differentiable.
if var_name is None, returns if any of the variable is c_differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv, "c"))
return any(require)
else:
return self.do_grad_(var_name, "c")

def do_grad_(self, var_name: str, base: str) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return (
self.fitting_output_def()[var_name].r_differentiable
or self.fitting_output_def()[var_name].c_differentiable
)
assert base in ["c", "r"]
if base == "c":
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")
Expand Down
19 changes: 15 additions & 4 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ class DipoleFitting(GeneralFitting):
mixed_types
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
exclude_types: List[int]
exclude_types
Atomic contributions of the excluded atom types are set zero.
r_differentiable
If the variable is differentiated with respect to coordinates of atoms.
Only reduciable variable are differentiable.
c_differentiable
If the variable is differentiated with respect to the cell tensor (pbc case).
Only reduciable variable are differentiable.
"""

def __init__(
Expand All @@ -94,6 +99,8 @@ def __init__(
spin: Any = None,
mixed_types: bool = False,
exclude_types: List[int] = [],
r_differentiable: bool = True,
c_differentiable: bool = True,
old_impl=False,
):
# seed, uniform_seed are not included
Expand All @@ -109,6 +116,8 @@ def __init__(
raise NotImplementedError("atom_ener is not implemented")

self.embedding_width = embedding_width
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand Down Expand Up @@ -139,6 +148,8 @@ def serialize(self) -> dict:
data = super().serialize()
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
data["c_differentiable"] = self.c_differentiable
return data

def output_def(self):
Expand All @@ -148,8 +159,8 @@ def output_def(self):
self.var_name,
[3],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
r_differentiable=self.r_differentiable,
c_differentiable=self.c_differentiable,
),
]
)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def output_def(self):
self.var_name,
[3, 3],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
18 changes: 11 additions & 7 deletions deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def eval(
**kwargs,
)
sel_natoms = self._get_sel_natoms(atom_types[0])
if sel_natoms == 0:
sel_natoms = atom_types.shape[-1] # set to natoms
if atomic:
return results[self.output_tensor_name].reshape(nframes, sel_natoms, -1)
else:
Expand Down Expand Up @@ -184,22 +186,24 @@ def eval_full(
aparam=aparam,
**kwargs,
)

sel_natoms = self._get_sel_natoms(atom_types[0])
if sel_natoms == 0:
sel_natoms = atom_types.shape[-1] # set to natoms
energy = results[f"{self.output_tensor_name}_redu"].reshape(nframes, -1)
force = results[f"{self.output_tensor_name}_derv_r"].reshape(
nframes, -1, natoms, 3
)
virial = results[f"{self.output_tensor_name}_derv_c_redu"].reshape(
nframes, -1, 9
)
atomic_energy = results[self.output_tensor_name].reshape(
nframes, sel_natoms, -1
)
atomic_virial = results[f"{self.output_tensor_name}_derv_c"].reshape(
nframes, -1, natoms, 9
)

if atomic:
atomic_energy = results[self.output_tensor_name].reshape(
nframes, sel_natoms, -1
)
atomic_virial = results[f"{self.output_tensor_name}_derv_c"].reshape(
nframes, -1, natoms, 9
)
return (
energy,
force,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ def _eval_model(
shape = self._get_output_shape(odef, nframes, natoms)
out = batch_output[pt_name].reshape(shape).detach().cpu().numpy()
results.append(out)
else:
shape = self._get_output_shape(odef, nframes, natoms)
results.append(np.full(np.abs(shape), np.nan)) # this is kinda hacky
return tuple(results)

def _get_output_shape(self, odef, nframes, natoms):
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.do_grad():
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward_atomic(
the result dict, defined by the fitting net output def.
"""
nframes, nloc, nnei = nlist.shape
if self.do_grad():
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def forward_atomic(
) -> Dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
if self.do_grad():
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

# this will mask all -1 in the nlist
Expand Down
91 changes: 91 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
Optional,
)

import torch

from .dp_model import (
DPModel,
)


class DipoleModel(DPModel):
model_type = "dipole"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
model_predict = {}
model_predict["dipole"] = model_ret["dipole"]
model_predict["global_dipole"] = model_ret["dipole_redu"]
if self.do_grad_r("dipole"):
model_predict["force"] = model_ret["dipole_derv_r"].squeeze(-2)
if self.do_grad_c("dipole"):
model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(
-3
)
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
return model_predict

def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
model_predict = {}
model_predict["dipole"] = model_ret["dipole"]
model_predict["global_dipole"] = model_ret["dipole_redu"]
if self.do_grad_r("dipole"):
model_predict["force"] = model_ret["dipole_derv_r"].squeeze(-2)
if self.do_grad_c("dipole"):
model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(
-3
)
else:
model_predict = model_ret
return model_predict
14 changes: 7 additions & 7 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ def forward(
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad("energy"):
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
else:
model_predict["force"] = model_ret["dforce"]
return model_predict
Expand Down Expand Up @@ -80,13 +81,12 @@ def forward_lower(
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-2
)
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
Expand Down
10 changes: 6 additions & 4 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ def forward(
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad("energy"):
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
else:
model_predict["force"] = model_ret["dforce"]
else:
Expand Down Expand Up @@ -79,13 +80,14 @@ def forward_lower(
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad("energy"):
if self.do_grad_r("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret[
"energy_derv_c"
].squeeze(-2)
].squeeze(-3)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
Expand Down
Loading

0 comments on commit af14ba4

Please sign in to comment.