Skip to content

Commit

Permalink
feat: add NumPy DeepPot (#3332)
Browse files Browse the repository at this point in the history
While a DPModel cannot be directly trained, it can be converted from
another model:
```sh
dp convert-backend frozen_model.pth frozen_model.dp
dp test -m frozen_model.dp -s ../data/
```
The energy result is consistent with TF and PT. Force and virial are
NaN, as expected.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 24, 2024
1 parent 03ca9ab commit 15df69b
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 3 deletions.
8 changes: 6 additions & 2 deletions deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DPModelBackend(Backend):
name = "DPModel"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".dp"]
Expand Down Expand Up @@ -70,7 +70,11 @@ def deep_eval(self) -> Type["DeepEvalBackend"]:
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError(f"Unsupported backend: {self.name}")
from deepmd.dpmodel.infer.deep_eval import (
DeepEval,
)

return DeepEval

@property
def neighbor_stat(self) -> Type["NeighborStat"]:
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
372 changes: 372 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
OutputVariableCategory,
OutputVariableDef,
)
from deepmd.dpmodel.utils.batch_size import (
AutoBatchSize,
)
from deepmd.dpmodel.utils.network import (
load_dp_model,
)
from deepmd.infer.deep_dipole import (
DeepDipole,
)
from deepmd.infer.deep_dos import (
DeepDOS,
)
from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper
from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.infer.deep_polar import (
DeepPolar,
)
from deepmd.infer.deep_pot import (
DeepPot,
)
from deepmd.infer.deep_wfc import (
DeepWFC,
)

if TYPE_CHECKING:
import ase.neighborlist


class DeepEval(DeepEvalBackend):
"""NumPy backend implementaion of DeepEval.
Parameters
----------
model_file : Path
The name of the frozen model file.
output_def : ModelOutputDef
The output definition of the model.
*args : list
Positional arguments.
auto_batch_size : bool or int or AutomaticBatchSize, default: False
If True, automatic batch size will be used. If int, it will be used
as the initial batch size.
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional
The ASE neighbor list class to produce the neighbor list. If None, the
neighbor list will be built natively in the model.
**kwargs : dict
Keyword arguments.
"""

def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
):
self.output_def = output_def
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = DPModel.deserialize(model_data["model"])
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
if auto_batch_size:
self.auto_batch_size = AutoBatchSize()
else:
self.auto_batch_size = None
elif isinstance(auto_batch_size, int):
self.auto_batch_size = AutoBatchSize(auto_batch_size)
elif isinstance(auto_batch_size, AutoBatchSize):
self.auto_batch_size = auto_batch_size
else:
raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.rcut

def get_ntypes(self) -> int:
"""Get the number of atom types of this model."""
return len(self.type_map)

def get_type_map(self) -> List[str]:
"""Get the type map (element name of the atom types) of this model."""
return self.type_map

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this DP."""
return self.dp.get_dim_fparam()

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this DP."""
return self.dp.get_dim_aparam()

@property
def model_type(self) -> Type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_type = self.dp.model_output_type()
if model_type == "energy":
return DeepPot
elif model_type == "dos":
return DeepDOS
elif model_type == "dipole":
return DeepDipole
elif model_type == "polar":
return DeepPolar
elif model_type == "wfc":
return DeepWFC
else:
raise RuntimeError("Unknown model type")

def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self.dp.get_sel_type()

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
return 0

def get_has_efield(self):
"""Check if the model has efield."""
return False

def get_ntypes_spin(self):
"""Get the number of spin atom types of this model."""
return 0

def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Parameters
----------
coords
The coordinates of atoms.
The array should be of size nframes x natoms x 3
cells
The cell of the region.
If None then non-PBC is assumed, otherwise using PBC.
The array should be of size nframes x 9
atom_types
The atom types
The list should contain natoms ints
atomic
Calculate the atomic energy and virial
fparam
The frame parameter.
The array can be of size :
- nframes x dim_fparam.
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
aparam
The atomic parameter
The array can be of size :
- nframes x natoms x dim_aparam.
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
- dim_aparam. Then all frames and atoms are provided with the same aparam.
**kwargs
Other parameters
Returns
-------
output_dict : dict
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
if cells is not None:
cells = np.array(cells)
natoms, numb_test = self._get_natoms_and_nframes(
coords, atom_types, len(atom_types.shape) > 1
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
)
return dict(
zip(
[x.name for x in request_defs],
out,
)
)

def _get_request_defs(self, atomic: bool) -> List[OutputVariableDef]:
"""Get the requested output definitions.
When atomic is True, all output_def are requested.
When atomic is False, only energy (tensor), force, and virial
are requested.
Parameters
----------
atomic : bool
Whether to request the atomic output.
Returns
-------
list[OutputVariableDef]
The requested output definitions.
"""
if atomic:
return list(self.output_def.var_defs.values())
else:
return [
x
for x in self.output_def.var_defs.values()
if x.category
in (
OutputVariableCategory.REDU,
OutputVariableCategory.DERV_R,
OutputVariableCategory.DERV_C_REDU,
)
]

def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable:
"""Wrapper method with auto batch size.
Parameters
----------
inner_func : Callable
the method to be wrapped
numb_test : int
number of tests
natoms : int
number of atoms
Returns
-------
Callable
the wrapper
"""
if self.auto_batch_size is not None:

def eval_func(*args, **kwargs):
return self.auto_batch_size.execute_all(
inner_func, numb_test, natoms, *args, **kwargs
)

else:
eval_func = inner_func
return eval_func

def _get_natoms_and_nframes(
self,
coords: np.ndarray,
atom_types: np.ndarray,
mixed_type: bool = False,
) -> Tuple[int, int]:
if mixed_type:
natoms = len(atom_types[0])
else:
natoms = len(atom_types)
if natoms == 0:
assert coords.size == 0
else:
coords = np.reshape(np.array(coords), [-1, natoms * 3])
nframes = coords.shape[0]
return natoms, nframes

def _eval_model(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
request_defs: List[OutputVariableDef],
):
model = self.dp

nframes = coords.shape[0]
if len(atom_types.shape) == 1:
natoms = len(atom_types)
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
else:
natoms = len(atom_types[0])

coord_input = coords.reshape([-1, natoms, 3])
type_input = atom_types
if cells is not None:
box_input = cells.reshape([-1, 3, 3])
else:
box_input = None

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
)
batch_output = model(
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]

results = []
for odef in request_defs:
# it seems not doing conversion
# dp_name = self._OUTDEF_DP2BACKEND[odef.name]
dp_name = odef.name
if dp_name in batch_output:
shape = self._get_output_shape(odef, nframes, natoms)
if batch_output[dp_name] is not None:
out = batch_output[dp_name].reshape(shape)
else:
out = np.full(shape, np.nan)
results.append(out)
else:
shape = self._get_output_shape(odef, nframes, natoms)
results.append(np.full(np.abs(shape), np.nan)) # this is kinda hacky
return tuple(results)

def _get_output_shape(self, odef, nframes, natoms):
if odef.category == OutputVariableCategory.DERV_C_REDU:
# virial
return [nframes, *odef.shape[:-1], 9]
elif odef.category == OutputVariableCategory.REDU:
# energy
return [nframes, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_C:
# atom_virial
return [nframes, *odef.shape[:-1], natoms, 9]
elif odef.category == OutputVariableCategory.DERV_R:
# force
return [nframes, *odef.shape[:-1], natoms, 3]
elif odef.category == OutputVariableCategory.OUT:
# atom_energy, atom_tensor
# Something wrong here?
# return [nframes, *shape, natoms, 1]
return [nframes, natoms, *odef.shape, 1]
else:
raise RuntimeError("unknown category")
Loading

0 comments on commit 15df69b

Please sign in to comment.